11 #include <type_traits>
12 #include <xtensor/xfixed.hpp>
13 #include <xtensor/xtensor.hpp>
15 namespace dolfinx::math
22 template <
typename U,
typename V>
23 xt::xtensor_fixed<typename U::value_type, xt::xshape<3>> cross(
const U& u,
26 assert(u.size() == 3);
27 assert(v.size() == 3);
28 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
29 u[0] * v[1] - u[1] * v[0]};
35 T difference_of_products(T a, T b, T c, T d) noexcept
38 T err = std::fma(-b, c, w);
39 T diff = std::fma(a, d, -w);
47 template <
typename Matrix>
48 auto det(
const Matrix& A)
50 using value_type =
typename Matrix::value_type;
51 assert(A.shape(0) == A.shape(1));
52 assert(A.dimension() == 2);
54 const int nrows = A.shape(0);
60 return difference_of_products(A(0, 0), A(0, 1), A(1, 0), A(1, 1));
65 value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
66 value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
67 value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
68 value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
69 value_type w4 = std::fma(A(0, 2), w2, w3);
73 throw std::runtime_error(
"math::det is not implemented for "
74 + std::to_string(A.shape(0)) +
"x"
75 + std::to_string(A.shape(1)) +
" matrices.");
85 template <
typename U,
typename V>
86 void inv(
const U& A, V&& B)
88 using value_type =
typename U::value_type;
89 const std::size_t nrows = A.shape(0);
93 B(0, 0) = 1 / A(0, 0);
97 value_type idet = 1. / det(A);
98 B(0, 0) = idet * A(1, 1);
99 B(0, 1) = -idet * A(0, 1);
100 B(1, 0) = -idet * A(1, 0);
101 B(1, 1) = idet * A(0, 0);
106 value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
107 value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
108 value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
109 value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
110 value_type det = std::fma(A(0, 2), w2, w3);
112 value_type idet = 1 / det;
115 B(1, 0) = -w1 * idet;
117 B(0, 1) = difference_of_products(A(0, 2), A(0, 1), A(2, 2), A(2, 1)) * idet;
118 B(0, 2) = difference_of_products(A(0, 1), A(0, 2), A(1, 1), A(1, 2)) * idet;
119 B(1, 1) = difference_of_products(A(0, 0), A(0, 2), A(2, 0), A(2, 2)) * idet;
120 B(1, 2) = difference_of_products(A(1, 0), A(0, 0), A(1, 2), A(0, 2)) * idet;
121 B(2, 1) = difference_of_products(A(2, 0), A(0, 0), A(2, 1), A(0, 1)) * idet;
122 B(2, 2) = difference_of_products(A(0, 0), A(1, 0), A(0, 1), A(1, 1)) * idet;
126 throw std::runtime_error(
"math::inv is not implemented for "
127 + std::to_string(A.shape(0)) +
"x"
128 + std::to_string(A.shape(1)) +
" matrices.");
138 template <
typename U,
typename V,
typename P>
139 void dot(
const U& A,
const V& B, P&& C,
bool transpose =
false)
143 assert(A.shape(0) == B.shape(1));
144 for (std::size_t i = 0; i < A.shape(1); i++)
145 for (std::size_t j = 0; j < B.shape(0); j++)
146 for (std::size_t k = 0; k < A.shape(0); k++)
147 C(i, j) += A(k, i) * B(j, k);
151 assert(A.shape(1) == B.shape(0));
152 for (std::size_t i = 0; i < A.shape(0); i++)
153 for (std::size_t j = 0; j < B.shape(1); j++)
154 for (std::size_t k = 0; k < A.shape(1); k++)
155 C(i, j) += A(i, k) * B(k, j);
165 template <
typename U,
typename V>
166 void pinv(
const U& A, V&& P)
168 auto shape = A.shape();
169 assert(shape[0] > shape[1]);
170 assert(P.shape(1) == shape[0]);
171 assert(P.shape(0) == shape[1]);
172 using T =
typename U::value_type;
175 xt::xtensor_fixed<T, xt::xshape<2, 3>> AT;
176 xt::xtensor_fixed<T, xt::xshape<2, 2>> ATA;
177 xt::xtensor_fixed<T, xt::xshape<2, 2>> Inv;
178 AT = xt::transpose(A);
179 std::fill(ATA.begin(), ATA.end(), 0.0);
180 std::fill(P.begin(), P.end(), 0.0);
183 dolfinx::math::dot(AT, A, ATA);
184 dolfinx::math::inv(ATA, Inv);
185 dolfinx::math::dot(Inv, AT, P);
187 else if (shape[1] == 1)
189 auto res = std::transform_reduce(A.begin(), A.end(), 0., std::plus<T>(),
190 [](
const auto v) { return v * v; });
191 P = (1 / res) * xt::transpose(A);
195 throw std::runtime_error(
"math::pinv is not implemented for "
196 + std::to_string(A.shape(0)) +
"x"
197 + std::to_string(A.shape(1)) +
" matrices.");