"""Algorithm for expanding compound expressions into equivalent representations using basic operators."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Anders Logg, 2009-2010
from ufl.classes import Product, Grad, Conj
from ufl.core.multiindex import indices, Index
from ufl.tensors import as_tensor, as_matrix, as_vector
from ufl.compound_expressions import deviatoric_expr, determinant_expr, cofactor_expr, inverse_expr
from ufl.corealg.multifunction import MultiFunction
from ufl.algorithms.map_integrands import map_integrand_dags
[docs]class LowerCompoundAlgebra(MultiFunction):
"""Expands high level compound operators to equivalent representations using basic operators."""
def __init__(self):
"""Initialize."""
MultiFunction.__init__(self)
ufl_type = MultiFunction.reuse_if_untouched
# ------------ Compound tensor operators
[docs] def trace(self, o, A):
"""Lower a trace."""
i = Index()
return A[i, i]
[docs] def transposed(self, o, A):
"""Lower a transposed."""
i, j = indices(2)
return as_tensor(A[i, j], (j, i))
[docs] def deviatoric(self, o, A):
"""Lower a deviatoric."""
return deviatoric_expr(A)
[docs] def skew(self, o, A):
"""Lower a skew."""
i, j = indices(2)
return as_matrix((A[i, j] - A[j, i]) / 2, (i, j))
[docs] def sym(self, o, A):
"""Lower a sym."""
i, j = indices(2)
return as_matrix((A[i, j] + A[j, i]) / 2, (i, j))
[docs] def cross(self, o, a, b):
"""Lower a cross."""
def c(i, j):
return Product(a[i], b[j]) - Product(a[j], b[i])
return as_vector((c(1, 2), c(2, 0), c(0, 1)))
[docs] def perp(self, o, a):
"""Lower a perp."""
return as_vector([-a[1], a[0]])
[docs] def dot(self, o, a, b):
"""Lower a dot."""
ai = indices(len(a.ufl_shape) - 1)
bi = indices(len(b.ufl_shape) - 1)
k = (Index(),)
# Creates a single IndexSum over a Product
s = a[ai + k] * b[k + bi]
return as_tensor(s, ai + bi)
[docs] def inner(self, o, a, b):
"""Lower an inner."""
ash = a.ufl_shape
bsh = b.ufl_shape
if ash != bsh:
raise ValueError("Nonmatching shapes.")
ii = indices(len(ash))
# Creates multiple IndexSums over a Product
s = a[ii] * Conj(b[ii])
return s
[docs] def outer(self, o, a, b):
"""Lower an outer."""
ii = indices(len(a.ufl_shape))
jj = indices(len(b.ufl_shape))
# Create a Product with no shared indices
s = Conj(a[ii]) * b[jj]
return as_tensor(s, ii + jj)
[docs] def determinant(self, o, A):
"""Lower a determinant."""
return determinant_expr(A)
[docs] def cofactor(self, o, A):
"""Lower a cofactor."""
return cofactor_expr(A)
[docs] def inverse(self, o, A):
"""Lower an inverse."""
return inverse_expr(A)
# ------------ Compound differential operators
[docs] def div(self, o, a):
"""Lower a div."""
i = Index()
return a[..., i].dx(i)
[docs] def nabla_div(self, o, a):
"""Lower a nabla_div."""
i = Index()
return a[i, ...].dx(i)
[docs] def nabla_grad(self, o, a):
"""Lower a nabla_grad."""
sh = a.ufl_shape
if sh == ():
return Grad(a)
else:
j = Index()
ii = tuple(indices(len(sh)))
return as_tensor(a[ii].dx(j), (j,) + ii)
[docs] def curl(self, o, a):
"""Lower a curl."""
# o = curl a = "[a.dx(1), -a.dx(0)]" if a.ufl_shape == ()
# o = curl a = "cross(nabla, (a0, a1, 0))[2]" if a.ufl_shape == (2,)
# o = curl a = "cross(nabla, a)" if a.ufl_shape == (3,)
def c(i, j):
"""A component of curl."""
return a[j].dx(i) - a[i].dx(j)
sh = a.ufl_shape
if sh == ():
return as_vector((a.dx(1), -a.dx(0)))
if sh == (2,):
return c(0, 1)
if sh == (3,):
return as_vector((c(1, 2), c(2, 0), c(0, 1)))
raise ValueError(f"Invalid shape {sh} of curl argument.")
[docs]def apply_algebra_lowering(expr):
"""Expands high level compound operators to equivalent representations using basic operators."""
return map_integrand_dags(LowerCompoundAlgebra(), expr)