Source code for ufl.algorithms.renumbering

"""Algorithms for renumbering of counted objects, currently variables and indices."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.classes import Zero
from ufl.core.expr import Expr
from ufl.core.multiindex import FixedIndex, Index, MultiIndex
from ufl.variable import Label, Variable


[docs]class VariableRenumberingTransformer(ReuseTransformer): """Variable renumbering transformer.""" def __init__(self): """Initialise.""" ReuseTransformer.__init__(self) self.variable_map = {}
[docs] def variable(self, o): """Apply to variable.""" e, l = o.ufl_operands # noqa: E741 v = self.variable_map.get(l) if v is None: e = self.visit(e) l2 = Label(len(self.variable_map)) v = Variable(e, l2) self.variable_map[l] = v return v
[docs]class IndexRenumberingTransformer(VariableRenumberingTransformer): """Index renumbering transformer. This is a poorly designed algorithm. It is used in some tests, please do not use for anything else. """ def __init__(self): """Initialise.""" VariableRenumberingTransformer.__init__(self) self.index_map = {}
[docs] def zero(self, o): """Apply to zero.""" fi = o.ufl_free_indices fid = o.ufl_index_dimensions mapped_fi = tuple(self.index(Index(count=i)) for i in fi) paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)] new_fi, new_fid = zip(*tuple(sorted(paired_fid))) return Zero(o.ufl_shape, new_fi, new_fid)
[docs] def index(self, o): """Apply to index.""" if isinstance(o, FixedIndex): return o else: c = o._count i = self.index_map.get(c) if i is None: i = Index(count=len(self.index_map)) self.index_map[c] = i return i
[docs] def multi_index(self, o): """Apply to multi_index.""" new_indices = tuple(self.index(i) for i in o.indices()) return MultiIndex(new_indices)
[docs]def renumber_indices(expr): """Renumber indices.""" if isinstance(expr, Expr): num_free_indices = len(expr.ufl_free_indices) result = apply_transformer(expr, IndexRenumberingTransformer()) if isinstance(expr, Expr): if num_free_indices != len(result.ufl_free_indices): raise ValueError( "The number of free indices left in expression " "should be invariant w.r.t. renumbering." ) return result