11#include <basix/mdspan.hpp> 
   16namespace dolfinx::math
 
   23template <
typename U, 
typename V>
 
   24std::array<typename U::value_type, 3> cross(
const U& u, 
const V& v)
 
   26  assert(u.size() == 3);
 
   27  assert(v.size() == 3);
 
   28  return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
 
   29          u[0] * v[1] - u[1] * v[0]};
 
   35T difference_of_products(T a, T b, T c, T d) 
noexcept 
   38  T err = std::fma(-b, c, w);
 
   39  T diff = std::fma(a, d, -w);
 
   50auto det(
const T* A, std::array<std::size_t, 2> shape)
 
   52  assert(shape[0] == shape[1]);
 
   61    return difference_of_products(A[0], A[1], A[2], A[3]);
 
   66    T w0 = difference_of_products(A[3 + 1], A[3 + 2], A[3 * 2 + 1],
 
   68    T w1 = difference_of_products(A[3], A[3 + 2], A[3 * 2], A[3 * 2 + 2]);
 
   69    T w2 = difference_of_products(A[3], A[3 + 1], A[3 * 2], A[3 * 2 + 1]);
 
   70    T w3 = difference_of_products(A[0], A[1], w1, w0);
 
   71    T w4 = std::fma(A[2], w2, w3);
 
   75    throw std::runtime_error(
"math::det is not implemented for " 
   76                             + std::to_string(A[0]) + 
"x" + std::to_string(A[1])
 
   85template <
typename Matrix>
 
   88  static_assert(Matrix::rank() == 2, 
"Must be rank 2");
 
   89  assert(A.extent(0) == A.extent(1));
 
   91  using value_type = 
typename Matrix::value_type;
 
   92  const int nrows = A.extent(0);
 
   98    return difference_of_products(A(0, 0), A(0, 1), A(1, 0), A(1, 1));
 
  103    value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
 
  104    value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
 
  105    value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
 
  106    value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
 
  107    value_type w4 = std::fma(A(0, 2), w2, w3);
 
  111    throw std::runtime_error(
"math::det is not implemented for " 
  112                             + std::to_string(A.extent(0)) + 
"x" 
  113                             + std::to_string(A.extent(1)) + 
" matrices.");
 
  123template <
typename U, 
typename V>
 
  126  static_assert(U::rank() == 2, 
"Must be rank 2");
 
  127  static_assert(V::rank() == 2, 
"Must be rank 2");
 
  129  using value_type = 
typename U::value_type;
 
  130  const std::size_t nrows = A.extent(0);
 
  134    B(0, 0) = 1 / A(0, 0);
 
  138    value_type idet = 1. / det(A);
 
  139    B(0, 0) = idet * A(1, 1);
 
  140    B(0, 1) = -idet * A(0, 1);
 
  141    B(1, 0) = -idet * A(1, 0);
 
  142    B(1, 1) = idet * A(0, 0);
 
  147    value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
 
  148    value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
 
  149    value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
 
  150    value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
 
  151    value_type det = std::fma(A(0, 2), w2, w3);
 
  153    value_type idet = 1 / det;
 
  156    B(1, 0) = -w1 * idet;
 
  158    B(0, 1) = difference_of_products(A(0, 2), A(0, 1), A(2, 2), A(2, 1)) * idet;
 
  159    B(0, 2) = difference_of_products(A(0, 1), A(0, 2), A(1, 1), A(1, 2)) * idet;
 
  160    B(1, 1) = difference_of_products(A(0, 0), A(0, 2), A(2, 0), A(2, 2)) * idet;
 
  161    B(1, 2) = difference_of_products(A(1, 0), A(0, 0), A(1, 2), A(0, 2)) * idet;
 
  162    B(2, 1) = difference_of_products(A(2, 0), A(0, 0), A(2, 1), A(0, 1)) * idet;
 
  163    B(2, 2) = difference_of_products(A(0, 0), A(1, 0), A(0, 1), A(1, 1)) * idet;
 
  167    throw std::runtime_error(
"math::inv is not implemented for " 
  168                             + std::to_string(A.extent(0)) + 
"x" 
  169                             + std::to_string(A.extent(1)) + 
" matrices.");
 
  179template <
typename U, 
typename V, 
typename P>
 
  180void dot(U A, V B, P C, 
bool transpose = 
false)
 
  182  static_assert(U::rank() == 2, 
"Must be rank 2");
 
  183  static_assert(V::rank() == 2, 
"Must be rank 2");
 
  184  static_assert(P::rank() == 2, 
"Must be rank 2");
 
  188    assert(A.extent(0) == B.extent(1));
 
  189    for (std::size_t i = 0; i < A.extent(1); i++)
 
  190      for (std::size_t j = 0; j < B.extent(0); j++)
 
  191        for (std::size_t k = 0; k < A.extent(0); k++)
 
  192          C(i, j) += A(k, i) * B(j, k);
 
  196    assert(A.extent(1) == B.extent(0));
 
  197    for (std::size_t i = 0; i < A.extent(0); i++)
 
  198      for (std::size_t j = 0; j < B.extent(1); j++)
 
  199        for (std::size_t k = 0; k < A.extent(1); k++)
 
  200          C(i, j) += A(i, k) * B(k, j);
 
  210template <
typename U, 
typename V>
 
  213  static_assert(U::rank() == 2, 
"Must be rank 2");
 
  214  static_assert(V::rank() == 2, 
"Must be rank 2");
 
  216  assert(A.extent(0) > A.extent(1));
 
  217  assert(P.extent(1) == A.extent(0));
 
  218  assert(P.extent(0) == A.extent(1));
 
  219  using T = 
typename U::value_type;
 
  220  if (A.extent(1) == 2)
 
  222    std::array<T, 6> ATb;
 
  223    std::array<T, 4> ATAb, Invb;
 
  224    MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
 
  225        T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 3>>
 
  226        AT(ATb.data(), 2, 3);
 
  227    MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
 
  228        T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
 
  229        ATA(ATAb.data(), 2, 2);
 
  230    MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
 
  231        T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
 
  232        Inv(Invb.data(), 2, 2);
 
  234    for (std::size_t i = 0; i < AT.extent(0); ++i)
 
  235      for (std::size_t j = 0; j < AT.extent(1); ++j)
 
  238    std::fill(ATAb.begin(), ATAb.end(), 0.0);
 
  239    for (std::size_t i = 0; i < P.extent(0); ++i)
 
  240      for (std::size_t j = 0; j < P.extent(1); ++j)
 
  248  else if (A.extent(1) == 1)
 
  251    for (std::size_t i = 0; i < A.extent(0); ++i)
 
  252      for (std::size_t j = 0; j < A.extent(1); ++j)
 
  253        res += A(i, j) * A(i, j);
 
  255    for (std::size_t i = 0; i < A.extent(0); ++i)
 
  256      for (std::size_t j = 0; j < A.extent(1); ++j)
 
  257        P(j, i) = (1 / res) * A(i, j);
 
  261    throw std::runtime_error(
"math::pinv is not implemented for " 
  262                             + std::to_string(A.extent(0)) + 
"x" 
  263                             + std::to_string(A.extent(1)) + 
" matrices.");