Source code for ufl.conditional

"""This module defines classes for conditional expressions."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

import warnings

from ufl.checks import is_true_ufl_scalar
from ufl.constantvalue import as_ufl
from ufl.core.expr import ufl_err_str
from ufl.core.operator import Operator
from ufl.core.ufl_type import ufl_type
from ufl.exprequals import expr_equals
from ufl.precedence import parstr

# --- Condition classes ---

# TODO: Would be nice with some kind of type system to show that this
# is a boolean type not a float type


[docs]@ufl_type(is_abstract=True, is_scalar=True) class Condition(Operator): """Condition.""" __slots__ = () def __init__(self, operands): """Initialise.""" Operator.__init__(self, operands) def __bool__(self): """Convert to a bool.""" # Showing explicit error here to protect against misuse raise ValueError("UFL conditions cannot be evaluated as bool in a Python context.") __nonzero__ = __bool__
[docs]@ufl_type(is_abstract=True, num_ops=2) class BinaryCondition(Condition): """Binary condition.""" __slots__ = ("_name",) def __init__(self, name, left, right): """Initialise.""" left = as_ufl(left) right = as_ufl(right) Condition.__init__(self, (left, right)) self._name = name if name in ("!=", "=="): # Since equals and not-equals are used for comparing # representations, we have to allow any shape here. The # scalar properties must be checked when used in # conditional instead! pass elif name in ("&&", "||"): # Binary operators acting on boolean expressions allow # only conditions for arg in (left, right): if not isinstance(arg, Condition): raise ValueError(f"Expecting a Condition, not {ufl_err_str(arg)}.") else: # Binary operators acting on non-boolean expressions allow # only scalars if left.ufl_shape != () or right.ufl_shape != (): raise ValueError("Expecting scalar arguments.") if left.ufl_free_indices != () or right.ufl_free_indices != (): raise ValueError("Expecting scalar arguments.") def __str__(self): """Format as a string.""" return "%s %s %s" % ( parstr(self.ufl_operands[0], self), self._name, parstr(self.ufl_operands[1], self), )
# Not associating with __eq__, the concept of equality with == is # reserved for object equivalence for use in set and dict.
[docs]@ufl_type() class EQ(BinaryCondition): """Equality condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, "==", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a == b)
def __bool__(self): """Convert to a bool.""" return expr_equals(self.ufl_operands[0], self.ufl_operands[1]) __nonzero__ = __bool__
# Not associating with __ne__, the concept of equality with == is # reserved for object equivalence for use in set and dict.
[docs]@ufl_type() class NE(BinaryCondition): """Not equal condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, "!=", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a != b)
def __bool__(self): """Convert to a bool.""" return not expr_equals(self.ufl_operands[0], self.ufl_operands[1]) __nonzero__ = __bool__
[docs]@ufl_type(binop="__le__") class LE(BinaryCondition): """Less than or equal condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, "<=", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a <= b)
[docs]@ufl_type(binop="__ge__") class GE(BinaryCondition): """Greater than or equal to condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, ">=", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a >= b)
[docs]@ufl_type(binop="__lt__") class LT(BinaryCondition): """Less than condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, "<", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a < b)
[docs]@ufl_type(binop="__gt__") class GT(BinaryCondition): """Greater than condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, ">", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a > b)
[docs]@ufl_type() class AndCondition(BinaryCondition): """And condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, "&&", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a and b)
[docs]@ufl_type() class OrCondition(BinaryCondition): """Or condition.""" __slots__ = () def __init__(self, left, right): """Initialise.""" BinaryCondition.__init__(self, "||", left, right)
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) b = self.ufl_operands[1].evaluate(x, mapping, component, index_values) return bool(a or b)
[docs]@ufl_type(num_ops=1) class NotCondition(Condition): """Not condition.""" __slots__ = () def __init__(self, condition): """Initialise.""" Condition.__init__(self, (condition,)) if not isinstance(condition, Condition): raise ValueError("Expecting a condition.")
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) return bool(not a)
def __str__(self): """Format as a string.""" return "!(%s)" % (str(self.ufl_operands[0]),)
[docs]@ufl_type(num_ops=3, inherit_shape_from_operand=1, inherit_indices_from_operand=1) class Conditional(Operator): """Conditional expression. In C++ these take the format `(condition ? true_value : false_value)`. """ __slots__ = () def __init__(self, condition, true_value, false_value): """Initialise.""" if not isinstance(condition, Condition): raise ValueError("Expecting condition as first argument.") true_value = as_ufl(true_value) false_value = as_ufl(false_value) tsh = true_value.ufl_shape fsh = false_value.ufl_shape if tsh != fsh: raise ValueError("Shape mismatch between conditional branches.") tfi = true_value.ufl_free_indices ffi = false_value.ufl_free_indices if tfi != ffi: raise ValueError("Free index mismatch between conditional branches.") if isinstance(condition, (EQ, NE)): if not all( ( condition.ufl_operands[0].ufl_shape == (), condition.ufl_operands[0].ufl_free_indices == (), condition.ufl_operands[1].ufl_shape == (), condition.ufl_operands[1].ufl_free_indices == (), ) ): raise ValueError("Non-scalar == or != is not allowed.") Operator.__init__(self, (condition, true_value, false_value))
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" c = self.ufl_operands[0].evaluate(x, mapping, component, index_values) if c: a = self.ufl_operands[1] else: a = self.ufl_operands[2] return a.evaluate(x, mapping, component, index_values)
def __str__(self): """Format as a string.""" return "%s ? %s : %s" % tuple(parstr(o, self) for o in self.ufl_operands)
# --- Specific functions higher level than a conditional ---
[docs]@ufl_type(is_scalar=True, num_ops=1) class MinValue(Operator): """Take the minimum of two values.""" __slots__ = () def __init__(self, left, right): """Initialise.""" Operator.__init__(self, (left, right)) if not (is_true_ufl_scalar(left) and is_true_ufl_scalar(right)): raise ValueError("Expecting scalar arguments.")
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a, b = self.ufl_operands a = a.evaluate(x, mapping, component, index_values) b = b.evaluate(x, mapping, component, index_values) try: res = min(a, b) except ValueError: warnings.warn("Value error in evaluation of min() of %s and %s." % self.ufl_operands) raise return res
def __str__(self): """Format as a string.""" return "min_value(%s, %s)" % self.ufl_operands
[docs]@ufl_type(is_scalar=True, num_ops=1) class MaxValue(Operator): """Take the maximum of two values.""" __slots__ = () def __init__(self, left, right): """Initialise.""" Operator.__init__(self, (left, right)) if not (is_true_ufl_scalar(left) and is_true_ufl_scalar(right)): raise ValueError("Expecting scalar arguments.")
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a, b = self.ufl_operands a = a.evaluate(x, mapping, component, index_values) b = b.evaluate(x, mapping, component, index_values) try: res = max(a, b) except ValueError: warnings.warn("Value error in evaluation of max() of %s and %s." % self.ufl_operands) raise return res
def __str__(self): """Format as a string.""" return "max_value(%s, %s)" % self.ufl_operands