Source code for ufl.algorithms.formsplitter
"""Extract part of a form in a mixed FunctionSpace."""
# Copyright (C) 2016-2024 Chris Richardson and Lawrence Mitchell
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Cecile Daversin-Catty, 2018
from typing import Optional
from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags
from ufl.argument import Argument
from ufl.classes import FixedIndex, ListTensor
from ufl.constantvalue import Zero
from ufl.corealg.multifunction import MultiFunction
from ufl.functionspace import FunctionSpace
from ufl.tensors import as_vector
[docs]
class FormSplitter(MultiFunction):
"""Form splitter."""
[docs]
def split(self, form, ix, iy=0):
"""Split."""
# Remember which block to extract
self.idx = [ix, iy]
return map_integrand_dags(self, form)
[docs]
def argument(self, obj):
"""Apply to argument."""
if obj.part() is not None:
# Mixed element built from MixedFunctionSpace,
# whose sub-function spaces are indexed by obj.part()
if obj.part() == self.idx[obj.number()]:
return obj
else:
return Zero(obj.ufl_shape)
else:
# Mixed element built from MixedElement,
# whose sub-elements need their function space to be created
Q = obj.ufl_function_space()
dom = Q.ufl_domain()
sub_elements = obj.ufl_element().sub_elements
# If not a mixed element, do nothing
if len(sub_elements) == 0:
return obj
args = []
for i, sub_elem in enumerate(sub_elements):
Q_i = FunctionSpace(dom, sub_elem)
a = Argument(Q_i, obj.number(), part=obj.part())
indices = [()]
for m in a.ufl_shape:
indices = [(k + (j,)) for k in indices for j in range(m)]
if i == self.idx[obj.number()]:
args += [a[j] for j in indices]
else:
args += [Zero() for j in indices]
return as_vector(args)
[docs]
def indexed(self, o, child, multiindex):
"""Extract indexed entry if multindices are fixed.
This avoids tensors like (v_0, 0)[1] to be created.
"""
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return self.expr(o, child, multiindex)
[docs]
def restricted(self, o):
"""Apply to a restricted function."""
# If we hit a restriction first apply form splitter to argument, then check for zero
op_split = map_expr_dag(self, o.ufl_operands[0])
if isinstance(op_split, Zero):
return op_split
else:
return op_split(o._side)
expr = MultiFunction.reuse_if_untouched
[docs]
def extract_blocks(form, i: Optional[int] = None, j: Optional[None] = None):
"""Extract blocks of a form.
If arity is 0, returns the form.
If arity is 1, return the ith block. If ``i`` is ``None``, return all blocks.
If arity is 2, return the ``(i,j)`` entry. If ``j`` is ``None``, return the ith row.
If neither `i` nor `j` are set, return all blocks (as a scalar, vector or tensor).
Args:
form: A form
i: Index of the block to extract. If set to ``None``, ``j`` must be None.
j: Index of the block to extract.
"""
if i is None and j is not None:
raise RuntimeError(f"Cannot extract block with {j=} and {i=}.")
fs = FormSplitter()
arguments = form.arguments()
numbers = tuple(sorted(set(a.number() for a in arguments)))
arity = len(numbers)
assert arity <= 2
if arity == 0:
return (form,)
# If mixed element, each argument has no sub-elements
parts = tuple(sorted(set(part for a in arguments if (part := a.part()) is not None)))
if parts == ():
if i is None and j is None:
num_sub_elements = arguments[0].ufl_element().num_sub_elements
forms = []
for pi in range(num_sub_elements):
form_i = []
for pj in range(num_sub_elements):
f = fs.split(form, pi, pj)
if f.empty():
form_i.append(None)
else:
form_i.append(f)
forms.append(tuple(form_i))
return tuple(forms)
else:
return fs.split(form, i, j)
# If mixed function space, each argument has sub-elements
forms = []
num_parts = len(parts)
for pi in range(num_parts):
form_i = []
if arity > 1:
for pj in range(num_parts):
f = fs.split(form, pi, pj)
if f.empty():
form_i.append(None)
else:
if (num_args := len(f.arguments())) != 2:
raise RuntimeError(f"Expected 2 arguments, got {num_args}")
form_i.append(f)
forms.append(tuple(form_i))
else:
f = fs.split(form, pi)
if f.empty():
forms.append(None)
else:
forms.append(f)
try:
forms_tuple = tuple(forms)
except TypeError:
# Only one form returned
forms_tuple = (forms,)
if i is not None:
if (num_rows := len(forms_tuple)) <= i:
raise RuntimeError(f"Cannot extract block {i} from form with {num_rows} blocks.")
if arity > 1 and j is not None:
if (num_cols := len(forms_tuple[i])) <= j:
raise RuntimeError(
f"Cannot extract block {i},{j} from form with {num_rows}x{num_cols} blocks."
)
return forms_tuple[i][j]
else:
return forms_tuple[i]
else:
return forms_tuple