Matrix-free conjugate gradient solver for the Poisson equation
This demo illustrates how to solve the Poisson equation using a matrix-free conjugate gradient (CG) solver. In particular, it illustrates how to
Solve a linear partial differential equation using a matrix-free conjugate gradient (CG) solver
Create and apply Dirichlet boundary conditions
Compute approximation error as compared with a known exact solution,
Python script
Jupyter notebook
Note
This demo illustrates the use of a matrix-free conjugate gradient solver. Many practical problems will also require a preconditioner to create an efficient solver. This is not covered here.
Problem definition
For a domain \(\Omega \subset \mathbb{R}^n\) with boundary \(\partial \Omega\), the Poisson equation with Dirichlet boundary conditions reads:
The variational problem reads: Given a suitable function space satisfying the essential boundary condition (\(u = u_{\rm D} \ {\rm on} \ \partial\Omega\)), \(V\), and its homogenised counterpart, \(V_0\), find \(u \in V\) such that
where the bilinear and linear formulations are
respectively. In this demo we select:
\(\Omega = [0,1] \times [0,1]\) (a square)
\(u_{\rm D} = 1 + x^2 + 2y^2\)
\(f = -6\)
The function \(u_{\rm D}\) is futher the exact solution of the posed problem.
Implementation
The modules that will be used are imported:
from mpi4py import MPI
import numpy as np
import dolfinx
import ufl
from dolfinx import fem, la
from ufl import action, dx, grad, inner
We begin by using create_rectangle to create a rectangular
Mesh of the domain, and creating a
finite element FunctionSpace
on the mesh.
dtype = dolfinx.default_scalar_type
real_type = np.real(dtype(0.0)).dtype
comm = MPI.COMM_WORLD
mesh = dolfinx.mesh.create_rectangle(comm, [[0.0, 0.0], [1.0, 1.0]], [10, 10], dtype=real_type)
# Create function space
degree = 2
V = fem.functionspace(mesh, ("Lagrange", degree))
The second argument to functionspace is a tuple consisting of (family, degree), where family is the finite element family, and degree
specifies the polynomial degree. In this case V consists of
third-order, continuous Lagrange finite element functions.
Next, we locate the mesh facets that lie on the domain boundary
\(\partial\Omega\). We do this by first calling
create_connectivity
and then retrieving all facets on the boundary using
exterior_facet_indices.
tdim = mesh.topology.dim
mesh.topology.create_connectivity(tdim - 1, tdim)
facets = dolfinx.mesh.exterior_facet_indices(mesh.topology)
We now find the degrees of freedom that are associated with the boundary
facets using
locate_dofs_topological
dofs = fem.locate_dofs_topological(V=V, entity_dim=tdim - 1, entities=facets)
and use dirichletbc to define the
essential boundary condition. On the boundary we prescribe the
Function uD, which we create by
interpolating the expression \(u_{\rm D}\) in the finite element space
\(V\).
uD = fem.Function(V, dtype=dtype)
uD.interpolate(lambda x: 1 + x[0] ** 2 + 2 * x[1] ** 2)
bc = fem.dirichletbc(value=uD, dofs=dofs)
Next, we express the variational problem using UFL.
x = ufl.SpatialCoordinate(mesh)
u = ufl.TrialFunction(V)
v = ufl.TestFunction(V)
f = fem.Constant(mesh, dtype(-6.0))
a = inner(grad(u), grad(v)) * dx
L = inner(f, v) * dx
L_fem = fem.form(L, dtype=dtype)
For the matrix-free solvers we also define a second linear form M as
the action of the bilinear form \(a\) on an
arbitrary Function ui. This linear
form is defined as
ui = fem.Function(V, dtype=dtype)
M = action(a, ui)
M_fem = fem.form(M, dtype=dtype)
Matrix-free conjugate gradient solver
The right hand side vector \(b - A x_{\rm bc}\) is the assembly of the linear
form \(L\) where the essential Dirichlet boundary conditions are implemented
using lifting. Since we want to avoid assembling the matrix A, we compute
the necessary matrix-vector product using the linear form M explicitly.
# Apply lifting: b <- b - A * x_bc
b = fem.assemble_vector(L_fem)
ui.x.array[:] = 0.0
bc.set(ui.x.array, alpha=-1.0)
fem.assemble_vector(b.array, M_fem)
b.scatter_reverse(la.InsertMode.add)
# Set BC dofs to zero on right hand side
bc.set(b.array, alpha=0.0)
b.scatter_forward()
To implement the matrix-free CG solver using DOLFINx vectors, we define the
function action_A to compute the matrix-vector product \(y = A x\).
def action_A(x, y):
# Set coefficient vector of the linear form M and ensure it is updated
# across processes
ui.x.array[:] = x.array
ui.x.scatter_forward()
# Compute action of A on ui using the linear form M
y.array[:] = 0.0
fem.assemble_vector(y.array, M_fem)
y.scatter_reverse(la.InsertMode.add)
# Set BC dofs to zero
bc.set(y.array, alpha=0.0)
Basic conjugate gradient solver
Solves the problem A x = b, using the function action_A as the operator,
x as an initial guess of the solution, and b as the right hand side
vector. comm is the MPI Communicator, max_iter is the maximum number of
iterations, rtol is the relative tolerance.
def cg(comm, action_A, x: la.Vector, b: la.Vector, max_iter: int = 200, rtol: float = 1e-6):
rtol2 = rtol**2
nr = b.index_map.size_local
def _global_dot(comm, v0, v1):
# Only use the owned dofs in vector (up to nr)
return comm.allreduce(np.vdot(v0[:nr], v1[:nr]), MPI.SUM)
# Get initial y = A.x
y = la.vector(b.index_map, 1, dtype)
action_A(x, y)
# Copy residual to p
r = b.array - y.array
p = la.vector(b.index_map, 1, dtype)
p.array[:] = r
# Iterations of CG
rnorm0 = _global_dot(comm, r, r)
rnorm = rnorm0
for k in range(max_iter):
action_A(p, y)
alpha = rnorm / _global_dot(comm, p.array, y.array)
x.array[:] += alpha * p.array
r -= alpha * y.array
rnorm_new = _global_dot(comm, r, r)
beta = rnorm_new / rnorm
rnorm = rnorm_new
if comm.rank == 0:
print(k, np.sqrt(rnorm / rnorm0))
if rnorm / rnorm0 < rtol2:
x.scatter_forward()
return k
p.array[:] = beta * p.array + r
raise RuntimeError(f"Solver exceeded max iterations ({max_iter}).")
This matrix-free solver is now used to compute the finite element solution. The finite element solution’s approximation error as compared with the exact solution is measured in the \(L_2\)-norm.
rtol = 1e-6
u = fem.Function(V, dtype=dtype)
iter_cg1 = cg(mesh.comm, action_A, u.x, b, max_iter=200, rtol=rtol)
# Set BC values in the solution vector
bc.set(u.x.array, alpha=1.0)
def L2Norm(u):
val = fem.assemble_scalar(fem.form(inner(u, u) * dx, dtype=dtype))
return np.sqrt(comm.allreduce(val, op=MPI.SUM))
# Print CG iteration number and error
error_L2_cg1 = L2Norm(u - uD)
if mesh.comm.rank == 0:
print("Matrix-free CG solver using DOLFINx vectors:")
print(f"CG iterations until convergence: {iter_cg1}")
print(f"L2 approximation error: {error_L2_cg1:.4e}")