11#include <basix/mdspan.hpp>
16namespace dolfinx::math
23template <
typename U,
typename V>
24std::array<typename U::value_type, 3> cross(
const U& u,
const V& v)
26 assert(u.size() == 3);
27 assert(v.size() == 3);
28 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
29 u[0] * v[1] - u[1] * v[0]};
35T difference_of_products(T a, T b, T c, T d)
noexcept
38 T err = std::fma(-b, c, w);
39 T diff = std::fma(a, d, -w);
50auto det(
const T* A, std::array<std::size_t, 2> shape)
52 assert(shape[0] == shape[1]);
61 return difference_of_products(A[0], A[1], A[2], A[3]);
66 T w0 = difference_of_products(A[3 + 1], A[3 + 2], A[3 * 2 + 1],
68 T w1 = difference_of_products(A[3], A[3 + 2], A[3 * 2], A[3 * 2 + 2]);
69 T w2 = difference_of_products(A[3], A[3 + 1], A[3 * 2], A[3 * 2 + 1]);
70 T w3 = difference_of_products(A[0], A[1], w1, w0);
71 T w4 = std::fma(A[2], w2, w3);
75 throw std::runtime_error(
"math::det is not implemented for "
76 + std::to_string(A[0]) +
"x" + std::to_string(A[1])
85template <
typename Matrix>
88 static_assert(Matrix::rank() == 2,
"Must be rank 2");
89 assert(A.extent(0) == A.extent(1));
91 using value_type =
typename Matrix::value_type;
92 const int nrows = A.extent(0);
98 return difference_of_products(A(0, 0), A(0, 1), A(1, 0), A(1, 1));
103 value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
104 value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
105 value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
106 value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
107 value_type w4 = std::fma(A(0, 2), w2, w3);
111 throw std::runtime_error(
"math::det is not implemented for "
112 + std::to_string(A.extent(0)) +
"x"
113 + std::to_string(A.extent(1)) +
" matrices.");
123template <
typename U,
typename V>
126 static_assert(U::rank() == 2,
"Must be rank 2");
127 static_assert(V::rank() == 2,
"Must be rank 2");
129 using value_type =
typename U::value_type;
130 const std::size_t nrows = A.extent(0);
134 B(0, 0) = 1 / A(0, 0);
138 value_type idet = 1. / det(A);
139 B(0, 0) = idet * A(1, 1);
140 B(0, 1) = -idet * A(0, 1);
141 B(1, 0) = -idet * A(1, 0);
142 B(1, 1) = idet * A(0, 0);
147 value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
148 value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
149 value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
150 value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
151 value_type det = std::fma(A(0, 2), w2, w3);
153 value_type idet = 1 / det;
156 B(1, 0) = -w1 * idet;
158 B(0, 1) = difference_of_products(A(0, 2), A(0, 1), A(2, 2), A(2, 1)) * idet;
159 B(0, 2) = difference_of_products(A(0, 1), A(0, 2), A(1, 1), A(1, 2)) * idet;
160 B(1, 1) = difference_of_products(A(0, 0), A(0, 2), A(2, 0), A(2, 2)) * idet;
161 B(1, 2) = difference_of_products(A(1, 0), A(0, 0), A(1, 2), A(0, 2)) * idet;
162 B(2, 1) = difference_of_products(A(2, 0), A(0, 0), A(2, 1), A(0, 1)) * idet;
163 B(2, 2) = difference_of_products(A(0, 0), A(1, 0), A(0, 1), A(1, 1)) * idet;
167 throw std::runtime_error(
"math::inv is not implemented for "
168 + std::to_string(A.extent(0)) +
"x"
169 + std::to_string(A.extent(1)) +
" matrices.");
179template <
typename U,
typename V,
typename P>
180void dot(U A, V B, P C,
bool transpose =
false)
182 static_assert(U::rank() == 2,
"Must be rank 2");
183 static_assert(V::rank() == 2,
"Must be rank 2");
184 static_assert(P::rank() == 2,
"Must be rank 2");
188 assert(A.extent(0) == B.extent(1));
189 for (std::size_t i = 0; i < A.extent(1); i++)
190 for (std::size_t j = 0; j < B.extent(0); j++)
191 for (std::size_t k = 0; k < A.extent(0); k++)
192 C(i, j) += A(k, i) * B(j, k);
196 assert(A.extent(1) == B.extent(0));
197 for (std::size_t i = 0; i < A.extent(0); i++)
198 for (std::size_t j = 0; j < B.extent(1); j++)
199 for (std::size_t k = 0; k < A.extent(1); k++)
200 C(i, j) += A(i, k) * B(k, j);
210template <
typename U,
typename V>
213 static_assert(U::rank() == 2,
"Must be rank 2");
214 static_assert(V::rank() == 2,
"Must be rank 2");
216 assert(A.extent(0) > A.extent(1));
217 assert(P.extent(1) == A.extent(0));
218 assert(P.extent(0) == A.extent(1));
219 using T =
typename U::value_type;
220 if (A.extent(1) == 2)
222 std::array<T, 6> ATb;
223 std::array<T, 4> ATAb, Invb;
224 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
225 T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 3>>
226 AT(ATb.data(), 2, 3);
227 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
228 T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
229 ATA(ATAb.data(), 2, 2);
230 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
231 T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
232 Inv(Invb.data(), 2, 2);
234 for (std::size_t i = 0; i < AT.extent(0); ++i)
235 for (std::size_t j = 0; j < AT.extent(1); ++j)
238 std::fill(ATAb.begin(), ATAb.end(), 0.0);
239 for (std::size_t i = 0; i < P.extent(0); ++i)
240 for (std::size_t j = 0; j < P.extent(1); ++j)
248 else if (A.extent(1) == 1)
251 for (std::size_t i = 0; i < A.extent(0); ++i)
252 for (std::size_t j = 0; j < A.extent(1); ++j)
253 res += A(i, j) * A(i, j);
255 for (std::size_t i = 0; i < A.extent(0); ++i)
256 for (std::size_t j = 0; j < A.extent(1); ++j)
257 P(j, i) = (1 / res) * A(i, j);
261 throw std::runtime_error(
"math::pinv is not implemented for "
262 + std::to_string(A.extent(0)) +
"x"
263 + std::to_string(A.extent(1)) +
" matrices.");