Source code for ufl.algorithms.replace_derivative_nodes
"""Algorithm for replacing derivative nodes in a BaseForm or Expr."""
import ufl
from ufl.algorithms.analysis import extract_arguments
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.constantvalue import as_ufl
from ufl.corealg.multifunction import MultiFunction
from ufl.tensors import ListTensor
[docs]
class DerivativeNodeReplacer(MultiFunction):
"""Replace derivative nodes with new derivative nodes."""
def __init__(self, mapping, **derivative_kwargs):
"""Initialise."""
super().__init__()
self.mapping = mapping
self.der_kwargs = derivative_kwargs
expr = MultiFunction.reuse_if_untouched
[docs]
def coefficient_derivative(self, cd, o, coefficients, arguments, coefficient_derivatives):
"""Apply to coefficient_derivative."""
der_kwargs = self.der_kwargs
new_coefficients = tuple(
self.mapping[c] if c in self.mapping.keys() else c for c in coefficients.ufl_operands
)
# Ensure type compatibility for arguments!
if "argument" not in der_kwargs.keys():
# Argument's number/part can be retrieved from the former coefficient derivative.
arguments = arguments.ufl_operands
new_arguments = ()
for c, a in zip(new_coefficients, arguments):
if isinstance(a, ListTensor):
(a,) = extract_arguments(a)
new_arguments += (type(a)(c.ufl_function_space(), a.number(), a.part()),)
der_kwargs.update({"argument": new_arguments})
return ufl.derivative(o, new_coefficients, **der_kwargs)
[docs]
def replace_derivative_nodes(expr, mapping, **derivative_kwargs):
"""Replaces derivative nodes.
Replaces the variable with respect to which the derivative is taken.
This is called during apply_derivatives to treat delayed derivatives.
Example: Let u be a Coefficient, N an ExternalOperator independent
of u (i.e. N's operands don't depend on u),
and let uhat and Nhat be Arguments.
F = u ** 2 * N * dx
dFdu = derivative(F, u, uhat)
dFdN = replace_derivative_nodes(dFdu, {u: N}, argument=Nhat)
Then, by subsequently expanding the derivatives we have:
dFdu -> 2 * u * uhat * N * dx
dFdN -> u ** 2 * Nhat * dx
Args:
expr: An Expr or BaseForm.
mapping: A dict with from:to replacements to perform.
derivative_kwargs: A dict containing the keyword arguments for
derivative (i.e. `argument` and `coefficient_derivatives`).
"""
mapping2 = dict((k, as_ufl(v)) for (k, v) in mapping.items())
return map_integrand_dags(DerivativeNodeReplacer(mapping2, **derivative_kwargs), expr)