Source code for ffcx.ir.integral

# Copyright (C) 2013-2020 Martin Sandve Alnæs and Michal Habera
#
# This file is part of FFCx.(https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later
"""Main algorithm for building the integral intermediate representation."""

import collections
import itertools
import logging
import typing

import numpy as np
import ufl
from ufl.algorithms.balancing import balance_modifiers
from ufl.checks import is_cellwise_constant
from ufl.classes import QuadratureWeight

from ffcx.ir.analysis.factorization import compute_argument_factorization
from ffcx.ir.analysis.graph import build_scalar_graph
from ffcx.ir.analysis.modified_terminals import analyse_modified_terminal, is_modified_terminal
from ffcx.ir.analysis.visualise import visualise_graph
from ffcx.ir.elementtables import UniqueTableReferenceT, build_optimized_tables

logger = logging.getLogger("ffcx")


class ModifiedArgumentDataT(typing.NamedTuple):
    """Modified argument data."""

    ma_index: int
    tabledata: UniqueTableReferenceT


class BlockDataT(typing.NamedTuple):
    """Block data."""

    ttypes: tuple[str, ...]  # list of table types for each block rank
    factor_indices_comp_indices: list[tuple[int, int]]  # list of (factor index, component index)
    all_factors_piecewise: bool  # True if all factors for this block are piecewise
    unames: tuple[str, ...]  # list of unique FE table names for each block rank
    restrictions: tuple[str, ...]  # restriction "+" | "-" | None for each block rank
    transposed: bool  # block is the transpose of another
    is_uniform: bool
    ma_data: tuple[ModifiedArgumentDataT, ...]  # used in "full", "safe" and "partial"
    is_permuted: bool  # Do quad points on facets need to be permuted?


