Matrix-free conjugate gradient solver for the Poisson equation
Download sources
This demo illustrates how to solve the Poisson equation using a matrix-free conjugate gradient (CG) solver. In particular, it illustrates how to:
Solve a linear partial differential equation using a matrix-free conjugate gradient (CG) solver.
Create and apply Dirichlet boundary conditions.
Compute approximation error as compared with a known exact solution.
Note
This demo illustrates the use of a matrix-free conjugate gradient solver. Many practical problems will also require a preconditioner to create an efficient solver. This is not covered here.
Problem definition
For a domain \(\Omega \subset \mathbb{R}^n\) with boundary \(\partial \Omega\), the Poisson equation with Dirichlet boundary conditions reads:
The variational problem reads: Given a suitable function space satisfying the essential boundary condition (\(u = u_{\rm D} \ {\rm on} \ \partial\Omega\)), \(V\), and its homogenised counterpart, \(V_0\), find \(u \in V\) such that
where the bilinear and linear formulations are
respectively. In this demo we select:
\(\Omega = [0,1] \times [0,1]\) (a square)
\(u_{\rm D} = 1 + x^2 + 2y^2\)
\(f = -6\)
The function \(u_{\rm D}\) is further the exact solution of the posed problem.
Implementation
The modules that will be used are imported:
We begin by using create_rectangle to create a rectangular
Mesh of the domain, and creating a
finite element FunctionSpace
on the mesh.
dtype = dolfinx.default_scalar_type
real_type = np.real(dtype(0.0)).dtype
comm = MPI.COMM_WORLD
mesh = dolfinx.mesh.create_rectangle(comm, [[0.0, 0.0], [1.0, 1.0]], (10, 10), dtype=real_type)
degree = 2
V = fem.functionspace(mesh, ("Lagrange", degree))
The second argument to functionspace is a tuple consisting of (family, degree), where family is the finite element family, and degree
specifies the polynomial degree. In this case V consists of
third-order, continuous Lagrange finite element functions.
Next, we locate the mesh facets that lie on the domain boundary
\(\partial\Omega\). We do this by first calling
create_connectivity and then retrieving all
facets on the boundary using exterior_facet_indices.
tdim = mesh.topology.dim
mesh.topology.create_connectivity(tdim - 1, tdim)
facets = dolfinx.mesh.exterior_facet_indices(mesh.topology)
We now find the degrees of freedom that are associated with the boundary
facets using
locate_dofs_topological
dofs = fem.locate_dofs_topological(V=V, entity_dim=tdim - 1, entities=facets)
and use dirichletbc to define the
essential boundary condition. On the boundary we prescribe the
Function uD, which we create by
interpolating the expression \(u_{\rm D}\) in the finite element space
\(V\).
uD = fem.Function(V, dtype=dtype)
uD.interpolate(lambda x: 1 + x[0] ** 2 + 2 * x[1] ** 2)
bc = fem.dirichletbc(value=uD, dofs=dofs)
Next, we express the variational problem using UFL.
x = ufl.SpatialCoordinate(mesh)
u = ufl.TrialFunction(V)
v = ufl.TestFunction(V)
f = fem.Constant(mesh, dtype(-6.0))
a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx
L = ufl.inner(f, v) * ufl.dx
L_fem = fem.form(L, dtype=dtype)
For the matrix-free solvers we also define a second linear form M as
the action of the bilinear form \(a\) on an
arbitrary Function ui. This linear
form is defined as
ui = fem.Function(V, dtype=dtype)
M = ufl.action(a, ui)
M_fem = fem.form(M, dtype=dtype)
Matrix-free conjugate gradient solver
The right hand side vector \(b - A x_{\rm bc}\) is the assembly of the
linear form \(L\) where the essential Dirichlet boundary conditions are
implemented using lifting. Since we want to avoid assembling the matrix
A, we compute the necessary matrix-vector product using the linear form
M explicitly.
# Apply lifting: b <- b - A * x_bc
b = fem.assemble_vector(L_fem)
ui.x.array[:] = 0.0
bc.set(ui.x.array, alpha=-1.0)
fem.assemble_vector(b.array, M_fem)
b.scatter_reverse(la.InsertMode.add)
Set BC dofs to zero on right hand side
bc.set(b.array, alpha=0.0)
b.scatter_forward()
To implement the matrix-free CG solver using DOLFINx vectors, we
define the function action_A to compute the matrix-vector product \(y
= A x\).
def action_A(x, y):
# Set coefficient vector of the linear form M and ensure it is
# updated across processes
ui.x.array[:] = x.array
ui.x.scatter_forward()
# Compute action of A on ui using the linear form M
y.array[:] = 0.0
fem.assemble_vector(y.array, M_fem)
y.scatter_reverse(la.InsertMode.add)
# Set BC dofs to zero
bc.set(y.array, alpha=0.0)
Basic conjugate gradient solver
Solves the problem A x = b, using the function action_A as the
operator, x as an initial guess of the solution, and b as the
right hand side vector. comm is the MPI Communicator, max_iter is
the maximum number of iterations, rtol is the relative tolerance.
def cg(comm, action_A, x: la.Vector, b: la.Vector, max_iter: int = 200, rtol: float = 1e-6):
rtol2 = rtol**2
nr = b.index_map.size_local
def _global_dot(comm, v0, v1):
# Only use the owned dofs in vector (up to nr)
return comm.allreduce(np.vdot(v0[:nr], v1[:nr]), MPI.SUM)
# Get initial y = A.x
y = la.vector(b.index_map, 1, dtype)
action_A(x, y)
# Copy residual to p
r = b.array - y.array
p = la.vector(b.index_map, 1, dtype)
p.array[:] = r
# Iterations of CG
rnorm0 = _global_dot(comm, r, r)
rnorm = rnorm0
for k in range(max_iter):
action_A(p, y)
alpha = rnorm / _global_dot(comm, p.array, y.array)
x.array[:] += alpha * p.array
r -= alpha * y.array
rnorm_new = _global_dot(comm, r, r)
beta = rnorm_new / rnorm
rnorm = rnorm_new
if comm.rank == 0:
print(k, np.sqrt(rnorm / rnorm0))
if rnorm / rnorm0 < rtol2:
x.scatter_forward()
return k
p.array[:] = beta * p.array + r
raise RuntimeError(f"Solver exceeded max iterations ({max_iter}).")
This matrix-free solver is now used to compute the finite element solution. The finite element solution’s approximation error as compared with the exact solution is measured in the \(L_2\)-norm.
rtol = 1e-6
u = fem.Function(V, dtype=dtype)
iter_cg1 = cg(mesh.comm, action_A, u.x, b, max_iter=200, rtol=rtol)
Set BC values in the solution vector
bc.set(u.x.array, alpha=1.0)
Print CG iteration number and error
def L2Norm(u):
val = fem.assemble_scalar(fem.form(ufl.inner(u, u) * ufl.dx, dtype=dtype))
return np.sqrt(comm.allreduce(val, op=MPI.SUM))
error_L2_cg1 = L2Norm(u - uD)
if mesh.comm.rank == 0:
print("Matrix-free CG solver using DOLFINx vectors:")
print(f"CG iterations until convergence: {iter_cg1}")
print(f"L2 approximation error: {error_L2_cg1:.4e}")