Basix 0.8.0

Home     Installation     Demos     C++ docs     Python docs

Loading...
Searching...
No Matches
math.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include "mdspan.hpp"
10#include <array>
11#include <cmath>
12#include <concepts>
13#include <span>
14#include <string>
15#include <utility>
16#include <vector>
17
18extern "C"
19{
20 void ssyevd_(char* jobz, char* uplo, int* n, float* a, int* lda, float* w,
21 float* work, int* lwork, int* iwork, int* liwork, int* info);
22 void dsyevd_(char* jobz, char* uplo, int* n, double* a, int* lda, double* w,
23 double* work, int* lwork, int* iwork, int* liwork, int* info);
24
25 void sgesv_(int* N, int* NRHS, float* A, int* LDA, int* IPIV, float* B,
26 int* LDB, int* INFO);
27 void dgesv_(int* N, int* NRHS, double* A, int* LDA, int* IPIV, double* B,
28 int* LDB, int* INFO);
29
30 void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha,
31 float* a, int* lda, float* b, int* ldb, float* beta, float* c,
32 int* ldc);
33 void dgemm_(char* transa, char* transb, int* m, int* n, int* k, double* alpha,
34 double* a, int* lda, double* b, int* ldb, double* beta, double* c,
35 int* ldc);
36
37 int sgetrf_(const int* m, const int* n, float* a, const int* lda, int* lpiv,
38 int* info);
39 int dgetrf_(const int* m, const int* n, double* a, const int* lda, int* lpiv,
40 int* info);
41}
42
47namespace basix::math
48{
49namespace impl
50{
55template <std::floating_point T>
56void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
57 std::span<const T> B, std::array<std::size_t, 2> Bshape,
58 std::span<T> C)
59{
60 static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
61
62 assert(Ashape[1] == Bshape[0]);
63 assert(C.size() == Ashape[0] * Bshape[1]);
64
65 int M = Ashape[0];
66 int N = Bshape[1];
67 int K = Ashape[1];
68
69 T alpha = 1;
70 T beta = 0;
71 int lda = K;
72 int ldb = N;
73 int ldc = N;
74 char trans = 'N';
75 if constexpr (std::is_same_v<T, float>)
76 {
77 sgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
78 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
79 }
80 else if constexpr (std::is_same_v<T, double>)
81 {
82 dgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
83 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
84 }
85}
86
87} // namespace impl
88
93template <typename U, typename V>
94std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
95outer(const U& u, const V& v)
96{
97 std::vector<typename U::value_type> result(u.size() * v.size());
98 for (std::size_t i = 0; i < u.size(); ++i)
99 for (std::size_t j = 0; j < v.size(); ++j)
100 result[i * v.size() + j] = u[i] * v[j];
101 return {std::move(result), {u.size(), v.size()}};
102}
103
108template <typename U, typename V>
109std::array<typename U::value_type, 3> cross(const U& u, const V& v)
110{
111 assert(u.size() == 3);
112 assert(v.size() == 3);
113 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
114 u[0] * v[1] - u[1] * v[0]};
115}
116
123template <std::floating_point T>
124std::pair<std::vector<T>, std::vector<T>> eigh(std::span<const T> A,
125 std::size_t n)
126{
127 // Copy A
128 std::vector<T> M(A.begin(), A.end());
129
130 // Allocate storage for eigenvalues
131 std::vector<T> w(n, 0);
132
133 int N = n;
134 char jobz = 'V'; // Compute eigenvalues and eigenvectors
135 char uplo = 'L'; // Lower
136 int ldA = n;
137 int lwork = -1;
138 int liwork = -1;
139 int info;
140 std::vector<T> work(1);
141 std::vector<int> iwork(1);
142
143 // Query optimal workspace size
144 if constexpr (std::is_same_v<T, float>)
145 {
146 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
147 iwork.data(), &liwork, &info);
148 }
149 else if constexpr (std::is_same_v<T, double>)
150 {
151 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
152 iwork.data(), &liwork, &info);
153 }
154
155 if (info != 0)
156 throw std::runtime_error("Could not find workspace size for syevd.");
157
158 // Solve eigen problem
159 work.resize(work[0]);
160 iwork.resize(iwork[0]);
161 lwork = work.size();
162 liwork = iwork.size();
163 if constexpr (std::is_same_v<T, float>)
164 {
165 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
166 iwork.data(), &liwork, &info);
167 }
168 else if constexpr (std::is_same_v<T, double>)
169 {
170 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
171 iwork.data(), &liwork, &info);
172 }
173 if (info != 0)
174 throw std::runtime_error("Eigenvalue computation did not converge.");
175
176 return {std::move(w), std::move(M)};
177}
178
183template <std::floating_point T>
184std::vector<T>
185solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
186 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
187 A,
188 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
189 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
190 B)
191{
192 namespace stdex
193 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
194
195 // Copy A and B to column-major storage
196 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
197 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
198 _A(A.extents()), _B(B.extents());
199 for (std::size_t i = 0; i < A.extent(0); ++i)
200 for (std::size_t j = 0; j < A.extent(1); ++j)
201 _A(i, j) = A(i, j);
202 for (std::size_t i = 0; i < B.extent(0); ++i)
203 for (std::size_t j = 0; j < B.extent(1); ++j)
204 _B(i, j) = B(i, j);
205
206 int N = _A.extent(0);
207 int nrhs = _B.extent(1);
208 int lda = _A.extent(0);
209 int ldb = _B.extent(0);
210 // Pivot indices that define the permutation matrix for the LU solver
211 std::vector<int> piv(N);
212 int info;
213 if constexpr (std::is_same_v<T, float>)
214 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
215 else if constexpr (std::is_same_v<T, double>)
216 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
217 if (info != 0)
218 throw std::runtime_error("Call to dgesv failed: " + std::to_string(info));
219
220 // Copy result to row-major storage
221 std::vector<T> rb(_B.extent(0) * _B.extent(1));
222 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
223 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
224 r(rb.data(), _B.extents());
225 for (std::size_t i = 0; i < _B.extent(0); ++i)
226 for (std::size_t j = 0; j < _B.extent(1); ++j)
227 r(i, j) = _B(i, j);
228
229 return rb;
230}
231
235template <std::floating_point T>
237 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
238 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
239 A)
240{
241 // Copy to column major matrix
242 namespace stdex
243 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
244 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
245 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
246 _A(A.extents());
247 for (std::size_t i = 0; i < A.extent(0); ++i)
248 for (std::size_t j = 0; j < A.extent(1); ++j)
249 _A(i, j) = A(i, j);
250
251 std::vector<T> B(A.extent(1), 1);
252 int N = _A.extent(0);
253 int nrhs = 1;
254 int lda = _A.extent(0);
255 int ldb = B.size();
256
257 // Pivot indices that define the permutation matrix for the LU solver
258 std::vector<int> piv(N);
259 int info;
260 if constexpr (std::is_same_v<T, float>)
261 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
262 else if constexpr (std::is_same_v<T, double>)
263 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
264
265 if (info < 0)
266 {
267 throw std::runtime_error("dgesv failed due to invalid value: "
268 + std::to_string(info));
269 }
270 else if (info > 0)
271 return true;
272 else
273 return false;
274}
275
281template <std::floating_point T>
282std::vector<std::size_t>
283transpose_lu(std::pair<std::vector<T>, std::array<std::size_t, 2>>& A)
284{
285 std::size_t dim = A.second[0];
286 assert(dim == A.second[1]);
287 int N = dim;
288 int info;
289 std::vector<int> lu_perm(dim);
290
291 // Comput LU decomposition of M
292 if constexpr (std::is_same_v<T, float>)
293 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
294 else if constexpr (std::is_same_v<T, double>)
295 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
296
297 if (info != 0)
298 {
299 throw std::runtime_error("LU decomposition failed: "
300 + std::to_string(info));
301 }
302
303 std::vector<std::size_t> perm(dim);
304 for (std::size_t i = 0; i < dim; ++i)
305 perm[i] = static_cast<std::size_t>(lu_perm[i] - 1);
306
307 return perm;
308}
309
315template <typename U, typename V, typename W>
316void dot(const U& A, const V& B, W&& C)
317{
318 assert(A.extent(1) == B.extent(0));
319 assert(C.extent(0) == A.extent(0));
320 assert(C.extent(1) == B.extent(1));
321 if (A.extent(0) * B.extent(1) * A.extent(1) < 512)
322 {
323 std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
324 for (std::size_t i = 0; i < A.extent(0); ++i)
325 for (std::size_t j = 0; j < B.extent(1); ++j)
326 for (std::size_t k = 0; k < A.extent(1); ++k)
327 C(i, j) += A(i, k) * B(k, j);
328 }
329 else
330 {
331 using T = typename std::decay_t<U>::value_type;
332 impl::dot_blas<T>(
333 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
334 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
335 std::span(C.data_handle(), C.size()));
336 }
337}
338
342template <std::floating_point T>
343std::vector<T> eye(std::size_t n)
344{
345 std::vector<T> I(n * n, 0);
346 namespace stdex
347 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
348 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
349 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
350 Iview(I.data(), n, n);
351 for (std::size_t i = 0; i < n; ++i)
352 Iview(i, i) = 1;
353 return I;
354}
355
360template <std::floating_point T>
362 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
363 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
364 wcoeffs,
365 std::size_t start = 0)
366{
367 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
368 {
369 T norm = 0;
370 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
371 norm += wcoeffs(i, k) * wcoeffs(i, k);
372
373 norm = std::sqrt(norm);
374 if (norm < 2 * std::numeric_limits<T>::epsilon())
375 {
376 throw std::runtime_error(
377 "Cannot orthogonalise the rows of a matrix with incomplete row rank");
378 }
379
380 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
381 wcoeffs(i, k) /= norm;
382
383 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
384 {
385 T a = 0;
386 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
387 a += wcoeffs(i, k) * wcoeffs(j, k);
388 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
389 wcoeffs(j, k) -= a * wcoeffs(i, k);
390 }
391 }
392}
393} // namespace basix::math
Mathematical functions.
Definition: math.h:48
void dot(const U &A, const V &B, W &&C)
Compute C = A * B.
Definition: math.h:316
std::vector< T > solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > A, MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > B)
Solve A X = B.
Definition: math.h:185
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Definition: math.h:109
bool is_singular(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > A)
Check if A is a singular matrix,.
Definition: math.h:236
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 > > &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition: math.h:283
void orthogonalise(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition: math.h:361
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Definition: math.h:124
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition: math.h:343
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition: math.h:95