27 void dot_blas(
const std::span<const double>& A,
28 std::array<std::size_t, 2> Ashape,
29 const std::span<const double>& B,
30 std::array<std::size_t, 2> Bshape,
const std::span<double>& C);
37 template <
typename U,
typename V>
38 std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
41 std::vector<typename U::value_type> result(u.size() * v.size());
42 for (std::size_t i = 0; i < u.size(); ++i)
43 for (std::size_t j = 0; j < v.size(); ++j)
44 result[i * v.size() + j] = u[i] * v[j];
46 return {std::move(result), {u.size(), v.size()}};
53 template <
typename U,
typename V>
54 std::array<typename U::value_type, 3>
cross(
const U& u,
const V& v)
56 assert(u.size() == 3);
57 assert(v.size() == 3);
58 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
59 u[0] * v[1] - u[1] * v[0]};
68 std::pair<std::vector<double>, std::vector<double>>
69 eigh(
const std::span<const double>& A, std::size_t n);
76 solve(
const std::experimental::mdspan<
77 const double, std::experimental::dextents<std::size_t, 2>>& A,
78 const std::experimental::mdspan<
79 const double, std::experimental::dextents<std::size_t, 2>>& B);
85 const double, std::experimental::dextents<std::size_t, 2>>& A);
91 std::vector<std::size_t>
92 transpose_lu(std::pair<std::vector<double>, std::array<std::size_t, 2>>& A);
99 template <
typename U,
typename V,
typename W>
100 void dot(
const U& A,
const V& B, W&& C)
102 assert(A.extent(1) == B.extent(0));
103 assert(C.extent(0) == C.extent(0));
104 assert(C.extent(1) == B.extent(1));
105 if (A.extent(0) * B.extent(1) * A.extent(1) < 4096)
107 std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
108 for (std::size_t i = 0; i < A.extent(0); ++i)
109 for (std::size_t j = 0; j < B.extent(1); ++j)
110 for (std::size_t k = 0; k < A.extent(1); ++k)
111 C(i, j) += A(i, k) * B(k, j);
116 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
117 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
118 std::span(C.data_handle(), C.size()));
125 std::vector<double>
eye(std::size_t n);
std::pair< std::vector< double >, std::vector< double > > eigh(const std::span< const double > &A, std::size_t n)
Definition: math.cpp:55
void dot(const U &A, const V &B, W &&C)
Definition: math.h:100
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition: math.h:39
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Definition: math.h:54
std::vector< double > solve(const std::experimental::mdspan< const double, std::experimental::dextents< std::size_t, 2 >> &A, const std::experimental::mdspan< const double, std::experimental::dextents< std::size_t, 2 >> &B)
Definition: math.cpp:92
bool is_singular(const std::experimental::mdspan< const double, std::experimental::dextents< std::size_t, 2 >> &A)
Definition: math.cpp:130
std::vector< std::size_t > transpose_lu(std::pair< std::vector< double >, std::array< std::size_t, 2 >> &A)
Definition: math.cpp:162
std::vector< double > eye(std::size_t n)
Definition: math.cpp:186