# Source code for ufl.finiteelement.mixedelement

# -*- coding: utf-8 -*-
"This module defines the UFL finite element classes."

# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
#
# Modified by Kristian B. Oelgaard
# Modified by Marie E. Rognes 2010, 2012
# Modified by Anders Logg 2014
# Modified by Massimiliano Leoni, 2016

from ufl.log import error
from ufl.permutation import compute_indices
from ufl.utils.sequences import product, max_degree
from ufl.utils.dicts import EmptyDict
from ufl.utils.indexflattening import flatten_multiindex, unflatten_index, shape_to_strides
from ufl.cell import as_cell

from ufl.finiteelement.finiteelementbase import FiniteElementBase
from ufl.finiteelement.finiteelement import FiniteElement

[docs]class MixedElement(FiniteElementBase):
"""A finite element composed of a nested hierarchy of mixed or simple
elements."""
__slots__ = ("_sub_elements", "_cells")

def __init__(self, *elements, **kwargs):
"Create mixed finite element from given list of elements"

if type(self) is MixedElement:
if kwargs:
error("Not expecting keyword arguments to MixedElement constructor.")

# Un-nest arguments if we get a single argument with a list of elements
if len(elements) == 1 and isinstance(elements[0], (tuple, list)):
elements = elements[0]
# Interpret nested tuples as sub-mixedelements recursively
elements = [MixedElement(e) if isinstance(e, (tuple, list)) else e
for e in elements]
self._sub_elements = elements

# Pick the first cell, for now all should be equal
cells = tuple(sorted(set(element.cell() for element in elements) - set([None])))
self._cells = cells
if cells:
cell = cells[0]
# Require that all elements are defined on the same cell
if not all(c == cell for c in cells[1:]):
error("Sub elements must live on the same cell.")
else:
cell = None

# Check that all elements use the same quadrature scheme TODO:
# We can allow the scheme not to be defined.
if len(elements) == 0:
else:
error("Quadrature scheme mismatch for sub elements of mixed element.")

# Compute value sizes in global and reference configurations
value_size_sum = sum(product(s.value_shape()) for s in self._sub_elements)
reference_value_size_sum = sum(product(s.reference_value_shape()) for s in self._sub_elements)

# Default value shape: Treated simply as all subelement values
# unpacked in a vector.
value_shape = kwargs.get('value_shape', (value_size_sum,))

# Default reference value shape: Treated simply as all
# subelement reference values unpacked in a vector.
reference_value_shape = kwargs.get('reference_value_shape', (reference_value_size_sum,))

# Validate value_shape (deliberately not for subclasses
# VectorElement and TensorElement)
if type(self) is MixedElement:
# This is not valid for tensor elements with symmetries,
# assume subclasses deal with their own validation
if product(value_shape) != value_size_sum:
error("Provided value_shape doesn't match the "
"total value size of all subelements.")

# Initialize element data
degrees = {e.degree() for e in self._sub_elements} - {None}
degree = max_degree(degrees) if degrees else None
value_shape, reference_value_shape)

# Cache repr string
if type(self) is MixedElement:
self._repr = "MixedElement(%s)" % (
", ".join(repr(e) for e in self._sub_elements),)

[docs]    def reconstruct_from_elements(self, *elements):
"Reconstruct a mixed element from new subelements."
if all(a == b for (a, b) in zip(elements, self._sub_elements)):
return self
return MixedElement(*elements)

[docs]    def symmetry(self):
"""Return the symmetry dict, which is a mapping :math:c_0 \\to c_1
meaning that component :math:c_0 is represented by component
:math:c_1.
A component is a tuple of one or more ints."""
# Build symmetry map from symmetries of subelements
sm = {}
# Base index of the current subelement into mixed value
j = 0
for e in self._sub_elements:
sh = e.value_shape()
st = shape_to_strides(sh)
# Map symmetries of subelement into index space of this
# element
for c0, c1 in e.symmetry().items():
j0 = flatten_multiindex(c0, st) + j
j1 = flatten_multiindex(c1, st) + j
sm[(j0,)] = (j1,)
# Update base index for next element
j += product(sh)
if j != product(self.value_shape()):
error("Size mismatch in symmetry algorithm.")
return sm or EmptyDict

