"""Algorithm to check for 'comparison' nodes in a form when the user is in 'complex mode'."""
from ufl.algebra import Real
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.argument import Argument
from ufl.constantvalue import RealValue, Zero
from ufl.corealg.multifunction import MultiFunction
from ufl.geometry import GeometricQuantity
[docs]class CheckComparisons(MultiFunction):
"""Raises an error if comparisons are done with complex quantities.
If quantities are real, adds the Real operator to the compared quantities.
Terminals that are real are RealValue, Zero, and Argument
(even in complex FEM, the basis functions are real)
Operations that produce reals are Abs, Real, Imag.
Terminals default to complex, and Sqrt, Pow (defensively) imply complex.
Otherwise, operators preserve the type of their operands.
"""
def __init__(self):
"""Initialise."""
MultiFunction.__init__(self)
self.nodetype = {}
[docs] def expr(self, o, *ops):
"""Defaults expressions to complex unless they only act on real quantities.
Overridden for specific operators. Rebuilds objects if necessary.
"""
types = {self.nodetype[op] for op in ops}
if types:
t = "complex" if "complex" in types else "real"
else:
t = "complex"
o = self.reuse_if_untouched(o, *ops)
self.nodetype[o] = t
return o
[docs] def compare(self, o, *ops):
"""Compare."""
types = {self.nodetype[op] for op in ops}
if "complex" in types:
raise ComplexComparisonError("Ordering undefined for complex values.")
else:
o = o._ufl_expr_reconstruct_(*map(Real, ops))
self.nodetype[o] = "bool"
return o
gt = compare
lt = compare
ge = compare
le = compare
sign = compare
[docs] def max_value(self, o, *ops):
"""Apply to max_value."""
types = {self.nodetype[op] for op in ops}
if "complex" in types:
raise ComplexComparisonError("You can't compare complex numbers with max.")
else:
o = o._ufl_expr_reconstruct_(*map(Real, ops))
self.nodetype[o] = "bool"
return o
[docs] def min_value(self, o, *ops):
"""Apply to min_value."""
types = {self.nodetype[op] for op in ops}
if "complex" in types:
raise ComplexComparisonError("You can't compare complex numbers with min.")
else:
o = o._ufl_expr_reconstruct_(*map(Real, ops))
self.nodetype[o] = "bool"
return o
[docs] def real(self, o, *ops):
"""Apply to real."""
o = self.reuse_if_untouched(o, *ops)
self.nodetype[o] = "real"
return o
[docs] def imag(self, o, *ops):
"""Apply to imag."""
o = self.reuse_if_untouched(o, *ops)
self.nodetype[o] = "real"
return o
[docs] def sqrt(self, o, *ops):
"""Apply to sqrt."""
o = self.reuse_if_untouched(o, *ops)
self.nodetype[o] = "complex"
return o
[docs] def power(self, o, base, exponent):
"""Apply to power."""
o = self.reuse_if_untouched(o, base, exponent)
try:
# Attempt to diagnose circumstances in which the result must be real.
exponent = float(exponent)
if self.nodetype[base] == "real" and int(exponent) == exponent:
self.nodetype[o] = "real"
return o
except TypeError:
pass
self.nodetype[o] = "complex"
return o
[docs] def abs(self, o, *ops):
"""Apply to abs."""
o = self.reuse_if_untouched(o, *ops)
self.nodetype[o] = "real"
return o
[docs] def terminal(self, term, *ops):
"""Apply to terminal."""
# default terminals to complex, except the ones we *know* are real
if isinstance(term, (RealValue, Zero, Argument, GeometricQuantity)):
self.nodetype[term] = "real"
else:
self.nodetype[term] = "complex"
return term
[docs] def indexed(self, o, expr, multiindex):
"""Apply to indexed."""
o = self.reuse_if_untouched(o, expr, multiindex)
self.nodetype[o] = self.nodetype[expr]
return o
[docs]def do_comparison_check(form):
"""Raises an error if invalid comparison nodes exist."""
return map_integrand_dags(CheckComparisons(), form)
[docs]class ComplexComparisonError(BaseException):
"""Complex compariseon exception."""