Matrix-free solvers in DOLFINx using PETSc#

Author: Jørgen S. Dokken

This demo can be downloaded as a single Python file demo_matrix-free-petsc.py. In this demo, we will demonstrate how to set up a matrix-free solver using PETSc. We will start by defining our variational problem, and then in turn define a custom PETSc SHELL matrix that will handle assembly without ever forming the PETSc MATAIJ system matrix.

Problem definition#

In this example, we consider a projection problem, i.e. Find \((u_h, p_h) \in V_h \times Q_h\) such that

\[ \begin{align} \min_{u_h, p_h} J(u_h, p_h) &= \frac{1}{2} \int_\Omega \vert u_h - f\vert^2~\mathrm{d}x + \int_\Omega \vert p_h - g\vert^2~\mathrm{d}x \end{align} \]

By considering the optimality conditions of this system we arrive at the variational problem: Find \((u_h, p_h) \in V_h \times Q_h\) such that

\[ \begin{align} \int_\Omega (u_h-f) \cdot v~\mathrm{d}x + \int_\Omega (p_h-g) q ~\mathrm{d}x &= 0 \quad \forall (v, q) \in V_h \times Q_h \end{align} \]

We start by importing the necessary modules

from mpi4py import MPI
from petsc4py import PETSc

import numpy as np

import basix.ufl
import dolfinx.fem.petsc
import ufl

Matrix-free operator#

Many interative methods, such as the conjugate gradient method, only requires the action of the system matrix on a vector. Thus, one can assemble a form of rank 1 with a given function replacing the trial function and obtain this vector.

We do this using ufl.action to create a rank 1 form. Additionally, some preconditioners, such as the Jacobi preconditioner, need access to the diagonal of the matrix. This can be obtained by assembling the form with a special option to only compute the diagonal, i.e. the form_compiler_options={"part":"diagonal"}, which is passed to FFCx when calling dolfinx.fem.form.

For the assembly itself, we require an operator that can compute the action of the matrix on a vector, as well as provide the diagonal. We provide this class below:

