DOLFINx 0.10.0.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
math.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include "types.h"
10#include <algorithm>
11#include <array>
12#include <basix/mdspan.hpp>
13#include <cmath>
14#include <string>
15#include <type_traits>
16
17namespace dolfinx::math
18{
19
24template <typename U, typename V>
25std::array<typename U::value_type, 3> cross(const U& u, const V& v)
26{
27 assert(u.size() == 3);
28 assert(v.size() == 3);
29 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
30 u[0] * v[1] - u[1] * v[0]};
31}
32
35template <typename T>
36T difference_of_products(T a, T b, T c, T d) noexcept
37{
38 T w = b * c;
39 T err = std::fma(-b, c, w);
40 T diff = std::fma(a, d, -w);
41 return diff + err;
42}
43
50template <typename T>
51auto det(const T* A, std::array<std::size_t, 2> shape)
52{
53 assert(shape[0] == shape[1]);
54
55 // const int nrows = shape[0];
56 switch (shape[0])
57 {
58 case 1:
59 return *A;
60 case 2:
61 /* A(0, 0), A(0, 1), A(1, 0), A(1, 1) */
62 return difference_of_products(A[0], A[1], A[2], A[3]);
63 case 3:
64 {
65 // Leibniz formula combined with Kahan’s method for accurate
66 // computation of 3 x 3 determinants
67 T w0 = difference_of_products(A[3 + 1], A[3 + 2], A[3 * 2 + 1],
68 A[2 * 3 + 2]);
69 T w1 = difference_of_products(A[3], A[3 + 2], A[3 * 2], A[3 * 2 + 2]);
70 T w2 = difference_of_products(A[3], A[3 + 1], A[3 * 2], A[3 * 2 + 1]);
71 T w3 = difference_of_products(A[0], A[1], w1, w0);
72 T w4 = std::fma(A[2], w2, w3);
73 return w4;
74 }
75 default:
76 throw std::runtime_error("math::det is not implemented for "
77 + std::to_string(A[0]) + "x" + std::to_string(A[1])
78 + " matrices.");
79 }
80}
81
86template <typename Matrix>
87auto det(Matrix A)
88{
89 static_assert(Matrix::rank() == 2, "Must be rank 2");
90 assert(A.extent(0) == A.extent(1));
91
92 using value_type = typename Matrix::value_type;
93 const int nrows = A.extent(0);
94 switch (nrows)
95 {
96 case 1:
97 return A(0, 0);
98 case 2:
99 return difference_of_products(A(0, 0), A(0, 1), A(1, 0), A(1, 1));
100 case 3:
101 {
102 // Leibniz formula combined with Kahan’s method for accurate
103 // computation of 3 x 3 determinants
104 value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
105 value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
106 value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
107 value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
108 value_type w4 = std::fma(A(0, 2), w2, w3);
109 return w4;
110 }
111 default:
112 throw std::runtime_error("math::det is not implemented for "
113 + std::to_string(A.extent(0)) + "x"
114 + std::to_string(A.extent(1)) + " matrices.");
115 }
116}
117
124template <typename U, typename V>
125void inv(U A, V B)
126{
127 static_assert(U::rank() == 2, "Must be rank 2");
128 static_assert(V::rank() == 2, "Must be rank 2");
129
130 using value_type = typename U::value_type;
131 const std::size_t nrows = A.extent(0);
132 switch (nrows)
133 {
134 case 1:
135 B(0, 0) = 1 / A(0, 0);
136 break;
137 case 2:
138 {
139 value_type idet = 1. / det(A);
140 B(0, 0) = idet * A(1, 1);
141 B(0, 1) = -idet * A(0, 1);
142 B(1, 0) = -idet * A(1, 0);
143 B(1, 1) = idet * A(0, 0);
144 break;
145 }
146 case 3:
147 {
148 value_type w0 = difference_of_products(A(1, 1), A(1, 2), A(2, 1), A(2, 2));
149 value_type w1 = difference_of_products(A(1, 0), A(1, 2), A(2, 0), A(2, 2));
150 value_type w2 = difference_of_products(A(1, 0), A(1, 1), A(2, 0), A(2, 1));
151 value_type w3 = difference_of_products(A(0, 0), A(0, 1), w1, w0);
152 value_type det = std::fma(A(0, 2), w2, w3);
153 assert(det != 0.);
154 value_type idet = 1 / det;
155
156 B(0, 0) = w0 * idet;
157 B(1, 0) = -w1 * idet;
158 B(2, 0) = w2 * idet;
159 B(0, 1) = difference_of_products(A(0, 2), A(0, 1), A(2, 2), A(2, 1)) * idet;
160 B(0, 2) = difference_of_products(A(0, 1), A(0, 2), A(1, 1), A(1, 2)) * idet;
161 B(1, 1) = difference_of_products(A(0, 0), A(0, 2), A(2, 0), A(2, 2)) * idet;
162 B(1, 2) = difference_of_products(A(1, 0), A(0, 0), A(1, 2), A(0, 2)) * idet;
163 B(2, 1) = difference_of_products(A(2, 0), A(0, 0), A(2, 1), A(0, 1)) * idet;
164 B(2, 2) = difference_of_products(A(0, 0), A(1, 0), A(0, 1), A(1, 1)) * idet;
165 break;
166 }
167 default:
168 throw std::runtime_error("math::inv is not implemented for "
169 + std::to_string(A.extent(0)) + "x"
170 + std::to_string(A.extent(1)) + " matrices.");
171 }
172}
173
180template <typename U, typename V, typename P>
181void dot(U A, V B, P C, bool transpose = false)
182{
183 static_assert(U::rank() == 2, "Must be rank 2");
184 static_assert(V::rank() == 2, "Must be rank 2");
185 static_assert(P::rank() == 2, "Must be rank 2");
186
187 if (transpose)
188 {
189 assert(A.extent(0) == B.extent(1));
190 for (std::size_t i = 0; i < A.extent(1); i++)
191 for (std::size_t j = 0; j < B.extent(0); j++)
192 for (std::size_t k = 0; k < A.extent(0); k++)
193 C(i, j) += A(k, i) * B(j, k);
194 }
195 else
196 {
197 assert(A.extent(1) == B.extent(0));
198 for (std::size_t i = 0; i < A.extent(0); i++)
199 for (std::size_t j = 0; j < B.extent(1); j++)
200 for (std::size_t k = 0; k < A.extent(1); k++)
201 C(i, j) += A(i, k) * B(k, j);
202 }
203}
204
211template <typename U, typename V>
212void pinv(U A, V P)
213{
214 static_assert(U::rank() == 2, "Must be rank 2");
215 static_assert(V::rank() == 2, "Must be rank 2");
216
217 assert(A.extent(0) > A.extent(1));
218 assert(P.extent(1) == A.extent(0));
219 assert(P.extent(0) == A.extent(1));
220 using T = typename U::value_type;
221 if (A.extent(1) == 2)
222 {
223 std::array<T, 6> ATb;
224 std::array<T, 4> ATAb, Invb;
225 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
226 T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 3>>
227 AT(ATb.data(), 2, 3);
228 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
229 T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
230 ATA(ATAb.data(), 2, 2);
231 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
232 T, MDSPAN_IMPL_STANDARD_NAMESPACE::extents<std::size_t, 2, 2>>
233 Inv(Invb.data(), 2, 2);
234
235 for (std::size_t i = 0; i < AT.extent(0); ++i)
236 for (std::size_t j = 0; j < AT.extent(1); ++j)
237 AT(i, j) = A(j, i);
238
239 std::ranges::fill(ATAb, 0.0);
240 for (std::size_t i = 0; i < P.extent(0); ++i)
241 for (std::size_t j = 0; j < P.extent(1); ++j)
242 P(i, j) = 0;
243
244 // pinv(A) = (A^T * A)^-1 * A^T
245 dot(AT, A, ATA);
246 inv(ATA, Inv);
247 dot(Inv, AT, P);
248 }
249 else if (A.extent(1) == 1)
250 {
251 T res = 0;
252 for (std::size_t i = 0; i < A.extent(0); ++i)
253 for (std::size_t j = 0; j < A.extent(1); ++j)
254 res += A(i, j) * A(i, j);
255
256 for (std::size_t i = 0; i < A.extent(0); ++i)
257 for (std::size_t j = 0; j < A.extent(1); ++j)
258 P(j, i) = (1 / res) * A(i, j);
259 }
260 else
261 {
262 throw std::runtime_error("math::pinv is not implemented for "
263 + std::to_string(A.extent(0)) + "x"
264 + std::to_string(A.extent(1)) + " matrices.");
265 }
266}
267
268} // namespace dolfinx::math