Source code for ufl.algorithms.formsplitter

"""Extract part of a form in a mixed FunctionSpace."""

# Copyright (C) 2016 Chris Richardson and Lawrence Mitchell
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later
#
# Modified by Cecile Daversin-Catty, 2018

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.argument import Argument
from ufl.constantvalue import Zero
from ufl.corealg.multifunction import MultiFunction
from ufl.functionspace import FunctionSpace
from ufl.tensors import as_vector


[docs]class FormSplitter(MultiFunction): """Form splitter."""
[docs] def split(self, form, ix, iy=0): """Split.""" # Remember which block to extract self.idx = [ix, iy] return map_integrand_dags(self, form)
[docs] def argument(self, obj): """Apply to argument.""" if obj.part() is not None: # Mixed element built from MixedFunctionSpace, # whose sub-function spaces are indexed by obj.part() if len(obj.ufl_shape) == 0: if obj.part() == self.idx[obj.number()]: return obj else: return Zero() else: indices = [()] for m in obj.ufl_shape: indices = [(k + (j,)) for k in indices for j in range(m)] if obj.part() == self.idx[obj.number()]: return as_vector([obj[j] for j in indices]) else: return as_vector([Zero() for j in indices]) else: # Mixed element built from MixedElement, # whose sub-elements need their function space to be created Q = obj.ufl_function_space() dom = Q.ufl_domain() sub_elements = obj.ufl_element().sub_elements() # If not a mixed element, do nothing if len(sub_elements) == 0: return obj args = [] for i, sub_elem in enumerate(sub_elements): Q_i = FunctionSpace(dom, sub_elem) a = Argument(Q_i, obj.number(), part=obj.part()) indices = [()] for m in a.ufl_shape: indices = [(k + (j,)) for k in indices for j in range(m)] if i == self.idx[obj.number()]: args += [a[j] for j in indices] else: args += [Zero() for j in indices] return as_vector(args)
[docs] def multi_index(self, obj): """Apply to multi_index.""" return obj
expr = MultiFunction.reuse_if_untouched
[docs]def extract_blocks(form, i=None, j=None): """Extract blocks.""" fs = FormSplitter() arguments = form.arguments() forms = [] numbers = tuple(sorted(set(a.number() for a in arguments))) arity = len(numbers) parts = tuple(sorted(set(a.part() for a in arguments))) assert arity <= 2 if arity == 0: return (form,) for pi in parts: if arity > 1: for pj in parts: f = fs.split(form, pi, pj) if f.empty(): forms.append(None) else: forms.append(f) else: f = fs.split(form, pi) if f.empty(): forms.append(None) else: forms.append(f) try: forms_tuple = tuple(forms) except TypeError: # Only one form returned forms_tuple = (forms,) if i is not None: if arity > 1 and j is not None: return forms_tuple[i * len(parts) + j] else: return forms_tuple[i] else: return forms_tuple