Source code for ufl.split_functions

"""Algorithm for splitting a Coefficient or Argument into subfunctions."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later
#
# Modified by Anders Logg, 2008

from ufl.functionspace import FunctionSpace
from ufl.indexed import Indexed
from ufl.permutation import compute_indices
from ufl.tensors import ListTensor, as_matrix, as_vector
from ufl.utils.indexflattening import flatten_multiindex, shape_to_strides
from ufl.utils.sequences import product


[docs]def split(v): """Split a coefficient or argument. If v is a Coefficient or Argument in a mixed space, returns a tuple with the function components corresponding to the subelements. """ domain = v.ufl_domain() # Default range is all of v begin = 0 end = None if isinstance(v, Indexed): # Special case: split previous output of split again # Consistent with simple element, just return function in a tuple return (v,) elif isinstance(v, ListTensor): # Special case: split previous output of split again ops = v.ufl_operands if all(isinstance(comp, Indexed) for comp in ops): args = [comp.ufl_operands[0] for comp in ops] if all(args[0] == args[i] for i in range(1, len(args))): # Get innermost terminal here and its element v = args[0] # Get relevant range of v components (begin,) = ops[0].ufl_operands[1] (end,) = ops[-1].ufl_operands[1] begin = int(begin) end = int(end) + 1 else: raise ValueError(f"Don't know how to split {v}.") else: raise ValueError(f"Don't know how to split {v}.") # Special case: simple element, just return function in a tuple element = v.ufl_element() if element.num_sub_elements == 0: assert end is None return (v,) if len(v.ufl_shape) != 1: raise ValueError( "Don't know how to split tensor valued mixed functions without flattened index space." ) # Compute value size and set default range end value_size = v.ufl_function_space().value_size if end is None: end = value_size else: # Recursively dive into mixedelement in to subelement # corresponding to beginning of range j = begin while True: for e in element.sub_elements: if j < FunctionSpace(domain, e).value_size: element = e break j -= FunctionSpace(domain, e).value_size # Then break when we find the subelement that covers the whole range if FunctionSpace(domain, element).value_size == (end - begin): break # Build expressions representing the subfunction of v for each subelement offset = begin sub_functions = [] for i, e in enumerate(element.sub_elements): # Get shape, size, indices, and v components # corresponding to subelement value shape = FunctionSpace(domain, e).value_shape strides = shape_to_strides(shape) rank = len(shape) sub_size = product(shape) subindices = [flatten_multiindex(c, strides) for c in compute_indices(shape)] components = [v[k + offset] for k in subindices] # Shape components into same shape as subelement if rank == 0: (subv,) = components elif rank <= 1: subv = as_vector(components) elif rank == 2: subv = as_matrix( [components[i * shape[1] : (i + 1) * shape[1]] for i in range(shape[0])] ) else: raise ValueError( f"Don't know how to split functions with sub functions of rank {rank}." ) offset += sub_size sub_functions.append(subv) if end != offset: raise ValueError( "Function splitting failed to extract components for whole intended range." ) return tuple(sub_functions)