Source code for ufl.algorithms.formsplitter
"""Extract part of a form in a mixed FunctionSpace."""
# Copyright (C) 2016-2024 Chris Richardson and Lawrence Mitchell
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Cecile Daversin-Catty, 2018
# Modified by Jørgen S. Dokken, 2025
import numpy as np
from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags
from ufl.argument import Argument
from ufl.classes import FixedIndex, ListTensor
from ufl.constantvalue import Zero
from ufl.corealg.multifunction import MultiFunction
from ufl.form import Form
from ufl.functionspace import FunctionSpace
from ufl.tensors import as_vector
[docs]
class FormSplitter(MultiFunction):
"""Form splitter."""
def __init__(self, replace_argument: bool = True):
"""Initialize form splitter.
Args:
replace_argument: If True, replace the argument by a new argument
in the sub-function space. If False, keep the original argument.
This is useful for instance when diagonalizing a form with a mixed-element
form, where we want to keep the original argument.
"""
MultiFunction.__init__(self)
self.idx = [None, None]
self.replace_argument = replace_argument
[docs]
def split(self, form, ix, iy=None):
"""Split form based on the argument part/number."""
# Remember which block to extract
self.idx = [ix, iy]
return map_integrand_dags(self, form)
[docs]
def argument(self, obj):
"""Apply to argument."""
if obj.part() is not None:
# Mixed element built from MixedFunctionSpace,
# whose sub-function spaces are indexed by obj.part()
if self.idx[obj.number()] is None:
return Zero(obj.ufl_shape)
if obj.part() == self.idx[obj.number()]:
return obj
else:
return Zero(obj.ufl_shape)
else:
# Mixed element built from MixedElement,
# whose sub-elements need their function space to be created
Q = obj.ufl_function_space()
dom = Q.ufl_domain()
sub_elements = obj.ufl_element().sub_elements
# If not a mixed element, do nothing
if len(sub_elements) == 0:
return obj
args = []
counter = 0
for i, sub_elem in enumerate(sub_elements):
Q_i = FunctionSpace(dom, sub_elem)
a = Argument(Q_i, obj.number(), part=obj.part())
if self.replace_argument:
if i == self.idx[obj.number()]:
args.extend(a[j] for j in np.ndindex(a.ufl_shape))
else:
args.extend(Zero() for _ in np.ndindex(a.ufl_shape))
else:
# If we are not replacing the argument, we need to insert
# the original argument at the right place in the vector.
# Mixed elements are flattened, thus we need to keep track of
# the position in the flattened vector.
if i == self.idx[obj.number()]:
if a.ufl_shape == ():
args.append(obj[counter])
else:
args.extend(
obj[counter + j] for j, _ in enumerate(np.ndindex(a.ufl_shape))
)
else:
args.extend(Zero() for _ in np.ndindex(a.ufl_shape))
counter += int(np.prod(a.ufl_shape))
return as_vector(args)
[docs]
def indexed(self, o, child, multiindex):
"""Extract indexed entry if multindices are fixed.
This avoids tensors like (v_0, 0)[1] to be created.
"""
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(
k == int(i) for k, i in enumerate(indices)
):
return child
else:
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)
[docs]
def restricted(self, o):
"""Apply to a restricted function."""
# If we hit a restriction first apply form splitter to argument, then check for zero
op_split = map_expr_dag(self, o.ufl_operands[0])
if isinstance(op_split, Zero):
return op_split
else:
return op_split(o._side)
expr = MultiFunction.reuse_if_untouched
[docs]
def extract_blocks(
form,
i: int | None = None,
j: int | None = None,
arity: int | None = None,
replace_argument: bool = True,
) -> Form | tuple[Form | None, ...] | tuple[tuple[Form | None, ...], ...]:
"""Extract blocks of a form.
If arity is 0, returns the form.
If arity is 1, return the ith block. If ``i`` is ``None``, return all blocks.
If arity is 2, return the ``(i,j)`` entry. If ``j`` is ``None``, return the ith row.
If neither `i` nor `j` are set, return all blocks (as a scalar, vector or tensor).
Args:
form: A form
i: Index of the block to extract. If set to ``None``, ``j`` must be None.
j: Index of the block to extract.
arity: Arity of the form. If not set, it will be inferred from the form.
replace_argument: If True, replace the argument by a new argument
in the (collapsed) sub-function space. If False, keep the original argument.
"""
if i is None and j is not None:
raise RuntimeError(f"Cannot extract block with {j=} and {i=}.")
fs = FormSplitter(replace_argument=replace_argument)
arguments = form.arguments()
if arity is None:
numbers = tuple(sorted(set(a.number() for a in arguments)))
arity = len(numbers)
assert arity <= 2
if arity == 0:
return (form,)
# If mixed element, each argument has no sub-elements
parts = tuple(sorted(set(part for a in arguments if (part := a.part()) is not None)))
if parts == ():
if i is None and j is None:
num_sub_elements = arguments[0].ufl_element().num_sub_elements
# If form has no sub elements, return the form itself.
if num_sub_elements == 0:
return form
forms = []
for pi in range(num_sub_elements):
form_i: list[object | None] = []
for pj in range(num_sub_elements):
f = fs.split(form, pi, pj)
if f.empty():
form_i.append(None)
else:
form_i.append(f)
forms.append(tuple(form_i))
return tuple(forms) # type: ignore[return-value]
else:
return fs.split(form, i, j)
# If mixed function space, each argument has sub-elements
num_parts = max(parts) + 1
forms = [None] * num_parts # type: ignore
if arity == 2:
for k in range(num_parts):
forms[k] = [None] * num_parts # type: ignore
for pi in range(num_parts):
if arity > 1:
for pj in range(num_parts):
f = fs.split(form, pi, pj)
# Ignore empty forms and rank 0 or 1 forms
if f.empty() or len(f.arguments()) != 2:
pass
else:
forms[pi][pj] = f # type: ignore
else:
f = fs.split(form, pi)
# Ignore empty forms and bilinear forms
if f.empty() or len(f.arguments()) != 1:
pass
else:
forms[pi] = f
if i is not None:
if (num_rows := len(forms)) <= i:
raise RuntimeError(f"Cannot extract block {i} from form with {num_rows} blocks.")
if arity > 1 and j is not None:
if (num_cols := len(forms[i])) <= j:
raise RuntimeError(
f"Cannot extract block {i},{j} from form with {num_rows}x{num_cols} blocks."
)
return forms[i][j] # type: ignore[return-value]
else:
return forms[i] # type: ignore[return-value]
else:
if arity == 1:
return tuple(forms) # type: ignore[return-value]
else:
return tuple(tuple(row) for row in forms) # type: ignore[return-value]