Source code for ufl.exproperators

# -*- coding: utf-8 -*-
"""This module attaches special functions to Expr.
This way we avoid circular dependencies between e.g.
Sum and its superclass Expr."""

# Copyright (C) 2008-2016 Martin Sandve Aln├Žs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later
#
# Modified by Massimiliano Leoni, 2016.

from itertools import chain
import numbers

from ufl.log import error
from ufl.utils.stacks import StackDict
from ufl.core.expr import Expr
from ufl.constantvalue import Zero, as_ufl
from ufl.algebra import Sum, Product, Division, Power, Abs
from ufl.tensoralgebra import Transposed, Inner
from ufl.core.multiindex import MultiIndex, Index, FixedIndex, IndexBase, indices
from ufl.indexed import Indexed
from ufl.indexsum import IndexSum
from ufl.tensors import as_tensor, ComponentTensor
from ufl.restriction import PositiveRestricted, NegativeRestricted
from ufl.differentiation import Grad
from ufl.index_combination_utils import create_slice_indices, merge_overlapping_indices

from ufl.exprequals import expr_equals

# --- Boolean operators ---

from ufl.conditional import LE, GE, LT, GT


def _le(left, right):
    "UFL operator: A boolean expresion (left <= right) for use with conditional."
    return LE(left, right)


def _ge(left, right):
    "UFL operator: A boolean expresion (left >= right) for use with conditional."
    return GE(left, right)


def _lt(left, right):
    "UFL operator: A boolean expresion (left < right) for use with conditional."
    return LT(left, right)


def _gt(left, right):
    "UFL operator: A boolean expresion (left > right) for use with conditional."
    return GT(left, right)


# '==' needs to implement comparison of expression representations for
# use in hashmaps (dict and set), but the others can be overloaded in
# the language.  It is possible that we can overload eq as well, but
# we'll need to fix some issues first and also check for a possible
# significant performance hit with compilation of complex
# forms. Replacing a==b with equiv(a,b) all over the code could be one
# way to reduce such a performance hit, but we cannot do anything
# about dict and set calling __eq__...
Expr.__eq__ = expr_equals


# != is used at least by tests, possibly in code as well, and must
# mean the opposite of ==, i.e. when evaluated as bool it must mean
# 'not equal representation'.
def _ne(self, other):
    return not self.__eq__(other)


Expr.__ne__ = _ne
Expr.__lt__ = _lt
Expr.__gt__ = _gt
Expr.__le__ = _le
Expr.__ge__ = _ge

# Python operators 'and'/'or' cannot be overloaded, and bitwise
# operators &/| don't have the right precedence levels
# Expr.__and__ = _and
# Expr.__or__ = _or


def _as_tensor(self, indices):
    "UFL operator: A^indices := as_tensor(A, indices)."
    if not isinstance(indices, tuple):
        error("Expecting a tuple of Index objects to A^indices := as_tensor(A, indices).")
    if not all(isinstance(i, Index) for i in indices):
        error("Expecting a tuple of Index objects to A^indices := as_tensor(A, indices).")
    return as_tensor(self, indices)


Expr.__xor__ = _as_tensor


# --- Helper functions for product handling ---