[docs]    def mapping(self):
if all(e.mapping() == "identity" for e in self._sub_elements):
return "identity"
else:
return "undefined"

[docs]    def num_sub_elements(self):
"Return number of sub elements."
return len(self._sub_elements)

[docs]    def sub_elements(self):
"Return list of sub elements."
return self._sub_elements

[docs]    def extract_subelement_component(self, i):
"""Extract direct subelement index and subelement relative
component index for a given component index."""
if isinstance(i, int):
i = (i,)
self._check_component(i)

# Select between indexing modes
if len(self.value_shape()) == 1:
# Indexing into a long vector of flattened subelement
# shapes
j, = i

# Find subelement for this index
for sub_element_index, e in enumerate(self._sub_elements):
sh = e.value_shape()
si = product(sh)
if j < si:
break
j -= si
if j < 0:
error("Moved past last value component!")

# Convert index into a shape tuple
st = shape_to_strides(sh)
component = unflatten_index(j, st)
else:
# Indexing into a multidimensional tensor where subelement
# index is first axis
sub_element_index = i[0]
if sub_element_index >= len(self._sub_elements):
error("Illegal component index (dimension %d)." % sub_element_index)
component = i[1:]
return (sub_element_index, component)

[docs]    def extract_component(self, i):
"""Recursively extract component index relative to a (simple) element
and that element for given value component index."""
sub_element_index, component = self.extract_subelement_component(i)
return self._sub_elements[sub_element_index].extract_component(component)

[docs]    def extract_subelement_reference_component(self, i):
"""Extract direct subelement index and subelement relative
reference_component index for a given reference_component index."""
if isinstance(i, int):
i = (i,)
self._check_reference_component(i)

# Select between indexing modes
assert len(self.reference_value_shape()) == 1
# Indexing into a long vector of flattened subelement shapes
j, = i

# Find subelement for this index
for sub_element_index, e in enumerate(self._sub_elements):
sh = e.reference_value_shape()
si = product(sh)
if j < si:
break
j -= si
if j < 0:
error("Moved past last value reference_component!")

# Convert index into a shape tuple
st = shape_to_strides(sh)
reference_component = unflatten_index(j, st)
return (sub_element_index, reference_component)

[docs]    def extract_reference_component(self, i):
"""Recursively extract reference_component index relative to a (simple) element
and that element for given value reference_component index."""
sub_element_index, reference_component = self.extract_subelement_reference_component(i)
return self._sub_elements[sub_element_index].extract_reference_component(reference_component)

[docs]    def is_cellwise_constant(self, component=None):
"""Return whether the basis functions of this
element is spatially constant over each cell."""
if component is None:
return all(e.is_cellwise_constant() for e in self.sub_elements())
else:
i, e = self.extract_component(component)
return e.is_cellwise_constant()

[docs]    def degree(self, component=None):
"Return polynomial degree of finite element."
if component is None:
return self._degree  # from FiniteElementBase, computed as max of subelements in __init__
else:
i, e = self.extract_component(component)
return e.degree()

[docs]    def reconstruct(self, **kwargs):
return MixedElement(*[e.reconstruct(**kwargs) for e in self.sub_elements()])

def __str__(self):
"Format as string for pretty printing."
tmp = ", ".join(str(element) for element in self._sub_elements)
return "<Mixed element: (" + tmp + ")>"

[docs]    def shortstr(self):
"Format as string for pretty printing."
tmp = ", ".join(element.shortstr() for element in self._sub_elements)
return "Mixed<" + tmp + ">"

[docs]class VectorElement(MixedElement):
"A special case of a mixed finite element where all elements are equal."

