12#include <basix/mdspan.hpp> 
   17namespace dolfinx::math
 
   24template <
typename U, 
typename V>
 
   25std::array<typename U::value_type, 3> cross(
const U& u, 
const V& v)
 
   27  assert(u.size() == 3);
 
   28  assert(v.size() == 3);
 
   29  return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
 
   30          u[0] * v[1] - u[1] * v[0]};
 
   36T difference_of_products(T a, T b, T c, T d) 
noexcept 
   39  T err = std::fma(-b, c, w);
 
   40  T diff = std::fma(a, d, -w);
 
   51auto det(
const T* A, std::array<std::size_t, 2> shape)
 
   53  assert(shape[0] == shape[1]);
 
   62    return difference_of_products(A[0], A[1], A[2], A[3]);
 
   67    T w0 = difference_of_products(A[3 + 1], A[3 + 2], A[3 * 2 + 1],
 
   69    T w1 = difference_of_products(A[3], A[3 + 2], A[3 * 2], A[3 * 2 + 2]);
 
   70    T w2 = difference_of_products(A[3], A[3 + 1], A[3 * 2], A[3 * 2 + 1]);
 
   71    T w3 = difference_of_products(A[0], A[1], w1, w0);
 
   72    T w4 = std::fma(A[2], w2, w3);
 
   76    throw std::runtime_error(
"math::det is not implemented for " 
   77                             + std::to_string(A[0]) + 
"x" + std::to_string(A[1])
 
   86template <
typename Matrix>
 
   89  static_assert(Matrix::rank() == 2, 
"Must be rank 2");
 
   90  assert(A.extent(0) == A.extent(1));
 
   92  using value_type = 
typename Matrix::value_type;
 
   93  const int nrows = A.extent(0);
 
   99    return difference_of_products(A(0, 0), A(0, 1), A(1, 0), A(1, 1));
 
  104    value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
 
  105    value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
 
  106    value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
 
  107    value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
 
  108    value_type w4 = std::fma(A(0, 2), w2, w3);
 
  112    throw std::runtime_error(
"math::det is not implemented for " 
  113                             + std::to_string(A.extent(0)) + 
"x" 
  114                             + std::to_string(A.extent(1)) + 
" matrices.");
 
  124template <
typename U, 
typename V>
 
  127  static_assert(U::rank() == 2, 
"Must be rank 2");
 
  128  static_assert(V::rank() == 2, 
"Must be rank 2");
 
  130  using value_type = 
typename U::value_type;
 
  131  const std::size_t nrows = A.extent(0);
 
  135    B(0, 0) = 1 / A(0, 0);
 
  139    value_type idet = 1. / det(A);
 
  140    B(0, 0) = idet * A(1, 1);
 
  141    B(0, 1) = -idet * A(0, 1);
 
  142    B(1, 0) = -idet * A(1, 0);
 
  143    B(1, 1) = idet * A(0, 0);
 
  148    value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
 
  149    value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
 
  150    value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
 
  151    value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
 
  152    value_type det = std::fma(A(0, 2), w2, w3);
 
  154    value_type idet = 1 / det;
 
  157    B(1, 0) = -w1 * idet;
 
  159    B(0, 1) = difference_of_products(A(0, 2), A(0, 1), A(2, 2), A(2, 1)) * idet;
 
  160    B(0, 2) = difference_of_products(A(0, 1), A(0, 2), A(1, 1), A(1, 2)) * idet;
 
  161    B(1, 1) = difference_of_products(A(0, 0), A(0, 2), A(2, 0), A(2, 2)) * idet;
 
  162    B(1, 2) = difference_of_products(A(1, 0), A(0, 0), A(1, 2), A(0, 2)) * idet;
 
  163    B(2, 1) = difference_of_products(A(2, 0), A(0, 0), A(2, 1), A(0, 1)) * idet;
 
  164    B(2, 2) = difference_of_products(A(0, 0), A(1, 0), A(0, 1), A(1, 1)) * idet;
 
  168    throw std::runtime_error(
"math::inv is not implemented for " 
  169                             + std::to_string(A.extent(0)) + 
"x" 
  170                             + std::to_string(A.extent(1)) + 
" matrices.");
 
  180template <
typename U, 
typename V, 
typename P>
 
  181void dot(U A, V B, P C, 
bool transpose = 
false)
 
  183  static_assert(U::rank() == 2, 
"Must be rank 2");
 
  184  static_assert(V::rank() == 2, 
"Must be rank 2");
 
  185  static_assert(P::rank() == 2, 
"Must be rank 2");
 
  189    assert(A.extent(0) == B.extent(1));
 
  190    for (std::size_t i = 0; i < A.extent(1); i++)
 
  191      for (std::size_t j = 0; j < B.extent(0); j++)
 
  192        for (std::size_t k = 0; k < A.extent(0); k++)
 
  193          C(i, j) += A(k, i) * B(j, k);
 
  197    assert(A.extent(1) == B.extent(0));
 
  198    for (std::size_t i = 0; i < A.extent(0); i++)
 
  199      for (std::size_t j = 0; j < B.extent(1); j++)
 
  200        for (std::size_t k = 0; k < A.extent(1); k++)
 
  201          C(i, j) += A(i, k) * B(k, j);
 
  211template <
typename U, 
typename V>
 
  214  static_assert(U::rank() == 2, 
"Must be rank 2");
 
  215  static_assert(V::rank() == 2, 
"Must be rank 2");
 
  217  assert(A.extent(0) > A.extent(1));
 
  218  assert(P.extent(1) == A.extent(0));
 
  219  assert(P.extent(0) == A.extent(1));
 
  220  using T = 
typename U::value_type;
 
  221  if (A.extent(1) == 2)
 
  223    std::array<T, 6> ATb;
 
  224    std::array<T, 4> ATAb, Invb;
 
  225    MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
 
  226        T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 3>>
 
  227        AT(ATb.data(), 2, 3);
 
  228    MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
 
  229        T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
 
  230        ATA(ATAb.data(), 2, 2);
 
  231    MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
 
  232        T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
 
  233        Inv(Invb.data(), 2, 2);
 
  235    for (std::size_t i = 0; i < AT.extent(0); ++i)
 
  236      for (std::size_t j = 0; j < AT.extent(1); ++j)
 
  239    std::ranges::fill(ATAb, 0.0);
 
  240    for (std::size_t i = 0; i < P.extent(0); ++i)
 
  241      for (std::size_t j = 0; j < P.extent(1); ++j)
 
  249  else if (A.extent(1) == 1)
 
  252    for (std::size_t i = 0; i < A.extent(0); ++i)
 
  253      for (std::size_t j = 0; j < A.extent(1); ++j)
 
  254        res += A(i, j) * A(i, j);
 
  256    for (std::size_t i = 0; i < A.extent(0); ++i)
 
  257      for (std::size_t j = 0; j < A.extent(1); ++j)
 
  258        P(j, i) = (1 / res) * A(i, j);
 
  262    throw std::runtime_error(
"math::pinv is not implemented for " 
  263                             + std::to_string(A.extent(0)) + 
"x" 
  264                             + std::to_string(A.extent(1)) + 
" matrices.");