"""This module defines expression transformation utilities.
These utilities are for expanding free indices in expressions to explicit fixed indices only.
"""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Anders Logg, 2009.
from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.classes import Terminal
from ufl.constantvalue import Zero
from ufl.core.multiindex import FixedIndex, Index, MultiIndex
from ufl.differentiation import Grad
from ufl.utils.stacks import Stack, StackDict
[docs]
class IndexExpander(ReuseTransformer):
"""Index expander."""
def __init__(self):
"""Initialise."""
ReuseTransformer.__init__(self)
self._components = Stack()
self._index2value = StackDict()
[docs]
def component(self):
"""Return current component tuple."""
if self._components:
return self._components.peek()
return ()
[docs]
def terminal(self, x):
"""Apply to terminal."""
if x.ufl_shape:
c = self.component()
if len(x.ufl_shape) != len(c):
raise ValueError("Component size mismatch.")
return x[c]
return x
[docs]
def zero(self, x):
"""Apply to zero."""
if len(x.ufl_shape) != len(self.component()):
raise ValueError("Component size mismatch.")
s = set(x.ufl_free_indices) - set(i.count() for i in self._index2value.keys())
if s:
raise ValueError(f"Free index set mismatch, these indices have no value assigned: {s}.")
# There is no index/shape info in this zero because that is asserted above
return Zero()
[docs]
def scalar_value(self, x):
"""Apply to scalar_value."""
if len(x.ufl_shape) != len(self.component()):
self.print_visit_stack()
if len(x.ufl_shape) != len(self.component()):
raise ValueError("Component size mismatch.")
s = set(x.ufl_free_indices) - set(i.count() for i in self._index2value.keys())
if s:
raise ValueError(f"Free index set mismatch, these indices have no value assigned: {s}.")
return x._ufl_class_(x.value())
[docs]
def conditional(self, x):
"""Apply to conditional."""
c, t, f = x.ufl_operands
# Not accepting nonscalars in condition
if c.ufl_shape != ():
raise ValueError("Not expecting tensor in condition.")
# Conditional may be indexed, push empty component
self._components.push(())
c = self.visit(c)
self._components.pop()
# Keep possibly non-scalar components for values
t = self.visit(t)
f = self.visit(f)
return self.reuse_if_possible(x, c, t, f)
[docs]
def division(self, x):
"""Apply to division."""
a, b = x.ufl_operands
# Not accepting nonscalars in division anymore
if a.ufl_shape != ():
raise ValueError("Not expecting tensor in division.")
if self.component() != ():
raise ValueError("Not expecting component in division.")
if b.ufl_shape != ():
raise ValueError("Not expecting division by tensor.")
a = self.visit(a)
# self._components.push(())
b = self.visit(b)
# self._components.pop()
return self.reuse_if_possible(x, a, b)
[docs]
def index_sum(self, x):
"""Apply to index_sum."""
ops = []
summand, multiindex = x.ufl_operands
(index,) = multiindex
# TODO: For the list tensor purging algorithm, do something like:
# if index not in self._to_expand:
# return self.expr(x, *[self.visit(o) for o in x.ufl_operands])
for value in range(x.dimension()):
self._index2value.push(index, value)
ops.append(self.visit(summand))
self._index2value.pop()
return sum(ops)
def _multi_index_values(self, x):
"""Apply to _multi_index_values."""
comp = []
for i in x._indices:
if isinstance(i, FixedIndex):
comp.append(i._value)
elif isinstance(i, Index):
comp.append(self._index2value[i])
return tuple(comp)
[docs]
def multi_index(self, x):
"""Apply to multi_index."""
comp = self._multi_index_values(x)
return MultiIndex(tuple(FixedIndex(i) for i in comp))
[docs]
def indexed(self, x):
"""Apply to indexed."""
A, ii = x.ufl_operands
# Push new component built from index value map
self._components.push(self._multi_index_values(ii))
# Hide index values (doing this is not correct behaviour)
# for i in ii:
# if isinstance(i, Index):
# self._index2value.push(i, None)
result = self.visit(A)
# Un-hide index values
# for i in ii:
# if isinstance(i, Index):
# self._index2value.pop()
# Reset component
self._components.pop()
return result
[docs]
def component_tensor(self, x):
"""Apply to component_tensor."""
# This function evaluates the tensor expression
# with indices equal to the current component tuple
expression, indices = x.ufl_operands
if expression.ufl_shape != ():
raise ValueError("Expecting scalar base expression.")
# Update index map with component tuple values
comp = self.component()
if len(indices) != len(comp):
raise ValueError("Index/component mismatch.")
for i, v in zip(indices.indices(), comp):
self._index2value.push(i, v)
self._components.push(())
# Evaluate with these indices
result = self.visit(expression)
# Revert index map
for _ in comp:
self._index2value.pop()
self._components.pop()
return result
[docs]
def list_tensor(self, x):
"""Apply to list_tensor."""
# Pick the right subtensor and subcomponent
c = self.component()
c0, c1 = c[0], c[1:]
op = x.ufl_operands[c0]
# Evaluate subtensor with this subcomponent
self._components.push(c1)
r = self.visit(op)
self._components.pop()
return r
[docs]
def grad(self, x):
"""Apply to grad."""
(f,) = x.ufl_operands
if not isinstance(f, (Terminal, Grad)):
raise ValueError("Expecting expand_derivatives to have been applied.")
# No need to visit child as long as it is on the form [Grad]([Grad](terminal))
return x[self.component()]
[docs]
def expand_indices(e):
"""Expand indices."""
return apply_transformer(e, IndexExpander())