# Copyright (C) 2021 Matthew W. Scroggs and Chris Richardson
#
# This file is part of FFCx.(https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Finite element interface."""
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
import ufl.finiteelement.FiniteElementBase
import warnings
from abc import ABC, abstractmethod
import basix
import numpy
import ufl
import basix.ufl_wrapper
import functools
[docs]@functools.lru_cache()
def create_basix_element(family_type, cell_type, degree, variant_info, discontinuous):
"""Create a basix element."""
return basix.create_element(family_type, cell_type, degree, *variant_info, discontinuous)
[docs]def create_element(element: ufl.finiteelement.FiniteElementBase) -> BaseElement:
"""Create an FFCx element from a UFL element.
Args:
element: A UFL finite element
Returns:
A FFCx finite element
"""
# TODO: EnrichedElement
if isinstance(element, basix.ufl_wrapper.BasixElement):
return BasixElement(element.basix_element)
if isinstance(element, ufl.VectorElement):
return BlockedElement(create_element(element.sub_elements()[0]),
element.num_sub_elements())
if isinstance(element, ufl.TensorElement):
return BlockedElement(create_element(element.sub_elements()[0]),
element.num_sub_elements(), None) # TODO: block shape
if isinstance(element, ufl.MixedElement):
return MixedElement([create_element(e) for e in element.sub_elements()])
if element.family() == "Quadrature":
return QuadratureElement(element)
family_name = element.family()
discontinuous = False
if family_name.startswith("Discontinuous "):
family_name = family_name[14:]
discontinuous = True
if family_name == "DP":
family_name = "P"
discontinuous = True
if family_name == "DQ":
family_name = "Q"
discontinuous = True
if family_name == "DPC":
discontinuous = True
family_type = basix.finite_element.string_to_family(family_name, element.cell().cellname())
cell_type = basix.cell.string_to_type(element.cell().cellname())
variant_info = []
if family_type == basix.ElementFamily.P and element.variant() == "equispaced":
# This is used for elements defining cells
variant_info = [basix.LagrangeVariant.equispaced]
else:
if element.variant() is not None:
raise ValueError("UFL variants are not supported by FFCx. Please wrap a Basix element directly.")
EF = basix.ElementFamily
if family_type == EF.P:
variant_info = [basix.LagrangeVariant.gll_warped]
elif family_type in [EF.RT, EF.N1E]:
variant_info = [basix.LagrangeVariant.legendre]
elif family_type in [EF.serendipity, EF.BDM, EF.N2E]:
variant_info = [basix.LagrangeVariant.legendre, basix.DPCVariant.legendre]
elif family_type == EF.DPC:
variant_info = [basix.DPCVariant.diagonal_gll]
return BasixElement(create_basix_element(
family_type, cell_type, element.degree(), tuple(variant_info), discontinuous))
[docs]def basix_index(*args):
"""Get the Basix index of a derivative."""
return basix.index(*args)
[docs]def create_quadrature(cellname, degree, rule):
"""Create a quadrature rule."""
if cellname == "vertex":
return [[]], [1]
quadrature = basix.make_quadrature(
basix.quadrature.string_to_type(rule), basix.cell.string_to_type(cellname), degree)
# The quadrature degree from UFL can be very high for some
# integrals. Print warning if number of quadrature points
# exceeds 100.
num_points = quadrature[1].size
if num_points >= 100:
warnings.warn(
f"Number of integration points per cell is: {num_points}. Consider using 'quadrature_degree' "
"to reduce number.")
return quadrature
[docs]def reference_cell_vertices(cellname):
"""Get the vertices of a reference cell."""
return basix.geometry(basix.cell.string_to_type(cellname))
[docs]def map_facet_points(points, facet, cellname):
"""Map points from a reference facet to a physical facet."""
geom = basix.geometry(basix.cell.string_to_type(cellname))
facet_vertices = [geom[i] for i in basix.topology(basix.cell.string_to_type(cellname))[-2][facet]]
return [facet_vertices[0] + sum((i - facet_vertices[0]) * j for i, j in zip(facet_vertices[1:], p))
for p in points]
[docs]class BaseElement(ABC):
"""An abstract element class."""
[docs] @abstractmethod
def tabulate(self, nderivs: int, points: numpy.ndarray):
"""Tabulate the basis functions of the element.
Args:
nderivs: Number of derivatives to tabulate.
points: Points to tabulate at
Returns:
Tabulated basis functions
"""
pass
[docs] @abstractmethod
def get_component_element(self, flat_component: int) -> tuple[BaseElement, int, int]:
"""Get element that represents a component of the element, and the offset and stride of the component.
For example, for a MixedElement, this will return the
sub-element that represents the given component, the offset of
that sub-element, and a stride of 1. For a BlockedElement, this
will return the sub-element, an offset equal to the component
number, and a stride equal to the block size. For vector-valued
element (eg H(curl) and H(div) elements), this returns a
ComponentElement (and as offset of 0 and a stride of 1). When
tabulate is called on the ComponentElement, only the part of the
table for the given component is returned.
Args:
flat_component: The component
Returns:
component element, offset of the component, stride of the component
"""
pass
@property
def element_type(self):
"""Element type."""
raise NotImplementedError
@property
@abstractmethod
def dim(self) -> int:
"""Number of DOFs the element has."""
pass
@property
@abstractmethod
def value_size(self) -> int:
"""Value size of the element.
Equal to ``numpy.prod(value_shape)``.
"""
pass
@property
@abstractmethod
def value_shape(self) -> typing.Tuple[int, ...]:
"""Value shape of the element basis function.
Note:
For scalar elements, ``(1,)`` is returned. This is different
from Basix where the value shape for scalar elements is
``(,)``.
"""
pass
@property
@abstractmethod
def num_entity_dofs(self):
"""Number of DOFs associated with each entity."""
pass
@property
@abstractmethod
def entity_dofs(self):
"""DOF numbers associated with each entity."""
pass
@property
@abstractmethod
def num_entity_closure_dofs(self):
"""Number of DOFs associated with the closure of each entity."""
pass
@property
@abstractmethod
def entity_closure_dofs(self):
"""DOF numbers associated with the closure of each entity."""
pass
@property
@abstractmethod
def num_global_support_dofs(self):
pass
@property
@abstractmethod
def reference_topology(self):
"""Topology of the reference element."""
pass
@property
@abstractmethod
def reference_geometry(self):
"""Geometry of the reference element."""
pass
@property
@abstractmethod
def family_name(self) -> str:
"""Family name of the element."""
pass
@property
@abstractmethod
def element_family(self):
"""Basix element family used to initialise the element."""
pass
@property
@abstractmethod
def lagrange_variant(self):
"""Basix Lagrange variant used to initialise the element."""
pass
@property
@abstractmethod
def dpc_variant(self):
"""Basix DPC variant used to initialise the element."""
pass
@property
@abstractmethod
def cell_type(self):
"""Basix cell type used to initialise the element."""
pass
@property
@abstractmethod
def discontinuous(self) -> bool:
"""True if the discontinuous version of the element is used."""
pass
@property
def is_custom_element(self) -> bool:
"""True if the element is a custom Basix element."""
return False
[docs]class BasixElement(BaseElement):
"""An element defined by Basix."""
def __init__(self, element):
self.element = element
[docs] def tabulate(self, nderivs, points):
tab = self.element.tabulate(nderivs, points)
return tab.transpose((0, 1, 3, 2)).reshape((tab.shape[0], tab.shape[1], -1))
[docs] def get_component_element(self, flat_component):
assert flat_component < self.value_size
return ComponentElement(self, flat_component), 0, 1
@property
def element_type(self) -> str:
"""Element type."""
if self.is_custom_element:
return "ufcx_basix_custom_element"
else:
return "ufcx_basix_element"
@property
def dim(self):
return self.element.dim
@property
def value_size(self):
return self.element.value_size
@property
def value_shape(self):
"""Get the value shape of the element."""
if len(self.element.value_shape) == 0:
return (1,)
else:
return self.element.value_shape
@property
def num_entity_dofs(self):
return self.element.num_entity_dofs
@property
def entity_dofs(self):
return self.element.entity_dofs
@property
def num_entity_closure_dofs(self):
return self.element.num_entity_closure_dofs
@property
def entity_closure_dofs(self):
return self.element.entity_closure_dofs
@property
def num_global_support_dofs(self):
# TODO
return 0
@property
def family_name(self):
return self.element.family.name
@property
def reference_topology(self):
return basix.topology(self.element.cell_type)
@property
def reference_geometry(self):
return basix.geometry(self.element.cell_type)
@property
def element_family(self):
return self.element.family
@property
def lagrange_variant(self):
return self.element.lagrange_variant
@property
def dpc_variant(self):
return self.element.dpc_variant
@property
def cell_type(self):
return self.element.cell_type
@property
def discontinuous(self):
return self.element.discontinuous
@property
def is_custom_element(self) -> bool:
"""True if the element is a custom Basix element."""
return self.element.family == basix.ElementFamily.custom
[docs]class ComponentElement(BaseElement):
"""An element representing one component of a BasixElement."""
def __init__(self, element, component):
self.element = element
self.component = component
[docs] def tabulate(self, nderivs, points):
tables = self.element.tabulate(nderivs, points)
output = []
for tbl in tables:
shape = (tbl.shape[0],) + tuple(self.element.value_shape) + (-1,)
tbl = tbl.reshape(shape)
if len(self.element.value_shape) == 1:
output.append(tbl[:, self.component, :])
elif len(self.element.value_shape) == 2:
# TODO: Something different may need doing here if
# tensor is symmetric
vs0 = self.element.value_shape[0]
output.append(tbl[:, self.component // vs0, self.component % vs0, :])
else:
raise NotImplementedError
return output
[docs] def get_component_element(self, flat_component):
if flat_component == 0:
return self, 0, 1
raise NotImplementedError
@property
def dim(self):
raise NotImplementedError
@property
def value_size(self):
raise NotImplementedError
@property
def value_shape(self):
raise NotImplementedError
@property
def num_entity_dofs(self):
raise NotImplementedError
@property
def entity_dofs(self):
raise NotImplementedError
@property
def num_entity_closure_dofs(self):
raise NotImplementedError
@property
def entity_closure_dofs(self):
raise NotImplementedError
@property
def num_global_support_dofs(self):
raise NotImplementedError
@property
def family_name(self) -> str:
raise NotImplementedError
@property
def reference_topology(self):
raise NotImplementedError
@property
def reference_geometry(self):
raise NotImplementedError
@property
def element_family(self):
return self.element.element_family
@property
def lagrange_variant(self):
return self.element.lagrange_variant
@property
def dpc_variant(self):
return self.element.dpc_variant
@property
def cell_type(self):
return self.element.cell_type
@property
def discontinuous(self):
return self.element.discontinuous
[docs]class MixedElement(BaseElement):
"""A mixed element that combines two or more elements."""
def __init__(self, sub_elements):
assert len(sub_elements) > 0
self.sub_elements = sub_elements
[docs] def tabulate(self, nderivs, points):
tables = []
results = [e.tabulate(nderivs, points) for e in self.sub_elements]
for deriv_tables in zip(*results):
new_table = numpy.zeros((len(points), self.value_size * self.dim))
start = 0
for e, t in zip(self.sub_elements, deriv_tables):
for i in range(0, e.dim, e.value_size):
new_table[:, start: start + e.value_size] = t[:, i: i + e.value_size]
start += self.value_size
tables.append(new_table)
return tables
[docs] def get_component_element(self, flat_component):
sub_dims = [0] + [e.dim for e in self.sub_elements]
sub_cmps = [0] + [e.value_size for e in self.sub_elements]
irange = numpy.cumsum(sub_dims)
crange = numpy.cumsum(sub_cmps)
# Find index of sub element which corresponds to the current
# flat component
component_element_index = numpy.where(
crange <= flat_component)[0].shape[0] - 1
sub_e = self.sub_elements[component_element_index]
e, offset, stride = sub_e.get_component_element(flat_component - crange[component_element_index])
# TODO: is this offset correct?
return e, irange[component_element_index] + offset, stride
@property
def element_type(self) -> str:
"""Get the element type."""
return "ufcx_mixed_element"
@property
def dim(self):
return sum(e.dim for e in self.sub_elements)
@property
def value_size(self):
return sum(e.value_size for e in self.sub_elements)
@property
def value_shape(self):
return (sum(e.value_size for e in self.sub_elements), )
@property
def num_entity_dofs(self):
data = [e.num_entity_dofs for e in self.sub_elements]
return [[sum(d[tdim][entity_n] for d in data) for entity_n, _ in enumerate(entities)]
for tdim, entities in enumerate(data[0])]
@property
def entity_dofs(self):
dofs = [[[] for i in entities] for entities in self.sub_elements[0].entity_dofs]
start_dof = 0
for e in self.sub_elements:
for tdim, entities in enumerate(e.entity_dofs):
for entity_n, entity_dofs in enumerate(entities):
dofs[tdim][entity_n] += [start_dof + i for i in entity_dofs]
start_dof += e.dim
return dofs
@property
def num_entity_closure_dofs(self):
data = [e.num_entity_closure_dofs for e in self.sub_elements]
return [[sum(d[tdim][entity_n] for d in data) for entity_n, _ in enumerate(entities)]
for tdim, entities in enumerate(data[0])]
@property
def entity_closure_dofs(self):
dofs = [[[] for i in entities] for entities in self.sub_elements[0].entity_closure_dofs]
start_dof = 0
for e in self.sub_elements:
for tdim, entities in enumerate(e.entity_closure_dofs):
for entity_n, entity_dofs in enumerate(entities):
dofs[tdim][entity_n] += [start_dof + i for i in entity_dofs]
start_dof += e.dim
return dofs
@property
def num_global_support_dofs(self):
return sum(e.num_global_support_dofs for e in self.sub_elements)
@property
def family_name(self):
return "mixed element"
@property
def reference_topology(self):
return self.sub_elements[0].reference_topology
@property
def reference_geometry(self):
return self.sub_elements[0].reference_geometry
@property
def lagrange_variant(self):
return None
@property
def dpc_variant(self):
return None
@property
def element_family(self):
return None
@property
def cell_type(self):
return None
@property
def discontinuous(self):
return False
[docs]class BlockedElement(BaseElement):
"""An element with a block size that contains multiple copies of a sub element."""
def __init__(self, sub_element, block_size, block_shape=None):
assert block_size > 0
if sub_element.value_size != 1:
raise ValueError("Blocked elements (VectorElement and TensorElement) of "
"non-scalar elements are not supported. Try using MixedElement "
"instead.")
self.sub_element = sub_element
self.block_size = block_size
if block_shape is None:
self.block_shape = (block_size, )
else:
self.block_shape = block_shape
[docs] def tabulate(self, nderivs, points):
assert len(self.block_shape) == 1 # TODO: block shape
assert self.value_size == self.block_size # TODO: remove this assumption
output = []
for table in self.sub_element.tabulate(nderivs, points):
new_table = numpy.zeros((table.shape[0], table.shape[1] * self.block_size**2))
for block in range(self.block_size):
col = block * (self.block_size + 1)
new_table[:, col: col + table.shape[1] * self.block_size**2: self.block_size**2] = table
output.append(new_table)
return output
[docs] def get_component_element(self, flat_component):
return self.sub_element, flat_component, self.block_size
@property
def element_type(self):
"""Element type."""
return self.sub_element.element_type
@property
def dim(self):
return self.sub_element.dim * self.block_size
@property
def value_size(self):
return self.block_size * self.sub_element.value_size
@property
def value_shape(self):
return (self.value_size, )
@property
def num_entity_dofs(self):
return [[j * self.block_size for j in i] for i in self.sub_element.num_entity_dofs]
@property
def entity_dofs(self):
# TODO: should this return this, or should it take blocks into
# account?
return [[[k * self.block_size + b for k in j for b in range(self.block_size)]
for j in i] for i in self.sub_element.entity_dofs]
@property
def num_entity_closure_dofs(self):
return [[j * self.block_size for j in i] for i in self.sub_element.num_entity_closure_dofs]
@property
def entity_closure_dofs(self):
# TODO: should this return this, or should it take blocks into
# account?
return [[[k * self.block_size + b for k in j for b in range(self.block_size)]
for j in i] for i in self.sub_element.entity_closure_dofs]
@property
def num_global_support_dofs(self):
return self.sub_element.num_global_support_dofs * self.block_size
@property
def family_name(self):
return self.sub_element.family_name
@property
def reference_topology(self):
return self.sub_element.reference_topology
@property
def reference_geometry(self):
return self.sub_element.reference_geometry
@property
def lagrange_variant(self):
return self.sub_element.lagrange_variant
@property
def dpc_variant(self):
return self.sub_element.dpc_variant
@property
def element_family(self):
return self.sub_element.element_family
@property
def cell_type(self):
return self.sub_element.cell_type
@property
def discontinuous(self):
return self.sub_element.discontinuous
[docs]class QuadratureElement(BaseElement):
"""A quadrature element."""
def __init__(self, ufl_element):
self._points, _ = create_quadrature(ufl_element.cell().cellname(),
ufl_element.degree(), ufl_element.quadrature_scheme())
self._ufl_element = ufl_element
[docs] def tabulate(self, nderivs, points):
if nderivs > 0:
raise ValueError("Cannot take derivatives of Quadrature element.")
if points.shape != self._points.shape:
raise ValueError("Mismatch of tabulation points and element points.")
tables = [numpy.eye(points.shape[0], points.shape[0])]
return tables
[docs] def get_component_element(self, flat_component):
return self, 0, 1
@property
def element_type(self) -> str:
"""Element type."""
return "ufcx_quadrature_element"
@property
def dim(self):
return self._points.shape[0]
@property
def value_size(self):
return 1
@property
def value_shape(self):
return (1,)
@property
def num_entity_dofs(self):
dofs = []
tdim = self._ufl_element.cell().topological_dimension()
if tdim >= 1:
dofs += [[0] * self._ufl_element.cell().num_vertices()]
if tdim >= 2:
dofs += [[0] * self._ufl_element.cell().num_edges()]
if tdim >= 3:
dofs += [[0] * self._ufl_element.cell().num_facets()]
dofs += [[self.dim]]
return dofs
@property
def entity_dofs(self):
start_dof = 0
entity_dofs = []
for i in self.num_entity_dofs:
dofs_list = []
for j in i:
dofs_list.append([start_dof + k for k in range(j)])
start_dof += j
entity_dofs.append(dofs_list)
return entity_dofs
@property
def num_entity_closure_dofs(self):
return self.num_entity_dofs
@property
def entity_closure_dofs(self):
return self.entity_dofs
@property
def num_global_support_dofs(self):
return 0
@property
def reference_topology(self):
raise NotImplementedError
@property
def reference_geometry(self):
raise NotImplementedError
@property
def family_name(self):
return self._ufl_element.family()
@property
def lagrange_variant(self):
return None
@property
def dpc_variant(self):
return None
@property
def element_family(self):
return None
@property
def cell_type(self) -> None:
return None
@property
def discontinuous(self):
return False