class MatrixFreeOperator:
    """Matrix-free operator for a bilinear form.

    This class provides an operator that takes a bilinear form and provides
    the action of the bilinear form on a vector, without explicitly
    assembling the matrix.
    """

    # Data allocation for the operator
    _w: dolfinx.fem.Function | list[dolfinx.fem.Function]  # Store working solution
    _diagonal: PETSc.Vec  # Temporary storage of diagonal
    _vector: PETSc.Vec  # Temporary storage of action

    _vector_product: dolfinx.fem.Form | list[dolfinx.fem.Form]  # Compiled matrix-vector product
    _compiled_diagonal: dolfinx.fem.Form | list[ufl.form.Form]  # Compiled diagonal form

    def __init__(
        self,
        bilinear_form: ufl.Form,
        bcs: list[dolfinx.fem.DirichletBC] | None = None,
        form_compiler_options: dict | None = None,
        jit_options: dict | None = None,
    ):
        """A matrix-free operator for a bilinear form.

        Args:
            bilinear_form: The bilinear form.
            bcs: A list of Dirichlet boundary conditions.
            form_compiler_options: Options to pass to the form compiler.
            jit_options: Options to pass to the JIT compiler.
        """
        jit_options = {} if jit_options is None else jit_options
        form_compiler_options = {} if form_compiler_options is None else form_compiler_options
        diagnal_options = form_compiler_options.copy()
        diagnal_options["part"] = "diagonal"

        # Store the boundary conditions
        self._bcs = [] if bcs is None else bcs

        # Use the number of arguments in the bilinear for to decide if we
        # have a mixed function space
        arguments = bilinear_form.arguments()
        if len(arguments) > 2:
            # Handle MixedFunctionSpace forms
            size = max(arg.part() for arg in arguments) + 1
            assert max(arg.number() for arg in arguments) == 1
            a_blocked = ufl.extract_blocks(bilinear_form)
            assert len(a_blocked) == size
            spaces = [a_blocked[i][i].arguments()[0].ufl_function_space() for i in range(size)]

            self._w = [dolfinx.fem.Function(space) for space in spaces]
            self._diagonal = dolfinx.fem.petsc.create_vector(spaces)
            self._vector = dolfinx.fem.petsc.create_vector(spaces)
            self._vector_product = dolfinx.fem.form(
                ufl.extract_blocks(ufl.action(bilinear_form, self._w)),
                form_compiler_options=form_compiler_options,
                jit_options=jit_options,
            )
            self._compiled_diagonal = dolfinx.fem.form(
                [a_blocked[i][i] for i in range(size)],
                form_compiler_options=diagnal_options,
                jit_options=jit_options,
            )
        else:
            # Handle "standard" bilinear forms
            assert len(arguments) == 2, "Only bilinear forms are supported"
            self._w = dolfinx.fem.Function(bilinear_form.arguments()[-1].ufl_function_space())
            self._diagonal = dolfinx.fem.petsc.create_vector(self._w.function_space)
            self._vector = dolfinx.fem.petsc.create_vector(self._w.function_space)
            self._vector_product = dolfinx.fem.form(ufl.action(bilinear_form, self._w))
            self._compiled_diagonal = dolfinx.fem.form(
                bilinear_form,
                form_compiler_options=diagnal_options,
                jit_options=jit_options,
            )

    def mult(self, mat, X, Y):
        """Compute Y = A * X, where A is the bilinear form.

        Note:
            This method never assembles the full matrix A.

        Args:
            mat: The PETSc matrix (not used).
            X: The input vector.
            Y: The output vector.

        """
        # Move data into local working array

        dolfinx.fem.petsc.assign(X, self._w)

        # Zero out any input from Dirichlet BCs
        if isinstance(self._compiled_diagonal, dolfinx.fem.Form):
            bcs0 = self._bcs
            for bc in self._bcs:
                di = bc.dof_indices()
                odi = di[0][: di[1]]
                self._w.x.array[odi] = 0
            self._w.x.scatter_forward()

        else:
            bcs0 = dolfinx.fem.bcs_by_block(
                dolfinx.fem.extract_function_spaces(self._compiled_diagonal), self._bcs
            )
            for i, bcs in enumerate(bcs0):
                for bc in bcs:
                    di = bc.dof_indices()
                    odi = di[0]
                    self._w[i].x.array[odi] = 0
                self._w[i].x.scatter_forward()

        # Assemble action
        with self._vector.localForm() as loc:
            loc.set(0)

        dolfinx.fem.petsc.assemble_vector(self._vector, self._vector_product)
        dolfinx.la.petsc._ghost_update(
            self._vector, PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE
        )

        # Insert X at Dirichlet dofs
        if isinstance(self._compiled_diagonal, dolfinx.fem.Form):
            bcs0 = self._bcs
            for bc in self._bcs:
                di = bc.dof_indices()
                odi = di[0][: di[1]]
                self._vector.array_w[odi] = X.array_r[odi]
        else:
            bcs0 = dolfinx.fem.bcs_by_block(
                dolfinx.fem.extract_function_spaces(self._compiled_diagonal), self._bcs
            )
            offset0, _ = self._vector.getAttr("_blocks")
            for bcs, off0, off1 in zip(bcs0, offset0[:-1], offset0[1:], strict=True):  # type: ignore[assignment]
                v_array = self._vector.array_w[off0:off1]
                x_array = X.array_r[off0:off1]
                for bc in bcs:
                    di = bc.dof_indices()
                    odi = di[0][: di[1]]
                    v_array[odi] = x_array[odi]
        dolfinx.la.petsc._ghost_update(
            self._vector, PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD
        )
        Y.setArray(self._vector)
        Y.ghostUpdate(PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)

    def getDiagonal(self, mat, vec):
        """Compute the diagonal of the bilinear form.

        Note:
            This is required for Jacobi preconditioning.

        Args:
            mat: The PETSc matrix (not used).
            vec: The output vector to store the diagonal.
        """
        # NOTE: Only have to go through a DOLFINx vector due to a
        # bug in PETSc, similar to:
        # https://gitlab.com/petsc/petsc/-/issues/1645
        with self._diagonal.localForm() as loc:
            loc.set(0)
        dolfinx.fem.petsc.assemble_vector(self._diagonal, self._compiled_diagonal)
        self._diagonal.ghostUpdate(PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)
        if isinstance(self._compiled_diagonal, dolfinx.fem.Form):
            for bc in self._bcs:
                di = bc.dof_indices()
                odi = di[0][: di[1]]
                self._diagonal.array_w[odi] = 1
        else:
            bcs0 = dolfinx.fem.bcs_by_block(
                dolfinx.fem.extract_function_spaces(self._compiled_diagonal), self._bcs
            )
            offset0, _ = self._diagonal.getAttr("_blocks")
            for bcs, off0 in zip(bcs0, offset0[:-1], strict=True):  # type: ignore[assignment]
                for bc in bcs:
                    di = bc.dof_indices()
                    odi = di[0][: di[1]]
                    self._diagonal.array_w[off0 + odi] = 1
        self._diagonal.ghostUpdate(PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)
        vec.setArray(self._diagonal)

