DOLFINx 0.9.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
Expression.h
1// Copyright (C) 2020-2021 Jack S. Hale and Michal Habera.
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include "Constant.h"
10#include "Function.h"
11#include <algorithm>
12#include <array>
13#include <concepts>
14#include <dolfinx/common/types.h>
15#include <dolfinx/mesh/Mesh.h>
16#include <functional>
17#include <span>
18#include <utility>
19#include <vector>
20
21namespace dolfinx::fem
22{
23template <dolfinx::scalar T>
24class Constant;
25
38template <dolfinx::scalar T,
39 std::floating_point U = dolfinx::scalar_value_type_t<T>>
40class Expression
41{
42public:
47 using scalar_type = T;
48
50 using geometry_type = U;
51
67 const std::vector<std::shared_ptr<
69 const std::vector<std::shared_ptr<const Constant<scalar_type>>>&
71 std::span<const geometry_type> X, std::array<std::size_t, 2> Xshape,
72 std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
73 const geometry_type*, const int*, const uint8_t*)>
74 fn,
75 const std::vector<int>& value_shape,
76 std::shared_ptr<const FunctionSpace<geometry_type>>
78 = nullptr)
79 : _coefficients(coefficients), _constants(constants),
80 _x_ref(std::vector<geometry_type>(X.begin(), X.end()), Xshape), _fn(fn),
81 _value_shape(value_shape),
82 _argument_function_space(argument_function_space)
83 {
84 for (auto& c : _coefficients)
85 {
86 assert(c);
87 if (c->function_space()->mesh()
88 != _coefficients.front()->function_space()->mesh())
89 {
90 throw std::runtime_error("Coefficients not all defined on same mesh.");
91 }
92 }
93 }
94
96 Expression(Expression&& e) = default;
97
99 virtual ~Expression() = default;
100
104 std::shared_ptr<const FunctionSpace<geometry_type>>
106 {
107 return _argument_function_space;
108 };
109
112 const std::vector<
113 std::shared_ptr<const Function<scalar_type, geometry_type>>>&
115 {
116 return _coefficients;
117 }
118
123 const std::vector<std::shared_ptr<const Constant<scalar_type>>>&
124 constants() const
125 {
126 return _constants;
127 }
128
134 std::vector<int> coefficient_offsets() const
135 {
136 std::vector<int> n{0};
137 for (auto& c : _coefficients)
138 {
139 if (!c)
140 throw std::runtime_error("Not all form coefficients have been set.");
141 n.push_back(n.back() + c->function_space()->element()->space_dimension());
142 }
143 return n;
144 }
145
157 std::span<const std::int32_t> entities,
158 std::span<scalar_type> values,
159 std::array<std::size_t, 2> vshape) const
160 {
161 std::size_t estride;
162 if (mesh.topology()->dim() == _x_ref.second[1])
163 estride = 1;
164 else if (mesh.topology()->dim() == _x_ref.second[1] + 1)
165 estride = 2;
166 else
167 throw std::runtime_error("Invalid dimension of evaluation points.");
168
169 // Prepare coefficients and constants
170 auto [coeffs, cstride] = pack_coefficients(*this, entities, estride);
171 std::vector<scalar_type> constant_data = pack_constants(*this);
172 auto fn = this->get_tabulate_expression();
173
174 // Prepare cell geometry
175 auto x_dofmap = mesh.geometry().dofmap();
176
177 // Get geometry data
178 auto& cmap = mesh.geometry().cmap();
179
180 std::size_t num_dofs_g = cmap.dim();
181 auto x_g = mesh.geometry().x();
182
183 // Create data structures used in evaluation
184 std::vector<geometry_type> coord_dofs(3 * num_dofs_g);
185
186 int num_argument_dofs = 1;
187 std::span<const std::uint32_t> cell_info;
188 std::function<void(std::span<scalar_type>, std::span<const std::uint32_t>,
189 std::int32_t, int)>
190 post_dof_transform
191 = [](std::span<scalar_type>, std::span<const std::uint32_t>,
192 std::int32_t, int)
193 {
194 // Do nothing
195 };
196
197 if (_argument_function_space)
198 {
199 num_argument_dofs
200 = _argument_function_space->dofmap()->element_dof_layout().num_dofs();
201 auto element = _argument_function_space->element();
202 num_argument_dofs *= _argument_function_space->dofmap()->bs();
203 assert(element);
204 if (element->needs_dof_transformations())
205 {
206 mesh.topology_mutable()->create_entity_permutations();
207 cell_info = std::span(mesh.topology()->get_cell_permutation_info());
208 post_dof_transform
209 = element->template dof_transformation_right_fn<scalar_type>(
211 }
212 }
213
214 // Create get entity index function
215 std::function<const std::int32_t*(std::span<const std::int32_t>,
216 std::size_t)>
217 get_entity_index
218 = []([[maybe_unused]] std::span<const std::int32_t> entities,
219 [[maybe_unused]] std::size_t idx) { return nullptr; };
220 if (estride == 2)
221 {
222 get_entity_index
223 = [](std::span<const std::int32_t> entities, std::size_t idx)
224 { return entities.data() + 2 * idx + 1; };
225 }
226
227 // Iterate over cells and 'assemble' into values
228 int size0 = _x_ref.second[0] * value_size();
229 std::vector<scalar_type> values_local(size0 * num_argument_dofs, 0);
230 for (std::size_t e = 0; e < entities.size() / estride; ++e)
231 {
232 std::int32_t entity = entities[e * estride];
233 auto x_dofs = MDSPAN_IMPL_STANDARD_NAMESPACE::submdspan(
234 x_dofmap, entity, MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent);
235 for (std::size_t i = 0; i < x_dofs.size(); ++i)
236 {
237 std::copy_n(std::next(x_g.begin(), 3 * x_dofs[i]), 3,
238 std::next(coord_dofs.begin(), 3 * i));
239 }
240
241 const scalar_type* coeff_cell = coeffs.data() + e * cstride;
242 const int* entity_index = get_entity_index(entities, e);
243
244 std::ranges::fill(values_local, 0);
245 _fn(values_local.data(), coeff_cell, constant_data.data(),
246 coord_dofs.data(), entity_index, nullptr);
247 post_dof_transform(values_local, cell_info, e, size0);
248 for (std::size_t j = 0; j < values_local.size(); ++j)
249 values[e * vshape[1] + j] = values_local[j];
250 }
251 }
252
255 const std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
256 const geometry_type*, const int*, const uint8_t*)>&
258 {
259 return _fn;
260 }
261
264 int value_size() const
265 {
266 return std::reduce(_value_shape.begin(), _value_shape.end(), 1,
267 std::multiplies{});
268 }
269
272 const std::vector<int>& value_shape() const { return _value_shape; }
273
276 std::pair<std::vector<geometry_type>, std::array<std::size_t, 2>> X() const
277 {
278 return _x_ref;
279 }
280
281private:
282 // Function space for Argument
283 std::shared_ptr<const FunctionSpace<geometry_type>> _argument_function_space;
284
285 // Coefficients associated with the Expression
286 std::vector<std::shared_ptr<const Function<scalar_type, geometry_type>>>
287 _coefficients;
288
289 // Constants associated with the Expression
290 std::vector<std::shared_ptr<const Constant<scalar_type>>> _constants;
291
292 // Function to evaluate the Expression
293 std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
294 const geometry_type*, const int*, const uint8_t*)>
295 _fn;
296
297 // Shape of the evaluated expression
298 std::vector<int> _value_shape;
299
300 // Evaluation points on reference cell. Synonymous with X in public
301 // interface.
302 std::pair<std::vector<geometry_type>, std::array<std::size_t, 2>> _x_ref;
303};
304} // namespace dolfinx::fem
Constant value which can be attached to a Form.
Definition Form.h:29
Represents a mathematical expression evaluated at a pre-defined set of points on the reference cell.
Definition Function.h:32
U geometry_type
Geometry type of the points.
Definition Expression.h:50
std::pair< std::vector< geometry_type >, std::array< std::size_t, 2 > > X() const
Evaluation points on the reference cell.
Definition Expression.h:276
const std::vector< std::shared_ptr< const Function< scalar_type, geometry_type > > > & coefficients() const
Get coefficients.
Definition Expression.h:114
Expression(const std::vector< std::shared_ptr< const Function< scalar_type, geometry_type > > > &coefficients, const std::vector< std::shared_ptr< const Constant< scalar_type > > > &constants, std::span< const geometry_type > X, std::array< std::size_t, 2 > Xshape, std::function< void(scalar_type *, const scalar_type *, const scalar_type *, const geometry_type *, const int *, const uint8_t *)> fn, const std::vector< int > &value_shape, std::shared_ptr< const FunctionSpace< geometry_type > > argument_function_space=nullptr)
Create an Expression.
Definition Expression.h:66
void eval(const mesh::Mesh< geometry_type > &mesh, std::span< const std::int32_t > entities, std::span< scalar_type > values, std::array< std::size_t, 2 > vshape) const
Evaluate Expression on cells or facets.
Definition Expression.h:156
const std::vector< std::shared_ptr< const Constant< scalar_type > > > & constants() const
Get constants.
Definition Expression.h:124
std::vector< int > coefficient_offsets() const
Offset for each coefficient expansion array on a cell.
Definition Expression.h:134
int value_size() const
Get value size.
Definition Expression.h:264
Expression(Expression &&e)=default
Move constructor.
const std::vector< int > & value_shape() const
Get value shape.
Definition Expression.h:272
virtual ~Expression()=default
Destructor.
T scalar_type
Scalar type.
Definition Expression.h:47
const std::function< void(scalar_type *, const scalar_type *, const scalar_type *, const geometry_type *, const int *, const uint8_t *)> & get_tabulate_expression() const
Get function for tabulate_expression.
Definition Expression.h:257
std::shared_ptr< const FunctionSpace< geometry_type > > argument_function_space() const
Get argument function space.
Definition Expression.h:105
This class represents a finite element function space defined by a mesh, a finite element,...
Definition vtk_utils.h:32
Definition XDMFFile.h:29
A Mesh consists of a set of connected and numbered mesh topological entities, and geometry data.
Definition Mesh.h:23
Definition types.h:20
Finite element method functionality.
Definition assemble_matrix_impl.h:26
std::vector< typename U::scalar_type > pack_constants(const U &u)
Pack constants of u into a single array ready for assembly.
Definition utils.h:1348
void pack_coefficients(const Form< T, U > &form, IntegralType integral_type, int id, std::span< T > c, int cstride)
Pack coefficients of a Form for a given integral type and domain id.
Definition utils.h:1062