Source code for ufl.algorithms.balancing

"""Balancing."""
# -*- coding: utf-8 -*-
# Copyright (C) 2011-2017 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

from ufl.classes import (
    CellAvg,
    FacetAvg,
    Grad,
    Indexed,
    NegativeRestricted,
    PositiveRestricted,
    ReferenceGrad,
    ReferenceValue,
)
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction

modifier_precedence = [
    ReferenceValue,
    ReferenceGrad,
    Grad,
    CellAvg,
    FacetAvg,
    PositiveRestricted,
    NegativeRestricted,
    Indexed,
]

modifier_precedence = {m._ufl_handler_name_: i for i, m in enumerate(modifier_precedence)}


[docs]def balance_modified_terminal(expr): """Balance modified terminal.""" # NB! Assuming e.g. grad(cell_avg(expr)) does not occur, # i.e. it is simplified to 0 immediately. if expr._ufl_is_terminal_: return expr assert expr._ufl_is_terminal_modifier_ orig = expr # Build list of modifier layers layers = [expr] while not expr._ufl_is_terminal_: assert expr._ufl_is_terminal_modifier_ expr = expr.ufl_operands[0] layers.append(expr) assert layers[-1] is expr assert expr._ufl_is_terminal_ # Apply modifiers in order layers = sorted(layers[:-1], key=lambda e: modifier_precedence[e._ufl_handler_name_]) for op in layers: ops = (expr,) + op.ufl_operands[1:] expr = op._ufl_expr_reconstruct_(*ops) # Preserve id if nothing has changed return orig if expr == orig else expr
[docs]class BalanceModifiers(MultiFunction): """Balance modifiers."""
[docs] def expr(self, expr, *ops): """Apply to expr.""" return expr._ufl_expr_reconstruct_(*ops)
[docs] def terminal(self, expr): """Apply to terminal.""" return expr
def _modifier(self, expr, *ops): """Apply to _modifier.""" return balance_modified_terminal(expr) reference_value = _modifier reference_grad = _modifier grad = _modifier cell_avg = _modifier facet_avg = _modifier positive_restricted = _modifier negative_restricted = _modifier
[docs]def balance_modifiers(expr): """Balance modifiers.""" mf = BalanceModifiers() return map_expr_dag(mf, expr)