Basix 0.9.0

Home     Installation     Demos     C++ docs     Python docs

math.h
1 // Copyright (C) 2021 Igor Baratta
2 //
3 // This file is part of DOLFINx (https://www.fenicsproject.org)
4 //
5 // SPDX-License-Identifier: LGPL-3.0-or-later
6 
7 #pragma once
8 
9 #include <array>
10 #include <cmath>
11 #include <concepts>
12 #include <span>
13 #include <stdexcept>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 
18 #include "mdspan.hpp"
19 
20 extern "C"
21 {
22  void ssyevd_(char* jobz, char* uplo, int* n, float* a, int* lda, float* w,
23  float* work, int* lwork, int* iwork, int* liwork, int* info);
24  void dsyevd_(char* jobz, char* uplo, int* n, double* a, int* lda, double* w,
25  double* work, int* lwork, int* iwork, int* liwork, int* info);
26 
27  void sgesv_(int* N, int* NRHS, float* A, int* LDA, int* IPIV, float* B,
28  int* LDB, int* INFO);
29  void dgesv_(int* N, int* NRHS, double* A, int* LDA, int* IPIV, double* B,
30  int* LDB, int* INFO);
31 
32  void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha,
33  float* a, int* lda, float* b, int* ldb, float* beta, float* c,
34  int* ldc);
35  void dgemm_(char* transa, char* transb, int* m, int* n, int* k, double* alpha,
36  double* a, int* lda, double* b, int* ldb, double* beta, double* c,
37  int* ldc);
38 
39  int sgetrf_(const int* m, const int* n, float* a, const int* lda, int* lpiv,
40  int* info);
41  int dgetrf_(const int* m, const int* n, double* a, const int* lda, int* lpiv,
42  int* info);
43 }
44 
49 namespace basix::math
50 {
51 namespace impl
52 {
57 template <std::floating_point T>
58 void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
59  std::span<const T> B, std::array<std::size_t, 2> Bshape,
60  std::span<T> C)
61 {
62  static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
63 
64  assert(Ashape[1] == Bshape[0]);
65  assert(C.size() == Ashape[0] * Bshape[1]);
66 
67  int M = Ashape[0];
68  int N = Bshape[1];
69  int K = Ashape[1];
70 
71  T alpha = 1;
72  T beta = 0;
73  int lda = K;
74  int ldb = N;
75  int ldc = N;
76  char trans = 'N';
77  if constexpr (std::is_same_v<T, float>)
78  {
79  sgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
80  const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
81  }
82  else if constexpr (std::is_same_v<T, double>)
83  {
84  dgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
85  const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
86  }
87 }
88 
89 } // namespace impl
90 
95 template <typename U, typename V>
96 std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
97 outer(const U& u, const V& v)
98 {
99  std::vector<typename U::value_type> result(u.size() * v.size());
100  for (std::size_t i = 0; i < u.size(); ++i)
101  for (std::size_t j = 0; j < v.size(); ++j)
102  result[i * v.size() + j] = u[i] * v[j];
103  return {std::move(result), {u.size(), v.size()}};
104 }
105 
110 template <typename U, typename V>
111 std::array<typename U::value_type, 3> cross(const U& u, const V& v)
112 {
113  assert(u.size() == 3);
114  assert(v.size() == 3);
115  return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
116  u[0] * v[1] - u[1] * v[0]};
117 }
118 
125 template <std::floating_point T>
126 std::pair<std::vector<T>, std::vector<T>> eigh(std::span<const T> A,
127  std::size_t n)
128 {
129  // Copy A
130  std::vector<T> M(A.begin(), A.end());
131 
132  // Allocate storage for eigenvalues
133  std::vector<T> w(n, 0);
134 
135  int N = n;
136  char jobz = 'V'; // Compute eigenvalues and eigenvectors
137  char uplo = 'L'; // Lower
138  int ldA = n;
139  int lwork = -1;
140  int liwork = -1;
141  int info;
142  std::vector<T> work(1);
143  std::vector<int> iwork(1);
144 
145  // Query optimal workspace size
146  if constexpr (std::is_same_v<T, float>)
147  {
148  ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
149  iwork.data(), &liwork, &info);
150  }
151  else if constexpr (std::is_same_v<T, double>)
152  {
153  dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
154  iwork.data(), &liwork, &info);
155  }
156 
157  if (info != 0)
158  throw std::runtime_error("Could not find workspace size for syevd.");
159 
160  // Solve eigen problem
161  work.resize(work[0]);
162  iwork.resize(iwork[0]);
163  lwork = work.size();
164  liwork = iwork.size();
165  if constexpr (std::is_same_v<T, float>)
166  {
167  ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
168  iwork.data(), &liwork, &info);
169  }
170  else if constexpr (std::is_same_v<T, double>)
171  {
172  dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
173  iwork.data(), &liwork, &info);
174  }
175  if (info != 0)
176  throw std::runtime_error("Eigenvalue computation did not converge.");
177 
178  return {std::move(w), std::move(M)};
179 }
180 
185 template <std::floating_point T>
186 std::vector<T>
187 solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
188  const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
189  A,
190  MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
191  const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
192  B)
193 {
194  namespace stdex
195  = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
196 
197  // Copy A and B to column-major storage
198  stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
199  MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
200  _A(A.extents()), _B(B.extents());
201  for (std::size_t i = 0; i < A.extent(0); ++i)
202  for (std::size_t j = 0; j < A.extent(1); ++j)
203  _A(i, j) = A(i, j);
204  for (std::size_t i = 0; i < B.extent(0); ++i)
205  for (std::size_t j = 0; j < B.extent(1); ++j)
206  _B(i, j) = B(i, j);
207 
208  int N = _A.extent(0);
209  int nrhs = _B.extent(1);
210  int lda = _A.extent(0);
211  int ldb = _B.extent(0);
212  // Pivot indices that define the permutation matrix for the LU solver
213  std::vector<int> piv(N);
214  int info;
215  if constexpr (std::is_same_v<T, float>)
216  sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
217  else if constexpr (std::is_same_v<T, double>)
218  dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
219  if (info != 0)
220  throw std::runtime_error("Call to dgesv failed: " + std::to_string(info));
221 
222  // Copy result to row-major storage
223  std::vector<T> rb(_B.extent(0) * _B.extent(1));
224  MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
225  T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
226  r(rb.data(), _B.extents());
227  for (std::size_t i = 0; i < _B.extent(0); ++i)
228  for (std::size_t j = 0; j < _B.extent(1); ++j)
229  r(i, j) = _B(i, j);
230 
231  return rb;
232 }
233 
237 template <std::floating_point T>
239  MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
240  const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
241  A)
242 {
243  // Copy to column major matrix
244  namespace stdex
245  = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
246  stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
247  MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
248  _A(A.extents());
249  for (std::size_t i = 0; i < A.extent(0); ++i)
250  for (std::size_t j = 0; j < A.extent(1); ++j)
251  _A(i, j) = A(i, j);
252 
253  std::vector<T> B(A.extent(1), 1);
254  int N = _A.extent(0);
255  int nrhs = 1;
256  int lda = _A.extent(0);
257  int ldb = B.size();
258 
259  // Pivot indices that define the permutation matrix for the LU solver
260  std::vector<int> piv(N);
261  int info;
262  if constexpr (std::is_same_v<T, float>)
263  sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
264  else if constexpr (std::is_same_v<T, double>)
265  dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
266 
267  if (info < 0)
268  {
269  throw std::runtime_error("dgesv failed due to invalid value: "
270  + std::to_string(info));
271  }
272  else if (info > 0)
273  return true;
274  else
275  return false;
276 }
277 
283 template <std::floating_point T>
284 std::vector<std::size_t>
285 transpose_lu(std::pair<std::vector<T>, std::array<std::size_t, 2>>& A)
286 {
287  std::size_t dim = A.second[0];
288  assert(dim == A.second[1]);
289  int N = dim;
290  int info;
291  std::vector<int> lu_perm(dim);
292 
293  // Comput LU decomposition of M
294  if constexpr (std::is_same_v<T, float>)
295  sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
296  else if constexpr (std::is_same_v<T, double>)
297  dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
298 
299  if (info != 0)
300  {
301  throw std::runtime_error("LU decomposition failed: "
302  + std::to_string(info));
303  }
304 
305  std::vector<std::size_t> perm(dim);
306  for (std::size_t i = 0; i < dim; ++i)
307  perm[i] = static_cast<std::size_t>(lu_perm[i] - 1);
308 
309  return perm;
310 }
311 
317 template <typename U, typename V, typename W>
318 void dot(const U& A, const V& B, W&& C)
319 {
320  assert(A.extent(1) == B.extent(0));
321  assert(C.extent(0) == A.extent(0));
322  assert(C.extent(1) == B.extent(1));
323  if (A.extent(0) * B.extent(1) * A.extent(1) < 512)
324  {
325  std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
326  for (std::size_t i = 0; i < A.extent(0); ++i)
327  for (std::size_t j = 0; j < B.extent(1); ++j)
328  for (std::size_t k = 0; k < A.extent(1); ++k)
329  C(i, j) += A(i, k) * B(k, j);
330  }
331  else
332  {
333  using T = typename std::decay_t<U>::value_type;
334  impl::dot_blas<T>(
335  std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
336  std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
337  std::span(C.data_handle(), C.size()));
338  }
339 }
340 
344 template <std::floating_point T>
345 std::vector<T> eye(std::size_t n)
346 {
347  std::vector<T> I(n * n, 0);
348  namespace stdex
349  = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
350  MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
351  T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
352  Iview(I.data(), n, n);
353  for (std::size_t i = 0; i < n; ++i)
354  Iview(i, i) = 1;
355  return I;
356 }
357 
362 template <std::floating_point T>
364  MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
365  T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
366  wcoeffs,
367  std::size_t start = 0)
368 {
369  for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
370  {
371  T norm = 0;
372  for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
373  norm += wcoeffs(i, k) * wcoeffs(i, k);
374 
375  norm = std::sqrt(norm);
376  if (norm < 2 * std::numeric_limits<T>::epsilon())
377  {
378  throw std::runtime_error(
379  "Cannot orthogonalise the rows of a matrix with incomplete row rank");
380  }
381 
382  for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
383  wcoeffs(i, k) /= norm;
384 
385  for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
386  {
387  T a = 0;
388  for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
389  a += wcoeffs(i, k) * wcoeffs(j, k);
390  for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
391  wcoeffs(j, k) -= a * wcoeffs(i, k);
392  }
393  }
394 }
395 } // namespace basix::math
Mathematical functions.
Definition: math.h:50
void dot(const U &A, const V &B, W &&C)
Compute C = A * B.
Definition: math.h:318
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Definition: math.h:126
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition: math.h:97
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Definition: math.h:111
std::vector< T > solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> A, MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> B)
Solve A X = B.
Definition: math.h:187
bool is_singular(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> A)
Check if A is a singular matrix,.
Definition: math.h:238
void orthogonalise(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition: math.h:363
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition: math.h:345
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 >> &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition: math.h:285