22 void ssyevd_(
char* jobz,
char* uplo,
int* n,
float* a,
int* lda,
float* w,
23 float* work,
int* lwork,
int* iwork,
int* liwork,
int* info);
24 void dsyevd_(
char* jobz,
char* uplo,
int* n,
double* a,
int* lda,
double* w,
25 double* work,
int* lwork,
int* iwork,
int* liwork,
int* info);
27 void sgesv_(
int* N,
int* NRHS,
float* A,
int* LDA,
int* IPIV,
float* B,
29 void dgesv_(
int* N,
int* NRHS,
double* A,
int* LDA,
int* IPIV,
double* B,
32 void sgemm_(
char* transa,
char* transb,
int* m,
int* n,
int* k,
float* alpha,
33 float* a,
int* lda,
float* b,
int* ldb,
float* beta,
float* c,
35 void dgemm_(
char* transa,
char* transb,
int* m,
int* n,
int* k,
double* alpha,
36 double* a,
int* lda,
double* b,
int* ldb,
double* beta,
double* c,
39 int sgetrf_(
const int* m,
const int* n,
float* a,
const int* lda,
int* lpiv,
41 int dgetrf_(
const int* m,
const int* n,
double* a,
const int* lda,
int* lpiv,
58 template <std::
floating_po
int T>
59 void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
60 std::span<const T> B, std::array<std::size_t, 2> Bshape,
61 std::span<T> C, T alpha = 1, T beta = 0)
63 static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
65 assert(Ashape[1] == Bshape[0]);
66 assert(C.size() == Ashape[0] * Bshape[1]);
76 if constexpr (std::is_same_v<T, float>)
78 sgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
79 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
81 else if constexpr (std::is_same_v<T, double>)
83 dgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
84 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
94 template <
typename U,
typename V>
95 std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
98 std::vector<typename U::value_type> result(u.size() * v.size());
99 for (std::size_t i = 0; i < u.size(); ++i)
100 for (std::size_t j = 0; j < v.size(); ++j)
101 result[i * v.size() + j] = u[i] * v[j];
102 return {std::move(result), {u.size(), v.size()}};
109 template <
typename U,
typename V>
110 std::array<typename U::value_type, 3>
cross(
const U& u,
const V& v)
112 assert(u.size() == 3);
113 assert(v.size() == 3);
114 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
115 u[0] * v[1] - u[1] * v[0]};
125 template <std::
floating_po
int T>
126 std::pair<std::vector<T>, std::vector<T>>
eigh(std::span<const T> A,
130 std::vector<T> M(A.begin(), A.end());
133 std::vector<T> w(n, 0);
142 std::vector<T> work(1);
143 std::vector<int> iwork(1);
146 if constexpr (std::is_same_v<T, float>)
148 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
149 iwork.data(), &liwork, &info);
151 else if constexpr (std::is_same_v<T, double>)
153 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
154 iwork.data(), &liwork, &info);
158 throw std::runtime_error(
"Could not find workspace size for syevd.");
161 work.resize(work[0]);
162 iwork.resize(iwork[0]);
164 liwork = iwork.size();
165 if constexpr (std::is_same_v<T, float>)
167 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
168 iwork.data(), &liwork, &info);
170 else if constexpr (std::is_same_v<T, double>)
172 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
173 iwork.data(), &liwork, &info);
176 throw std::runtime_error(
"Eigenvalue computation did not converge.");
178 return {std::move(w), std::move(M)};
185 template <std::
floating_po
int T>
187 solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
188 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
190 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
191 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
195 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
198 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
199 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
200 _A(A.extents()), _B(B.extents());
201 for (std::size_t i = 0; i < A.extent(0); ++i)
202 for (std::size_t j = 0; j < A.extent(1); ++j)
204 for (std::size_t i = 0; i < B.extent(0); ++i)
205 for (std::size_t j = 0; j < B.extent(1); ++j)
208 int N = _A.extent(0);
209 int nrhs = _B.extent(1);
210 int lda = _A.extent(0);
211 int ldb = _B.extent(0);
214 std::vector<int> piv(N);
216 if constexpr (std::is_same_v<T, float>)
217 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
218 else if constexpr (std::is_same_v<T, double>)
219 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
221 throw std::runtime_error(
"Call to dgesv failed: " + std::to_string(info));
224 std::vector<T> rb(_B.extent(0) * _B.extent(1));
225 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
226 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
227 r(rb.data(), _B.extents());
228 for (std::size_t i = 0; i < _B.extent(0); ++i)
229 for (std::size_t j = 0; j < _B.extent(1); ++j)
238 template <std::
floating_po
int T>
240 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
241 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
246 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
247 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
248 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
250 for (std::size_t i = 0; i < A.extent(0); ++i)
251 for (std::size_t j = 0; j < A.extent(1); ++j)
254 std::vector<T> B(A.extent(1), 1);
255 int N = _A.extent(0);
257 int lda = _A.extent(0);
261 std::vector<int> piv(N);
263 if constexpr (std::is_same_v<T, float>)
264 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
265 else if constexpr (std::is_same_v<T, double>)
266 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
270 throw std::runtime_error(
"dgesv failed due to invalid value: "
271 + std::to_string(info));
284 template <std::
floating_po
int T>
285 std::vector<std::size_t>
288 std::size_t dim = A.second[0];
289 assert(dim == A.second[1]);
292 std::vector<int> lu_perm(dim);
295 if constexpr (std::is_same_v<T, float>)
296 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
297 else if constexpr (std::is_same_v<T, double>)
298 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
302 throw std::runtime_error(
"LU decomposition failed: "
303 + std::to_string(info));
306 std::vector<std::size_t> perm(dim);
307 for (std::size_t i = 0; i < dim; ++i)
308 perm[i] =
static_cast<std::size_t
>(lu_perm[i] - 1);
320 template <
typename U,
typename V,
typename W>
321 void dot(
const U& A,
const V& B, W&& C,
322 typename std::decay_t<U>::value_type alpha = 1,
323 typename std::decay_t<U>::value_type beta = 0)
325 using T =
typename std::decay_t<U>::value_type;
327 assert(A.extent(1) == B.extent(0));
328 assert(C.extent(0) == A.extent(0));
329 assert(C.extent(1) == B.extent(1));
330 if (A.extent(0) * B.extent(1) * A.extent(1) < 256)
332 for (std::size_t i = 0; i < A.extent(0); ++i)
334 for (std::size_t j = 0; j < B.extent(1); ++j)
339 for (std::size_t k = 0; k < A.extent(1); ++k)
340 _C += A(i, k) * B(k, j);
341 _C = alpha * _C + beta * C0;
347 static_assert(std::is_same_v<
typename std::decay_t<U>::layout_type,
348 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right>);
349 static_assert(std::is_same_v<
typename std::decay_t<V>::layout_type,
350 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right>);
351 static_assert(std::is_same_v<
typename std::decay_t<W>::layout_type,
352 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right>);
353 static_assert(std::is_same_v<
typename std::decay_t<V>::value_type, T>);
354 static_assert(std::is_same_v<
typename std::decay_t<W>::value_type, T>);
356 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
357 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
358 std::span(C.data_handle(), C.size()), alpha, beta);
365 template <std::
floating_po
int T>
366 std::vector<T>
eye(std::size_t n)
368 std::vector<T> I(n * n, 0);
370 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
371 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
372 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
373 Iview(I.data(), n, n);
374 for (std::size_t i = 0; i < n; ++i)
383 template <std::
floating_po
int T>
385 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
386 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
388 std::size_t start = 0)
390 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
393 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
394 norm += wcoeffs(i, k) * wcoeffs(i, k);
396 norm = std::sqrt(norm);
397 if (norm < 2 * std::numeric_limits<T>::epsilon())
399 throw std::runtime_error(
"Cannot orthogonalise the rows of a matrix "
400 "with incomplete row rank");
403 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
404 wcoeffs(i, k) /= norm;
406 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
409 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
410 a += wcoeffs(i, k) * wcoeffs(j, k);
411 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
412 wcoeffs(j, k) -= a * wcoeffs(i, k);
Mathematical functions.
Definition: math.h:50
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Compute the eigenvalues and eigenvectors of a square Hermitian matrix A.
Definition: math.h:126
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition: math.h:96
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Compute the cross product u x v.
Definition: math.h:110
std::vector< T > solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> A, MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> B)
Solve A X = B.
Definition: math.h:187
bool is_singular(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> A)
Check if A is a singular matrix.
Definition: math.h:239
void dot(const U &A, const V &B, W &&C, typename std::decay_t< U >::value_type alpha=1, typename std::decay_t< U >::value_type beta=0)
Compute C = alpha A * B + beta C.
Definition: math.h:321
void orthogonalise(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 >> wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition: math.h:384
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition: math.h:366
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 >> &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition: math.h:286