Source code for ufl.variable

"""Define the Variable and Label classes.

These are used to label expressions as variables for differentiation.
"""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

from ufl.utils.counted import Counted
from ufl.core.expr import Expr
from ufl.core.ufl_type import ufl_type
from ufl.core.terminal import Terminal
from ufl.core.operator import Operator
from ufl.constantvalue import as_ufl


[docs]@ufl_type() class Label(Terminal, Counted): """Label.""" __slots__ = ("_count", "_counted_class") def __init__(self, count=None): """Initialise.""" Terminal.__init__(self) Counted.__init__(self, count, Label) def __str__(self): """Format as a string.""" return "Label(%d)" % self._count def __repr__(self): """Representation.""" r = "Label(%d)" % self._count return r @property def ufl_shape(self): """Get the UFL shape.""" raise ValueError("Label has no shape (it is not a tensor expression).") @property def ufl_free_indices(self): """Get the UFL free indices.""" raise ValueError("Label has no free indices (it is not a tensor expression).") @property def ufl_index_dimensions(self): """Get the UFL index dimensions.""" raise ValueError("Label has no free indices (it is not a tensor expression).")
[docs] def is_cellwise_constant(self): """Return true if the object is constant on each cell.""" return True
[docs] def ufl_domains(self): """Return tuple of domains related to this terminal object.""" return ()
def _ufl_signature_data_(self, renumbering): """UFL signature data.""" if self not in renumbering: return ("Label", self._count) return ("Label", renumbering[self])
[docs]@ufl_type(is_shaping=True, is_index_free=True, num_ops=1, inherit_shape_from_operand=0) class Variable(Operator): """A Variable is a representative for another expression. It will be used by the end-user mainly for defining a quantity to differentiate w.r.t. using diff. Example:: e = <...> e = variable(e) f = exp(e**2) df = diff(f, e) """ __slots__ = () def __init__(self, expression, label=None): """Initalise.""" # Conversion expression = as_ufl(expression) if label is None: label = Label() # Checks if not isinstance(expression, Expr): raise ValueError("Expecting Expr.") if not isinstance(label, Label): raise ValueError("Expecting a Label.") if expression.ufl_free_indices: raise ValueError("Variable cannot wrap an expression with free indices.") Operator.__init__(self, (expression, label))
[docs] def ufl_domains(self): """Get the UFL domains.""" return self.ufl_operands[0].ufl_domains()
[docs] def evaluate(self, x, mapping, component, index_values): """Evaluate.""" a = self.ufl_operands[0].evaluate(x, mapping, component, index_values) return a
[docs] def expression(self): """Get expression.""" return self.ufl_operands[0]
[docs] def label(self): """Get label.""" return self.ufl_operands[1]
def __eq__(self, other): """Check equality.""" return (isinstance(other, Variable) and self.ufl_operands[1] == other.ufl_operands[1] and # noqa: W504 self.ufl_operands[0] == other.ufl_operands[0]) def __str__(self): """Format as a string.""" return "var%d(%s)" % (self.ufl_operands[1].count(), self.ufl_operands[0])