def _mult(a, b):
    # Discover repeated indices, which results in index sums
    afi = a.ufl_free_indices
    bfi = b.ufl_free_indices
    afid = a.ufl_index_dimensions
    bfid = b.ufl_index_dimensions
    fi, fid, ri, rid = merge_overlapping_indices(afi, afid, bfi, bfid)

    # Pick out valid non-scalar products here (dot products):
    # - matrix-matrix (A*B, M*grad(u)) => A . B
    # - matrix-vector (A*v) => A . v
    s1, s2 = a.ufl_shape, b.ufl_shape
    r1, r2 = len(s1), len(s2)

    if r1 == 0 and r2 == 0:
        # Create scalar product
        p = Product(a, b)
        ti = ()

    elif r1 == 0 or r2 == 0:
        # Scalar - tensor product
        if r2 == 0:
            a, b = b, a

        # Check for zero, simplifying early if possible
        if isinstance(a, Zero) or isinstance(b, Zero):
            shape = s1 or s2
            return Zero(shape, fi, fid)

        # Repeated indices are allowed, like in:
        # v[i]*M[i,:]

        # Apply product to scalar components
        ti = indices(len(b.ufl_shape))
        p = Product(a, b[ti])

    elif r1 == 2 and r2 in (1, 2):  # Matrix-matrix or matrix-vector
        if ri:
            error("Not expecting repeated indices in non-scalar product.")

        # Check for zero, simplifying early if possible
        if isinstance(a, Zero) or isinstance(b, Zero):
            shape = s1[:-1] + s2[1:]
            return Zero(shape, fi, fid)

        # Return dot product in index notation
        ai = indices(len(a.ufl_shape) - 1)
        bi = indices(len(b.ufl_shape) - 1)
        k = indices(1)

        p = a[ai + k] * b[k + bi]
        ti = ai + bi

    else:
        error("Invalid ranks {0} and {1} in product.".format(r1, r2))

    # TODO: I think applying as_tensor after index sums results in
    # cleaner expression graphs.
    # Wrap as tensor again
    if ti:
        p = as_tensor(p, ti)

    # If any repeated indices were found, apply implicit summation
    # over those
    for i in ri:
        mi = MultiIndex((Index(count=i),))
        p = IndexSum(p, mi)

    return p


# --- Extend Expr with algebraic operators ---

_valid_types = (Expr, numbers.Real, numbers.Integral, numbers.Complex)