def __init__(self, family, cell=None, degree=None, dim=None,
"""
Create vector element (repeated mixed element)

*Arguments*
family (string)
The finite element family (or an existing FiniteElement)
cell
The geometric cell, ignored if family is a FiniteElement
degree (int)
The polynomial degree, ignored if family is a FiniteElement
dim (int)
The value dimension of the element (optional)
form_degree (int)
The form degree (FEEC notation, used when field is
viewed as k-form), ignored if family is a FiniteElement
The quadrature scheme (optional), ignored if family is a FiniteElement
variant
Hint for the local basis function variant (optional)
"""

if isinstance(family, FiniteElementBase):
sub_element = family
cell = sub_element.cell()
variant = sub_element.variant()
else:
if cell is not None:
cell = as_cell(cell)
# Create sub element
sub_element = FiniteElement(family, cell, degree,
form_degree=form_degree,
variant=variant)

# Set default size if not specified
if dim is None:
if cell is None:
error("Cannot infer vector dimension without a cell.")
dim = cell.geometric_dimension()

self._mapping = sub_element.mapping()
# Create list of sub elements for mixed element constructor
sub_elements = [sub_element] * dim

# Compute value shapes
value_shape = (dim,) + sub_element.value_shape()
reference_value_shape = (dim,) + sub_element.reference_value_shape()

# Initialize element data
MixedElement.__init__(self, sub_elements, value_shape=value_shape,
reference_value_shape=reference_value_shape)
value_shape, reference_value_shape)
self._sub_element = sub_element

if variant is None:
var_str = ""
else:
var_str = ", variant='" + variant + "'"

# Cache repr string
self._repr = "VectorElement(%s, dim=%d%s)" % (
repr(sub_element), len(self._sub_elements), var_str)

[docs]    def reconstruct(self, **kwargs):
sub_element = self._sub_element.reconstruct(**kwargs)
return VectorElement(sub_element, dim=len(self.sub_elements()))

[docs]    def variant(self):
"""Return the variant used to initialise the element."""
return self._sub_element.variant()

[docs]    def mapping(self):
return self._mapping

def __str__(self):
"Format as string for pretty printing."
return ("<vector element with %d components of %s>" %
(len(self._sub_elements), self._sub_element))

[docs]    def shortstr(self):
"Format as string for pretty printing."
return "Vector<%d x %s>" % (len(self._sub_elements),
self._sub_element.shortstr())

[docs]class TensorElement(MixedElement):
"""A special case of a mixed finite element where all elements are
equal.

"""
__slots__ = ("_sub_element", "_shape", "_symmetry",
"_sub_element_mapping",
"_flattened_sub_element_mapping",
"_mapping")

def __init__(self, family, cell=None, degree=None, shape=None,
"""Create tensor element (repeated mixed element with optional symmetries).

:arg family: The family string, or an existing FiniteElement.
:arg cell: The geometric cell (ignored if family is a FiniteElement).
:arg degree: The polynomial degree (ignored if family is a FiniteElement).
:arg shape: The shape of the element (defaults to a square
tensor given by the geometric dimension of the cell).
:arg symmetry: Optional symmetries.
family is a FiniteElement).
:arg variant:  Hint for the local basis function variant (optional)"""

if isinstance(family, FiniteElementBase):
sub_element = family
cell = sub_element.cell()
variant = sub_element.variant()
else:
if cell is not None:
cell = as_cell(cell)
# Create scalar sub element
variant=variant)

# Set default shape if not specified
if shape is None:
if cell is None:
error("Cannot infer tensor shape without a cell.")
dim = cell.geometric_dimension()
shape = (dim, dim)

if symmetry is None:
symmetry = EmptyDict
elif symmetry is True:
# Construct default symmetry dict for matrix elements
if not (len(shape) == 2 and shape[0] == shape[1]):
error("Cannot set automatic symmetry for non-square tensor.")
symmetry = dict(((i, j), (j, i)) for i in range(shape[0])
for j in range(shape[1]) if i > j)
else:
if not isinstance(symmetry, dict):
error("Expecting symmetry to be None (unset), True, or dict.")