Setting up a Krylov subspace solver with the matrix-free operator#

As we will solve the problem below with different representations of the bilinear form, we provide a convenience function for attaching the matrix-free operator to a PETSc KSP object.

def attach_matrix_free_operator(
    ksp: PETSc.KSP,
    bilinear_form: ufl.Form,
    bcs: list[dolfinx.fem.DirichletBC] | None = None,
):
    """Attach a matrix-free operator to a PETSc KSP object.

    Args:
        ksp: The PETSc KSP object.
        bilinear_form: The bilinear form.
        bcs: A list of Dirichlet boundary conditions.
    """
    # Check if we have something from a mixed function space
    operator = MatrixFreeOperator(bilinear_form, bcs=bcs)

    A = PETSc.Mat().create(ksp.getComm().tompi4py())
    sizes = operator._diagonal.getSizes()
    A.setSizes([sizes, sizes])

    A.setType("python")
    A.setPythonContext(operator)
    A.assemble()
    ksp.setOperators(A)

We are ready to solve our problem. In this demo we will consider two approaches, using basix.ufl.mixed_elementand using a ufl.MixedFunctionSpace.

We start by definng our mesh and finite element spaces.

We define the analytical solutions f and g

x = ufl.SpatialCoordinate(mesh)
f = ufl.as_vector((ufl.cos(2 * x[0]) * ufl.sin(x[1]), x[1]))
g = ufl.sin(3 * x[0]) * ufl.cos(4 * x[1])

Next, we create a general function to extract the bilinear and linear forms from the weak formulation.

def extract_system(
    W: dolfinx.fem.FunctionSpace | ufl.MixedFunctionSpace,
    f: ufl.core.expr.Expr,
    g: ufl.core.expr.Expr,
) -> tuple[ufl.Form, ufl.Form]:
    """Extract the bilinear and linear forms."""
    u_h, p_h = ufl.TrialFunctions(W)
    v, q = ufl.TestFunctions(W)
    residual = ufl.inner(u_h - f, v) * ufl.dx + ufl.inner(p_h - g, q) * ufl.dx
    return ufl.system(residual)

We also define a convenience function for creating the Krylov subspace solver and attaching the matrix free operator.

