23 void ssyevd_(
char* jobz,
char* uplo,
int* n,
float* a,
int* lda,
float* w,
24 float* work,
int* lwork,
int* iwork,
int* liwork,
int* info);
25 void dsyevd_(
char* jobz,
char* uplo,
int* n,
double* a,
int* lda,
double* w,
26 double* work,
int* lwork,
int* iwork,
int* liwork,
int* info);
28 void sgesv_(
int* N,
int* NRHS,
float* A,
int* LDA,
int* IPIV,
float* B,
30 void dgesv_(
int* N,
int* NRHS,
double* A,
int* LDA,
int* IPIV,
double* B,
33 void sgemm_(
char* transa,
char* transb,
int* m,
int* n,
int* k,
float* alpha,
34 float* a,
int* lda,
float* b,
int* ldb,
float* beta,
float* c,
36 void dgemm_(
char* transa,
char* transb,
int* m,
int* n,
int* k,
double* alpha,
37 double* a,
int* lda,
double* b,
int* ldb,
double* beta,
double* c,
40 int sgetrf_(
const int* m,
const int* n,
float* a,
const int* lda,
int* lpiv,
42 int dgetrf_(
const int* m,
const int* n,
double* a,
const int* lda,
int* lpiv,
59 template <std::
floating_po
int T>
60 void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
61 std::span<const T> B, std::array<std::size_t, 2> Bshape,
62 std::span<T> C, T alpha = 1, T beta = 0)
64 static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
66 assert(Ashape[1] == Bshape[0]);
67 assert(C.size() == Ashape[0] * Bshape[1]);
77 if constexpr (std::is_same_v<T, float>)
79 sgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
80 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
82 else if constexpr (std::is_same_v<T, double>)
84 dgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
85 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
95 template <
typename U,
typename V>
96 std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
99 std::vector<typename U::value_type> result(u.size() * v.size());
100 for (std::size_t i = 0; i < u.size(); ++i)
101 for (std::size_t j = 0; j < v.size(); ++j)
102 result[i * v.size() + j] = u[i] * v[j];
103 return {std::move(result), {u.size(), v.size()}};
110 template <
typename U,
typename V>
111 std::array<typename U::value_type, 3>
cross(
const U& u,
const V& v)
113 assert(u.size() == 3);
114 assert(v.size() == 3);
115 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
116 u[0] * v[1] - u[1] * v[0]};
126 template <std::
floating_po
int T>
127 std::pair<std::vector<T>, std::vector<T>>
eigh(std::span<const T> A,
131 std::vector<T> M(A.begin(), A.end());
134 std::vector<T> w(n, 0);
143 std::vector<T> work(1);
144 std::vector<int> iwork(1);
147 if constexpr (std::is_same_v<T, float>)
149 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
150 iwork.data(), &liwork, &info);
152 else if constexpr (std::is_same_v<T, double>)
154 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
155 iwork.data(), &liwork, &info);
159 throw std::runtime_error(
"Could not find workspace size for syevd.");
162 work.resize(work[0]);
163 iwork.resize(iwork[0]);
165 liwork = iwork.size();
166 if constexpr (std::is_same_v<T, float>)
168 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
169 iwork.data(), &liwork, &info);
171 else if constexpr (std::is_same_v<T, double>)
173 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
174 iwork.data(), &liwork, &info);
177 throw std::runtime_error(
"Eigenvalue computation did not converge.");
179 return {std::move(w), std::move(M)};
186 template <std::
floating_po
int T>
187 std::vector<T>
solve(md::mdspan<
const T, md::dextents<std::size_t, 2>> A,
188 md::mdspan<
const T, md::dextents<std::size_t, 2>> B)
191 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left> _A(
194 for (std::size_t i = 0; i < A.extent(0); ++i)
195 for (std::size_t j = 0; j < A.extent(1); ++j)
197 for (std::size_t i = 0; i < B.extent(0); ++i)
198 for (std::size_t j = 0; j < B.extent(1); ++j)
201 int N = _A.extent(0);
202 int nrhs = _B.extent(1);
203 int lda = _A.extent(0);
204 int ldb = _B.extent(0);
207 std::vector<int> piv(N);
209 if constexpr (std::is_same_v<T, float>)
210 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
211 else if constexpr (std::is_same_v<T, double>)
212 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
214 throw std::runtime_error(
"Call to dgesv failed: " + std::to_string(info));
217 std::vector<T> rb(_B.extent(0) * _B.extent(1));
218 md::mdspan<T, md::dextents<std::size_t, 2>> r(rb.data(), _B.extents());
219 for (std::size_t i = 0; i < _B.extent(0); ++i)
220 for (std::size_t j = 0; j < _B.extent(1); ++j)
229 template <std::
floating_po
int T>
230 bool is_singular(md::mdspan<
const T, md::dextents<std::size_t, 2>> A)
233 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left> _A(
235 for (std::size_t i = 0; i < A.extent(0); ++i)
236 for (std::size_t j = 0; j < A.extent(1); ++j)
239 std::vector<T> B(A.extent(1), 1);
240 int N = _A.extent(0);
242 int lda = _A.extent(0);
246 std::vector<int> piv(N);
248 if constexpr (std::is_same_v<T, float>)
249 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
250 else if constexpr (std::is_same_v<T, double>)
251 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
255 throw std::runtime_error(
"dgesv failed due to invalid value: "
256 + std::to_string(info));
269 template <std::
floating_po
int T>
270 std::vector<std::size_t>
273 std::size_t dim = A.second[0];
274 assert(dim == A.second[1]);
277 std::vector<int> lu_perm(dim);
280 if constexpr (std::is_same_v<T, float>)
281 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
282 else if constexpr (std::is_same_v<T, double>)
283 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
287 throw std::runtime_error(
"LU decomposition failed: "
288 + std::to_string(info));
291 std::vector<std::size_t> perm(dim);
292 for (std::size_t i = 0; i < dim; ++i)
293 perm[i] =
static_cast<std::size_t
>(lu_perm[i] - 1);
305 template <
typename U,
typename V,
typename W>
306 void dot(
const U& A,
const V& B, W&& C,
307 typename std::decay_t<U>::value_type alpha = 1,
308 typename std::decay_t<U>::value_type beta = 0)
310 using T =
typename std::decay_t<U>::value_type;
312 assert(A.extent(1) == B.extent(0));
313 assert(C.extent(0) == A.extent(0));
314 assert(C.extent(1) == B.extent(1));
315 if (A.extent(0) * B.extent(1) * A.extent(1) < 256)
317 for (std::size_t i = 0; i < A.extent(0); ++i)
319 for (std::size_t j = 0; j < B.extent(1); ++j)
324 for (std::size_t k = 0; k < A.extent(1); ++k)
325 _C += A(i, k) * B(k, j);
326 _C = alpha * _C + beta * C0;
332 static_assert(std::is_same_v<
typename std::decay_t<U>::layout_type,
334 static_assert(std::is_same_v<
typename std::decay_t<V>::layout_type,
336 static_assert(std::is_same_v<
typename std::decay_t<W>::layout_type,
338 static_assert(std::is_same_v<
typename std::decay_t<V>::value_type, T>);
339 static_assert(std::is_same_v<
typename std::decay_t<W>::value_type, T>);
341 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
342 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
343 std::span(C.data_handle(), C.size()), alpha, beta);
350 template <std::
floating_po
int T>
351 std::vector<T>
eye(std::size_t n)
353 std::vector<T> I(n * n, 0);
354 md::mdspan<T, md::dextents<std::size_t, 2>> Iview(I.data(), n, n);
355 for (std::size_t i = 0; i < n; ++i)
364 template <std::
floating_po
int T>
366 std::size_t start = 0)
368 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
371 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
372 norm += wcoeffs(i, k) * wcoeffs(i, k);
374 norm = std::sqrt(norm);
375 if (norm < 2 * std::numeric_limits<T>::epsilon())
377 throw std::runtime_error(
"Cannot orthogonalise the rows of a matrix "
378 "with incomplete row rank");
381 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
382 wcoeffs(i, k) /= norm;
384 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
387 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
388 a += wcoeffs(i, k) * wcoeffs(j, k);
389 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
390 wcoeffs(j, k) -= a * wcoeffs(i, k);
Mathematical functions.
Definition: math.h:51
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Compute the eigenvalues and eigenvectors of a square Hermitian matrix A.
Definition: math.h:127
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition: math.h:97
std::vector< T > solve(md::mdspan< const T, md::dextents< std::size_t, 2 >> A, md::mdspan< const T, md::dextents< std::size_t, 2 >> B)
Solve A X = B.
Definition: math.h:187
void orthogonalise(md::mdspan< T, md::dextents< std::size_t, 2 >> wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition: math.h:365
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Compute the cross product u x v.
Definition: math.h:111
void dot(const U &A, const V &B, W &&C, typename std::decay_t< U >::value_type alpha=1, typename std::decay_t< U >::value_type beta=0)
Compute C = alpha A * B + beta C.
Definition: math.h:306
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition: math.h:351
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 >> &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition: math.h:271
bool is_singular(md::mdspan< const T, md::dextents< std::size_t, 2 >> A)
Check if A is a singular matrix.
Definition: math.h:230