[docs] def compute_integral_ir(cell, integral_type, entity_type, integrands, argument_shape, p, visualise): """Compute intermediate representation for an integral.""" # The intermediate representation dict we're building and returning # here ir = {} # Shared unique tables for all quadrature loops ir["unique_tables"] = {} ir["unique_table_types"] = {} ir["integrand"] = {} for quadrature_rule, integrand in integrands.items(): expression = integrand # Rebalance order of nested terminal modifiers expression = balance_modifiers(expression) # Remove QuadratureWeight terminals from expression and replace with 1.0 expression = replace_quadratureweight(expression) # Build initial scalar list-based graph representation S = build_scalar_graph(expression) # Build terminal_data from V here before factorization. Then we # can use it to derive table properties for all modified # terminals, and then use that to rebuild the scalar graph more # efficiently before argument factorization. We can build # terminal_data again after factorization if that's necessary. initial_terminals = { i: analyse_modified_terminal(v["expression"]) for i, v in S.nodes.items() if is_modified_terminal(v["expression"]) } # Check if we have a mixed-dimensional integral is_mixed_dim = False for domain in ufl.domain.extract_domains(integrand): if domain.topological_dimension() != cell.topological_dimension(): is_mixed_dim = True mt_table_reference = build_optimized_tables( quadrature_rule, cell, integral_type, entity_type, initial_terminals.values(), ir["unique_tables"], use_sum_factorization=p["sum_factorization"], is_mixed_dim=is_mixed_dim, rtol=p["table_rtol"], atol=p["table_atol"], ) # Fetch unique tables for this quadrature rule table_types = {v.name: v.ttype for v in mt_table_reference.values()} tables = {v.name: v.values for v in mt_table_reference.values()} S_targets = [i for i, v in S.nodes.items() if v.get("target", False)] num_components = np.int32(np.prod(expression.ufl_shape)) if "zeros" in table_types.values(): # If there are any 'zero' tables, replace symbolically and rebuild graph for i, mt in initial_terminals.items(): # Set modified terminals with zero tables to zero tr = mt_table_reference.get(mt) if tr is not None and tr.ttype == "zeros": S.nodes[i]["expression"] = ufl.as_ufl(0.0) # Propagate expression changes using dependency list for i, v in S.nodes.items(): deps = [S.nodes[j]["expression"] for j in S.out_edges[i]] if deps: v["expression"] = v["expression"]._ufl_expr_reconstruct_(*deps) # Recreate expression with correct ufl_shape expressions = [ None, ] * num_components for target in S_targets: for comp in S.nodes[target]["component"]: assert expressions[comp] is None expressions[comp] = S.nodes[target]["expression"] expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) # Output diagnostic graph as pdf if visualise: visualise_graph(S, "S.pdf") # Compute factorization of arguments rank = len(argument_shape) F = compute_argument_factorization(S, rank) # Get the 'target' nodes that are factors of arguments, and insert in dict FV_targets = [i for i, v in F.nodes.items() if v.get("target", False)] argument_factorization = {} for fi in FV_targets: # Number of blocks using this factor must agree with number of components # to which this factor contributes. I.e. there are more blocks iff there are more # components assert len(F.nodes[fi]["target"]) == len(F.nodes[fi]["component"]) k = 0 for w in F.nodes[fi]["target"]: comp = F.nodes[fi]["component"][k] argument_factorization[w] = argument_factorization.get(w, []) # Store tuple of (factor index, component index) argument_factorization[w].append((fi, comp)) k += 1 # Get list of indices in F which are the arguments (should be at start) argkeys = set() for w in argument_factorization: argkeys = argkeys | set(w) argkeys = list(argkeys) # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate for i, v in F.nodes.items(): expr = v["expression"] if is_modified_terminal(expr): mt = analyse_modified_terminal(expr) F.nodes[i]["mt"] = mt tr = mt_table_reference.get(mt) if tr is not None: F.nodes[i]["tr"] = tr # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' analyse_dependencies(F, mt_table_reference) # Output diagnostic graph as pdf if visualise: visualise_graph(F, "F.pdf") # Loop over factorization terms block_contributions = collections.defaultdict(list) for ma_indices, fi_ci in sorted(argument_factorization.items()): # Get a bunch of information about this term assert rank == len(ma_indices) trs = tuple(F.nodes[ai]["tr"] for ai in ma_indices) unames = tuple(tr.name for tr in trs) ttypes = tuple(tr.ttype for tr in trs) assert not any(tt == "zeros" for tt in ttypes) blockmap = [] for tr in trs: begin = tr.offset num_dofs = tr.values.shape[3] dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) blockmap.append(dofmap) blockmap = tuple(blockmap) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals block_restrictions = [] for i, ai in enumerate(ma_indices): if trs[i].is_uniform: r = None else: r = F.nodes[ai]["mt"].restriction block_restrictions.append(r) block_restrictions = tuple(block_restrictions) # Check if each *each* factor corresponding to this argument is piecewise all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == "piecewise" for ifi in fi_ci) block_is_permuted = False for name in unames: if tables[name].shape[0] > 1: block_is_permuted = True ma_data = [] for i, ma in enumerate(ma_indices): ma_data.append(ModifiedArgumentDataT(ma, trs[i])) block_is_transposed = False # FIXME: Handle transposes for these block types block_unames = unames blockdata = BlockDataT( ttypes, fi_ci, all_factors_piecewise, block_unames, block_restrictions, block_is_transposed, block_is_uniform, tuple(ma_data), block_is_permuted, ) # Insert in expr_ir for this quadrature loop block_contributions[blockmap].append(blockdata) # Figure out which table names are referenced active_table_names = set() for i, v in F.nodes.items(): tr = v.get("tr") if tr is not None and F.nodes[i]["status"] != "inactive": if tr.has_tensor_factorisation: for t in tr.tensor_factors: active_table_names.add(t.name) else: active_table_names.add(tr.name) # Figure out which table names are referenced in blocks for blockmap, contributions in itertools.chain(block_contributions.items()): for blockdata in contributions: for mad in blockdata.ma_data: if mad.tabledata.has_tensor_factorisation: for t in mad.tabledata.tensor_factors: active_table_names.add(t.name) else: active_table_names.add(mad.tabledata.name) active_tables = {} active_table_types = {} for name in active_table_names: # Drop tables not referenced from modified terminals if table_types[name] not in ("zeros", "ones"): active_tables[name] = tables[name] active_table_types[name] = table_types[name] # Add tables and types for this quadrature rule to global tables dict ir["unique_tables"].update(active_tables) ir["unique_table_types"].update(active_table_types) # Build IR dict for the given expressions # Store final ir for this num_points ir["integrand"][quadrature_rule] = { "factorization": F, "modified_arguments": [F.nodes[i]["mt"] for i in argkeys], "block_contributions": block_contributions, } restrictions = [i.restriction for i in initial_terminals.values()] ir["needs_facet_permutations"] = ( "+" in restrictions and "-" in restrictions ) or is_mixed_dim return ir
def analyse_dependencies(F, mt_unique_table_reference): """Analyse dependencies. Sets 'status' of all nodes to either: 'inactive', 'piecewise' or 'varying' Children of 'target' nodes are either 'piecewise' or 'varying'. All other nodes are 'inactive'. Varying nodes are identified by their tables ('tr'). All their parent nodes are also set to 'varying' - any remaining active nodes are 'piecewise'. """ # Set targets, and dependencies to 'active' targets = [i for i, v in F.nodes.items() if v.get("target")] for _, v in F.nodes.items(): v["status"] = "inactive" while targets: s = targets.pop() F.nodes[s]["status"] = "active" for j in F.out_edges[s]: if F.nodes[j]["status"] == "inactive": targets.append(j) # Build piecewise/varying markers for factorized_vertices varying_ttypes = ("varying", "quadrature", "uniform") varying_indices = [] for i, v in F.nodes.items(): if v.get("mt") is None: continue tr = v.get("tr") if tr is not None: ttype = tr.ttype # Check if table computations have revealed values varying over points if ttype in varying_ttypes: varying_indices.append(i) else: if ttype not in ("fixed", "piecewise", "ones", "zeros"): raise RuntimeError(f"Invalid ttype {ttype}.") elif not is_cellwise_constant(v["expression"]): raise RuntimeError("Error " + str(tr)) # Keeping this check to be on the safe side, # not sure which cases this will cover (if any) # varying_indices.append(i) # Set all parents of active varying nodes to 'varying' while varying_indices: s = varying_indices.pop() if F.nodes[s]["status"] == "active": F.nodes[s]["status"] = "varying" for j in F.in_edges[s]: varying_indices.append(j) # Any remaining active nodes must be 'piecewise' for _, v in F.nodes.items(): if v["status"] == "active": v["status"] = "piecewise" def replace_quadratureweight(expression): """Remove any QuadratureWeight terminals and replace with 1.0.""" r = [] for node in ufl.corealg.traversal.unique_pre_traversal(expression): if is_modified_terminal(node) and isinstance(node, QuadratureWeight): r.append(node) replace_map = {q: 1.0 for q in r} return ufl.algorithms.replace(expression, replace_map)