def create_matrix_free_ksp(
    a: ufl.Form,
    bcs: list[dolfinx.fem.DirichletBC],
    prefix: str,
) -> PETSc.KSP:
    """Create a KSP solver and attach a matrix-free operator."""
    comm = a.ufl_domain().ufl_cargo().comm
    ksp = PETSc.KSP().create(comm)
    attach_matrix_free_operator(ksp, a, bcs=bcs)

    ksp.setMonitor(
        lambda _, its, rnorm: PETSc.Sys.Print(f"{prefix} Iter: {its}, rel. residual: {rnorm:.5e}")
    )
    ksp.setType("cg")
    pc = ksp.getPC()
    pc.setType("jacobi")
    dtype = dolfinx.default_scalar_type
    tol = 1e-10 if dtype == np.float64 else 1e-6
    ksp.setTolerances(atol=tol, rtol=tol, max_it=300)
    ksp.setErrorIfNotConverged(False)
    return ksp

Mixed-element approach#

We start by defining the function spaces using a mixed element from basix.ufl.mixed_element.

def mixed_element(
    mesh: dolfinx.mesh.Mesh, f: ufl.core.expr.Expr, g: ufl.core.expr.Expr
) -> tuple[dolfinx.fem.Function, dolfinx.fem.Function]:
    """Blocked problem using a {py:class}`basix.ufl.mixed_element`."""
    # Define function space for mixed element and extract subspaces
    W = dolfinx.fem.functionspace(mesh, basix.ufl.mixed_element([el_0, el_1]))
    V, _ = W.sub(0).collapse()
    Q, _ = W.sub(1).collapse()

    # Extract bilinear and linear forms
    a, L = extract_system(W, f, g)

    # Set up Dirichlet boundary conditions using the analytical solution.
    u_bc_expr = dolfinx.fem.Expression(f, V.element.interpolation_points)
    u_bc = dolfinx.fem.Function(V)
    u_bc.interpolate(u_bc_expr)
    mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim)
    bc_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology)
    bc_dofs_u = dolfinx.fem.locate_dofs_topological((W.sub(0), V), mesh.topology.dim - 1, bc_facets)
    p_bc_expr = dolfinx.fem.Expression(g, Q.element.interpolation_points)
    p_bc = dolfinx.fem.Function(Q)
    p_bc.interpolate(p_bc_expr)
    bc_dofs_p = dolfinx.fem.locate_dofs_topological((W.sub(1), Q), mesh.topology.dim - 1, bc_facets)
    bcs = [
        dolfinx.fem.dirichletbc(u_bc, bc_dofs_u, W.sub(0)),
        dolfinx.fem.dirichletbc(p_bc, bc_dofs_p, W.sub(1)),
    ]

    # Assemble RHS with boundary conditions
    b = dolfinx.fem.petsc.assemble_vector(dolfinx.fem.form(L))
    dolfinx.fem.petsc.apply_lifting(b, [dolfinx.fem.form(a)], [bcs])
    b.ghostUpdate(PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)
    dolfinx.fem.petsc.set_bc(b, bcs)

    # Setup matrix free KSP
    ksp = create_matrix_free_ksp(a, bcs, "MixedElement")

    # Solve the system
    wh = dolfinx.fem.Function(W)
    ksp.solve(b, wh.x.petsc_vec)
    wh.x.scatter_forward()

    # Extract subspace solutions
    uh = wh.sub(0).collapse()
    ph = wh.sub(1).collapse()
    return uh, ph

MixedFunctionSpace approach#

We can also define the function space using ufl.MixedFunctionSpace. This approach is more efficient if there are many subspaces, where there is little cross coupling. It is also more flexible, as each sub-space can be defined on different meshes, such as submeshes of codimension 0 and 1.