def _mul(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    o = as_ufl(o)
    return _mult(self, o)


Expr.__mul__ = _mul


def _rmul(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    o = as_ufl(o)
    return _mult(o, self)


Expr.__rmul__ = _rmul


def _add(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Sum(self, o)


Expr.__add__ = _add


def _radd(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    if isinstance(o, numbers.Number) and o == 0:
        # Allow adding scalar int 0 as a no-op, even for shaped self,
        # needed for sum([a,b])
        return self
    return Sum(o, self)


Expr.__radd__ = _radd


def _sub(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Sum(self, -o)


Expr.__sub__ = _sub


def _rsub(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Sum(o, -self)


Expr.__rsub__ = _rsub


def _div(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    sh = self.ufl_shape
    if sh:
        ii = indices(len(sh))
        d = Division(self[ii], o)
        return as_tensor(d, ii)
    return Division(self, o)


Expr.__div__ = _div
Expr.__truediv__ = _div


def _rdiv(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Division(o, self)


Expr.__rdiv__ = _rdiv
Expr.__rtruediv__ = _rdiv


def _pow(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    if o == 2 and self.ufl_shape:
        return Inner(self, self)
    return Power(self, o)


Expr.__pow__ = _pow


def _rpow(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Power(o, self)


Expr.__rpow__ = _rpow


# TODO: Add Negated class for this? Might simplify reductions in Add.
def _neg(self):
    return -1 * self


Expr.__neg__ = _neg


def _abs(self):
    return Abs(self)


Expr.__abs__ = _abs


# --- Extend Expr with restiction operators a("+"), a("-") ---

def _restrict(self, side):
    if side == "+":
        return PositiveRestricted(self)
    if side == "-":
        return NegativeRestricted(self)
    error("Invalid side '%s' in restriction operator." % (side,))


def _eval(self, coord, mapping=None, component=()):
    # Evaluate expression at this particular coordinate, with provided
    # values for other terminals in mapping

    # Evaluate derivatives first
    from ufl.algorithms import expand_derivatives
    f = expand_derivatives(self)

    # Evaluate recursively
    if mapping is None:
        mapping = {}
    index_values = StackDict()
    return f.evaluate(coord, mapping, component, index_values)


def _call(self, arg, mapping=None, component=()):
    # Taking the restriction or evaluating depending on argument
    if arg in ("+", "-"):
        if mapping is not None:
            error("Not expecting a mapping when taking restriction.")
        return _restrict(self, arg)
    else:
        return _eval(self, arg, mapping, component)


Expr.__call__ = _call


# --- Extend Expr with the transpose operation A.T ---

def _transpose(self):
    """Transpose a rank-2 tensor expression. For more general transpose
    operations of higher order tensor expressions, use indexing and Tensor."""
    return Transposed(self)


Expr.T = property(_transpose)


# --- Extend Expr with indexing operator a[i] ---

[docs]def analyse_key(ii, rank): """Takes something the user might input as an index tuple inside [], which could include complete slices (:) and ellipsis (...), and returns tuples of actual UFL index objects. The return value is a tuple (indices, axis_indices), each being a tuple of IndexBase instances. The return value 'indices' corresponds to all input objects of these types: - Index - FixedIndex - int => Wrapped in FixedIndex The return value 'axis_indices' corresponds to all input objects of these types: - Complete slice (:) => Replaced by a single new index - Ellipsis (...) => Replaced by multiple new indices """ # Wrap in tuple if not isinstance(ii, (tuple, MultiIndex)): ii = (ii,) else: # Flatten nested tuples, happens with f[...,ii] where ii is a # tuple of indices jj = [] for j in ii: if isinstance(j, (tuple, MultiIndex)): jj.extend(j) else: jj.append(j) ii = tuple(jj) # Convert all indices to Index or FixedIndex objects. If there is # an ellipsis, split the indices into before and after. axis_indices = set() pre = [] post = [] indexlist = pre for i in ii: if i == Ellipsis: # Switch from pre to post list when an ellipsis is # encountered if indexlist is not pre: error("Found duplicate ellipsis.") indexlist = post else: # Convert index to a proper type if isinstance(i, numbers.Integral): idx = FixedIndex(i) elif isinstance(i, IndexBase): idx = i elif isinstance(i, slice): if i == slice(None): idx = Index() axis_indices.add(idx) else: # TODO: Use ListTensor to support partial slices? error("Partial slices not implemented, only complete slices like [:]") else: error("Can't convert this object to index: %s" % (i,)) # Store index in pre or post list indexlist.append(idx) # Handle ellipsis as a number of complete slices, that is create a # number of new axis indices num_axis = rank - len(pre) - len(post) if indexlist is post: ellipsis_indices = indices(num_axis) axis_indices.update(ellipsis_indices) else: ellipsis_indices = () # Construct final tuples to return all_indices = tuple(chain(pre, ellipsis_indices, post)) axis_indices = tuple(i for i in all_indices if i in axis_indices) return all_indices, axis_indices
def _getitem(self, component): # Treat component consistently as tuple below if not isinstance(component, tuple): component = (component,) shape = self.ufl_shape # Analyse slices (:) and Ellipsis (...) all_indices, slice_indices, repeated_indices = create_slice_indices(component, shape, self.ufl_free_indices) # Check that we have the right number of indices for a tensor with # this shape if len(shape) != len(all_indices): error("Invalid number of indices {0} for expression of rank {1}.".format(len(all_indices), len(shape))) # Special case for simplifying foo[...] => foo, foo[:] => foo or # similar if len(slice_indices) == len(all_indices): return self # Special case for simplifying as_tensor(ai,(i,))[i] => ai if isinstance(self, ComponentTensor): if all_indices == self.indices().indices(): return self.ufl_operands[0] # Apply all indices to index self, yielding a scalar valued # expression mi = MultiIndex(all_indices) a = Indexed(self, mi) # TODO: I think applying as_tensor after index sums results in # cleaner expression graphs. # If the Ellipsis or any slices were found, wrap as tensor valued # with the slice indices created at the top here if slice_indices: a = as_tensor(a, slice_indices) # If any repeated indices were found, apply implicit summation # over those for i in repeated_indices: mi = MultiIndex((i,)) a = IndexSum(a, mi) # Check for zero (last so we can get indices etc from a, could # possibly be done faster by checking early instead) if isinstance(self, Zero): shape = a.ufl_shape fi = a.ufl_free_indices fid = a.ufl_index_dimensions a = Zero(shape, fi, fid) return a Expr.__getitem__ = _getitem # --- Extend Expr with spatial differentiation operator a.dx(i) --- def _dx(self, *ii): "Return the partial derivative with respect to spatial variable number *ii*." d = self # Unwrap ii to allow .dx(i,j) and .dx((i,j)) if len(ii) == 1 and isinstance(ii[0], tuple): ii = ii[0] # Apply all derivatives for i in ii: d = Grad(d) # Take all components, applying repeated index sums in the [] operation return d.__getitem__((Ellipsis,) + ii) Expr.dx = _dx