# Validate indices in symmetry dict
for i, j in symmetry.items():
if len(i) != len(j):
error("Non-matching length of symmetry index tuples.")
for k in range(len(i)):
if not (i[k] >= 0 and j[k] >= 0 and i[k] < shape[k] and j[k] < shape[k]):
error("Symmetry dimensions out of bounds.")

# Compute all index combinations for given shape
indices = compute_indices(shape)

# Compute mapping from indices to sub element number,
# accounting for symmetry
sub_elements = []
sub_element_mapping = {}
for index in indices:
if index in symmetry:
continue
sub_element_mapping[index] = len(sub_elements)
sub_elements += [sub_element]

# Update mapping for symmetry
for index in indices:
if index in symmetry:
sub_element_mapping[index] = sub_element_mapping[symmetry[index]]
flattened_sub_element_mapping = [sub_element_mapping[index] for i,
index in enumerate(indices)]

# Compute value shape
value_shape = shape

# Compute reference value shape based on symmetries
if symmetry:
reference_value_shape = (product(shape) - len(symmetry),)
self._mapping = "symmetries"
else:
reference_value_shape = shape
self._mapping = sub_element.mapping()

value_shape = value_shape + sub_element.value_shape()
reference_value_shape = reference_value_shape + sub_element.reference_value_shape()
# Initialize element data
MixedElement.__init__(self, sub_elements, value_shape=value_shape,
reference_value_shape=reference_value_shape)
self._family = sub_element.family()
self._degree = sub_element.degree()
self._sub_element = sub_element
self._shape = shape
self._symmetry = symmetry
self._sub_element_mapping = sub_element_mapping
self._flattened_sub_element_mapping = flattened_sub_element_mapping

if variant is None:
var_str = ""
else:
var_str = ", variant='" + variant + "'"

# Cache repr string
self._repr = "TensorElement(%s, shape=%s, symmetry=%s%s)" % (
repr(sub_element), repr(self._shape), repr(self._symmetry), var_str)

[docs]    def variant(self):
"""Return the variant used to initialise the element."""
return self._sub_element.variant()

[docs]    def mapping(self):
return self._mapping

[docs]    def flattened_sub_element_mapping(self):
return self._flattened_sub_element_mapping

[docs]    def extract_subelement_component(self, i):
"""Extract direct subelement index and subelement relative
component index for a given component index."""
if isinstance(i, int):
i = (i,)
self._check_component(i)

i = self.symmetry().get(i, i)
l = len(self._shape)  # noqa: E741
ii = i[:l]
jj = i[l:]
if ii not in self._sub_element_mapping:
error("Illegal component index %s." % (i,))
k = self._sub_element_mapping[ii]
return (k, jj)

[docs]    def symmetry(self):
"""Return the symmetry dict, which is a mapping :math:c_0 \\to c_1
meaning that component :math:c_0 is represented by component
:math:c_1.
A component is a tuple of one or more ints."""
return self._symmetry

[docs]    def reconstruct(self, **kwargs):
sub_element = self._sub_element.reconstruct(**kwargs)
return TensorElement(sub_element, shape=self._shape, symmetry=self._symmetry)

def __str__(self):
"Format as string for pretty printing."
if self._symmetry:
tmp = ", ".join("%s -> %s" % (a, b) for (a, b) in self._symmetry.items())
sym = " with symmetries (%s)" % tmp
else:
sym = ""
return ("<tensor element with shape %s of %s%s>" %
(self.value_shape(), self._sub_element, sym))

[docs]    def shortstr(self):
"Format as string for pretty printing."
if self._symmetry:
tmp = ", ".join("%s -> %s" % (a, b) for (a, b) in self._symmetry.items())
sym = " with symmetries (%s)" % tmp
else:
sym = ""
return "Tensor<%s x %s%s>" % (self.value_shape(),
self._sub_element.shortstr(), sym)