def mixed_function_space(
    mesh: dolfinx.mesh.Mesh, f: ufl.core.expr.Expr, g: ufl.core.expr.Expr
) -> tuple[dolfinx.fem.Function, dolfinx.fem.Function]:
    """Blocked problem using a {py:class}`ufl.MixedFunctionSpace`."""
    # Create mixed function space from two function spaces
    V = dolfinx.fem.functionspace(mesh, el_0)
    Q = dolfinx.fem.functionspace(mesh, el_1)
    W = ufl.MixedFunctionSpace(V, Q)

    # Extract bilinear and linear forms
    a, L = extract_system(W, f, g)

    # Create Dirichlet boundary conditions using the analytical solution.
    u_bc_expr = dolfinx.fem.Expression(f, V.element.interpolation_points)
    u_bc = dolfinx.fem.Function(V)
    u_bc.interpolate(u_bc_expr)
    mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim)
    bc_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology)
    bc_dofs_u = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, bc_facets)
    p_bc_expr = dolfinx.fem.Expression(g, Q.element.interpolation_points)
    p_bc = dolfinx.fem.Function(Q)
    p_bc.interpolate(p_bc_expr)
    bc_dofs_p = dolfinx.fem.locate_dofs_topological(Q, mesh.topology.dim - 1, bc_facets)
    bcs = [
        dolfinx.fem.dirichletbc(u_bc, bc_dofs_u),
        dolfinx.fem.dirichletbc(p_bc, bc_dofs_p),
    ]

    # Compile forms and assemble the RHS with boundary conditions
    # The lifting operation never assembles the full matrix A, it instead
    # assemble the local product A_local g_local, where g_local is the
    # local representation of the Dirichlet data.

    L_compiled = dolfinx.fem.form(ufl.extract_blocks(L))
    a_compiled = dolfinx.fem.form(ufl.extract_blocks(a))
    b = dolfinx.fem.petsc.assemble_vector(L_compiled)
    bcs0 = dolfinx.fem.bcs_by_block(dolfinx.fem.extract_function_spaces(L_compiled), bcs)
    dolfinx.fem.petsc.apply_lifting(b, a_compiled, bcs0)
    b.ghostUpdate(PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)
    dolfinx.fem.petsc.set_bc(b, bcs0)
    b.ghostUpdate(PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)

    # We define the matrix free KSP and solve the linear system
    ksp = create_matrix_free_ksp(a, bcs, "MixedFunctionSpace")
    wh = b.duplicate()
    ksp.solve(b, wh)
    wh.ghostUpdate(PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)

    #  Assign solution to dolfinx functions
    uh = dolfinx.fem.Function(V)
    ph = dolfinx.fem.Function(Q)
    dolfinx.fem.petsc.assign(wh, [uh, ph])
    return uh, ph

Checking solution accuracy#

We can now solve the problem using both approaches and compare the solution accuracy. We compute the L2-error between the numerical solution and the analytical solution.

u_me, p_me = mixed_element(mesh, f, g)

u_mfs, p_mfs = mixed_function_space(mesh, f, g)


def compute_L2_error(uh: ufl.core.expr.Expr, u_ex: ufl.core.expr.Expr) -> float:
    """Compute the L2-error between two expressions.

    Args:
        uh: The approximate solution.
        u_ex: The exact solution.
    """
    error = ufl.inner(uh - u_ex, uh - u_ex) * ufl.dx
    error = dolfinx.fem.assemble_scalar(dolfinx.fem.form(error))
    return np.sqrt(mesh.comm.allreduce(error, op=MPI.SUM))


error_u_me = compute_L2_error(u_me, f)
error_p_me = compute_L2_error(p_me, g)

PETSc.Sys.Print("Mixed element:")
PETSc.Sys.Print(f"L2 error (u): {error_u_me:.5e}")
PETSc.Sys.Print(f"L2 error (p): {error_p_me:.5e}")

error_u_mfs = compute_L2_error(u_mfs, f)
error_p_mfs = compute_L2_error(p_mfs, g)

PETSc.Sys.Print("Mixed function space:")
PETSc.Sys.Print(f"L2 error (u): {error_u_mfs:.5e}")
PETSc.Sys.Print(f"L2 error (p): {error_p_mfs:.5e}")