Source code for ufl.algorithms.balancing
"""Balancing."""
# -*- coding: utf-8 -*-
# Copyright (C) 2011-2017 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
from ufl.classes import (
CellAvg,
FacetAvg,
Grad,
Indexed,
NegativeRestricted,
PositiveRestricted,
ReferenceGrad,
ReferenceValue,
)
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction
modifier_precedence = [
ReferenceValue,
ReferenceGrad,
Grad,
CellAvg,
FacetAvg,
PositiveRestricted,
NegativeRestricted,
Indexed,
]
modifier_precedence = {m._ufl_handler_name_: i for i, m in enumerate(modifier_precedence)}
[docs]
def balance_modified_terminal(expr):
"""Balance modified terminal."""
# NB! Assuming e.g. grad(cell_avg(expr)) does not occur,
# i.e. it is simplified to 0 immediately.
if expr._ufl_is_terminal_:
return expr
assert expr._ufl_is_terminal_modifier_
orig = expr
# Build list of modifier layers
layers = [expr]
while not expr._ufl_is_terminal_:
assert expr._ufl_is_terminal_modifier_
expr = expr.ufl_operands[0]
layers.append(expr)
assert layers[-1] is expr
assert expr._ufl_is_terminal_
# Apply modifiers in order
layers = sorted(layers[:-1], key=lambda e: modifier_precedence[e._ufl_handler_name_])
for op in layers:
ops = (expr,) + op.ufl_operands[1:]
expr = op._ufl_expr_reconstruct_(*ops)
# Preserve id if nothing has changed
return orig if expr == orig else expr
[docs]
class BalanceModifiers(MultiFunction):
"""Balance modifiers."""
[docs]
def expr(self, expr, *ops):
"""Apply to expr."""
return expr._ufl_expr_reconstruct_(*ops)
[docs]
def terminal(self, expr):
"""Apply to terminal."""
return expr
def _modifier(self, expr, *ops):
"""Apply to _modifier."""
return balance_modified_terminal(expr)
reference_value = _modifier
reference_grad = _modifier
grad = _modifier
cell_avg = _modifier
facet_avg = _modifier
positive_restricted = _modifier
negative_restricted = _modifier
[docs]
def balance_modifiers(expr):
"""Balance modifiers."""
mf = BalanceModifiers()
return map_expr_dag(mf, expr)