47template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
48 typename X,
typename Y>
49void insert_csr(U&& data,
const V& cols,
const W& row_ptr,
const X& x,
50 const Y& xrows,
const Y& xcols, OP op,
51 [[maybe_unused]]
typename Y::value_type num_rows)
53 const std::size_t nc = xcols.size();
54 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
55 for (std::size_t r = 0; r < xrows.size(); ++r)
59 using T =
typename X::value_type;
60 const T* xr = x.data() + r * nc * BS0 * BS1;
64 throw std::runtime_error(
"Local row out of range");
67 auto cit0 = std::next(cols.begin(), row_ptr[row]);
68 auto cit1 = std::next(cols.begin(), row_ptr[row + 1]);
69 for (std::size_t c = 0; c < nc; ++c)
72 auto it = std::lower_bound(cit0, cit1, xcols[c]);
73 if (it == cit1 or *it != xcols[c])
74 throw std::runtime_error(
"Entry not in sparsity");
76 std::size_t d = std::distance(cols.begin(), it);
77 std::size_t di = d * BS0 * BS1;
78 std::size_t xi = c * BS1;
79 assert(di < data.size());
80 for (
int i = 0; i < BS0; ++i)
82 for (
int j = 0; j < BS1; ++j)
83 op(data[di + j], xr[xi + j]);
112template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
113 typename X,
typename Y>
115 const Y& xrows,
const Y& xcols, OP op,
116 [[maybe_unused]]
typename Y::value_type num_rows)
118 const std::size_t nc = xcols.size();
119 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
120 for (std::size_t r = 0; r < xrows.size(); ++r)
123 auto row = xrows[r] * BS0;
126 throw std::runtime_error(
"Local row out of range");
129 for (
int i = 0; i < BS0; ++i)
131 using T =
typename X::value_type;
132 const T* xr = x.data() + (r * BS0 + i) * nc * BS1;
135 auto cit0 = std::next(cols.begin(), row_ptr[row + i]);
136 auto cit1 = std::next(cols.begin(), row_ptr[row + i + 1]);
137 for (std::size_t c = 0; c < nc; ++c)
140 auto it = std::lower_bound(cit0, cit1, xcols[c] * BS1);
141 if (it == cit1 or *it != xcols[c] * BS1)
142 throw std::runtime_error(
"Entry not in sparsity");
144 std::size_t d = std::distance(cols.begin(), it);
145 assert(d < data.size());
146 std::size_t xi = c * BS1;
147 for (
int j = 0; j < BS1; ++j)
148 op(data[d + j], xr[xi + j]);
173template <
typename OP,
typename U,
typename V,
typename W,
typename X,
176 const X& x,
const Y& xrows,
const Y& xcols, OP op,
177 [[maybe_unused]]
typename Y::value_type num_rows,
180 const std::size_t nc = xcols.size();
181 const int nbs = bs0 * bs1;
183 assert(x.size() == xrows.size() * xcols.size());
184 for (std::size_t r = 0; r < xrows.size(); ++r)
187 auto rdiv = std::div(xrows[r], bs0);
188 using T =
typename X::value_type;
189 const T* xr = x.data() + r * nc;
192 if (rdiv.quot >= num_rows)
193 throw std::runtime_error(
"Local row out of range");
196 auto cit0 = std::next(cols.begin(), row_ptr[rdiv.quot]);
197 auto cit1 = std::next(cols.begin(), row_ptr[rdiv.quot + 1]);
198 for (std::size_t c = 0; c < nc; ++c)
201 auto cdiv = std::div(xcols[c], bs1);
202 auto it = std::lower_bound(cit0, cit1, cdiv.quot);
203 if (it == cit1 or *it != cdiv.quot)
204 throw std::runtime_error(
"Entry not in sparsity");
206 std::size_t d = std::distance(cols.begin(), it);
207 std::size_t di = d * nbs + rdiv.rem * bs1 + cdiv.rem;
208 assert(di < data.size());
225template <
typename T,
int BS1>
226void spmv(std::span<const T> values, std::span<const std::int64_t> row_begin,
227 std::span<const std::int64_t> row_end,
228 std::span<const std::int32_t> indices, std::span<const T> x,
229 std::span<T> y,
int bs0,
int bs1)
231 assert(row_begin.size() == row_end.size());
235 for (
int k0 = 0; k0 < bs0; ++k0)
237 for (std::size_t i = 0; i < row_begin.size(); i++)
240 for (std::int64_t j = row_begin[i]; j < row_end[i]; j++)
242 if constexpr (BS1 == -1)
244 for (
int k1 = 0; k1 < bs1; ++k1)
246 vi += values[j * bs0 * bs1 + k0 * bs1 + k1]
247 * x[indices[j] * bs1 + k1];
252 for (
int k1 = 0; k1 < BS1; ++k1)
254 vi += values[j * bs0 * BS1 + k0 * BS1 + k1]
255 * x[indices[j] * BS1 + k1];
260 y[i * bs0 + k0] += vi;
298template <
typename T,
int BS1>
299void spmvT(std::span<const T> values, std::span<const std::int64_t> row_begin,
300 std::span<const std::int64_t> row_end,
301 std::span<const std::int32_t> indices, std::span<const T> x,
302 std::span<T> y,
int bs0,
int bs1)
304 assert(row_begin.size() == row_end.size());
307 for (
int k0 = 0; k0 < bs0; ++k0)
309 for (std::size_t i = 0; i < row_begin.size(); i++)
311 const T xval = x[i * bs0 + k0];
312 for (std::int64_t j = row_begin[i]; j < row_end[i]; j++)
314 if constexpr (BS1 == -1)
316 for (
int k1 = 0; k1 < bs1; ++k1)
318 y[indices[j] * bs1 + k1]
319 += values[j * bs0 * bs1 + k0 * bs1 + k1] * xval;
324 for (
int k1 = 0; k1 < BS1; ++k1)
326 y[indices[j] * BS1 + k1]
327 += values[j * bs0 * BS1 + k0 * BS1 + k1] * xval;
Fetch the rows of B that correspond to the ghost columns of A.
Definition matmul.h:36
void insert_csr(U &&data, const V &cols, const W &row_ptr, const X &x, const Y &xrows, const Y &xcols, OP op, typename Y::value_type num_rows)
Incorporate data into a CSR matrix.
Definition matrix_csr_impl.h:49
void spmvT(std::span< const T > values, std::span< const std::int64_t > row_begin, std::span< const std::int64_t > row_end, std::span< const std::int32_t > indices, std::span< const T > x, std::span< T > y, int bs0, int bs1)
Sparse matrix-vector transpose product implementation.
Definition matrix_csr_impl.h:299
void insert_blocked_csr(U &&data, const V &cols, const W &row_ptr, const X &x, const Y &xrows, const Y &xcols, OP op, typename Y::value_type num_rows)
Incorporate blocked data with given block sizes into a non-blocked MatrixCSR.
Definition matrix_csr_impl.h:114
void insert_nonblocked_csr(U &&data, const V &cols, const W &row_ptr, const X &x, const Y &xrows, const Y &xcols, OP op, typename Y::value_type num_rows, int bs0, int bs1)
Incorporate non-blocked data into a blocked matrix (data block size=1).
Definition matrix_csr_impl.h:175
void spmv(std::span< const T > values, std::span< const std::int64_t > row_begin, std::span< const std::int64_t > row_end, std::span< const std::int32_t > indices, std::span< const T > x, std::span< T > y, int bs0, int bs1)
Sparse matrix-vector product implementation.
Definition matrix_csr_impl.h:226
Linear algebra interface.
Definition dolfinx_la.h:7