47template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
48 typename X,
typename Y>
49void insert_csr(U&& data,
const V& cols,
const W& row_ptr,
const X& x,
50 const Y& xrows,
const Y& xcols, OP op,
51 [[maybe_unused]]
typename Y::value_type num_rows)
53 const std::size_t nc = xcols.size();
54 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
55 for (std::size_t r = 0; r < xrows.size(); ++r)
59 using T =
typename X::value_type;
60 const T* xr = x.data() + r * nc * BS0 * BS1;
64 throw std::runtime_error(
"Local row out of range");
67 auto cit0 = std::next(cols.begin(), row_ptr[row]);
68 auto cit1 = std::next(cols.begin(), row_ptr[row + 1]);
69 for (std::size_t c = 0; c < nc; ++c)
72 auto it = std::lower_bound(cit0, cit1, xcols[c]);
73 if (it == cit1 or *it != xcols[c])
74 throw std::runtime_error(
"Entry not in sparsity");
76 std::size_t d = std::distance(cols.begin(), it);
77 std::size_t di = d * BS0 * BS1;
78 std::size_t xi = c * BS1;
79 assert(di < data.size());
80 for (
int i = 0; i < BS0; ++i)
82 for (
int j = 0; j < BS1; ++j)
83 op(data[di + j], xr[xi + j]);
112template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
113 typename X,
typename Y>
114void insert_blocked_csr(U&& data,
const V& cols,
const W& row_ptr,
const X& x,
115 const Y& xrows,
const Y& xcols, OP op,
116 [[maybe_unused]]
typename Y::value_type num_rows)
118 const std::size_t nc = xcols.size();
119 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
120 for (std::size_t r = 0; r < xrows.size(); ++r)
123 auto row = xrows[r] * BS0;
126 throw std::runtime_error(
"Local row out of range");
129 for (
int i = 0; i < BS0; ++i)
131 using T =
typename X::value_type;
132 const T* xr = x.data() + (r * BS0 + i) * nc * BS1;
135 auto cit0 = std::next(cols.begin(), row_ptr[row + i]);
136 auto cit1 = std::next(cols.begin(), row_ptr[row + i + 1]);
137 for (std::size_t c = 0; c < nc; ++c)
140 auto it = std::lower_bound(cit0, cit1, xcols[c] * BS1);
141 if (it == cit1 or *it != xcols[c] * BS1)
142 throw std::runtime_error(
"Entry not in sparsity");
144 std::size_t d = std::distance(cols.begin(), it);
145 assert(d < data.size());
146 std::size_t xi = c * BS1;
147 for (
int j = 0; j < BS1; ++j)
148 op(data[d + j], xr[xi + j]);
173template <
typename OP,
typename U,
typename V,
typename W,
typename X,
175void insert_nonblocked_csr(U&& data,
const V& cols,
const W& row_ptr,
176 const X& x,
const Y& xrows,
const Y& xcols, OP op,
177 [[maybe_unused]]
typename Y::value_type num_rows,
180 const std::size_t nc = xcols.size();
181 const int nbs = bs0 * bs1;
183 assert(x.size() == xrows.size() * xcols.size());
184 for (std::size_t r = 0; r < xrows.size(); ++r)
187 auto rdiv = std::div(xrows[r], bs0);
188 using T =
typename X::value_type;
189 const T* xr = x.data() + r * nc;
192 if (rdiv.quot >= num_rows)
193 throw std::runtime_error(
"Local row out of range");
196 auto cit0 = std::next(cols.begin(), row_ptr[rdiv.quot]);
197 auto cit1 = std::next(cols.begin(), row_ptr[rdiv.quot + 1]);
198 for (std::size_t c = 0; c < nc; ++c)
201 auto cdiv = std::div(xcols[c], bs1);
202 auto it = std::lower_bound(cit0, cit1, cdiv.quot);
203 if (it == cit1 or *it != cdiv.quot)
204 throw std::runtime_error(
"Entry not in sparsity");
206 std::size_t d = std::distance(cols.begin(), it);
207 std::size_t di = d * nbs + rdiv.rem * bs1 + cdiv.rem;
208 assert(di < data.size());
225template <
typename T,
int BS1>
226void spmv(std::span<const T> values, std::span<const std::int64_t> row_begin,
227 std::span<const std::int64_t> row_end,
228 std::span<const std::int32_t> indices, std::span<const T> x,
229 std::span<T> y,
int bs0,
int bs1)
231 assert(row_begin.size() == row_end.size());
232 for (
int k0 = 0; k0 < bs0; ++k0)
234 for (std::size_t i = 0; i < row_begin.size(); i++)
237 for (std::int32_t j = row_begin[i]; j < row_end[i]; j++)
239 if constexpr (BS1 == -1)
241 for (
int k1 = 0; k1 < bs1; ++k1)
243 vi += values[j * bs1 * bs0 + k1 * bs0 + k0]
244 * x[indices[j] * bs1 + k1];
249 for (
int k1 = 0; k1 < BS1; ++k1)
251 vi += values[j * BS1 * bs0 + k1 * bs0 + k0]
252 * x[indices[j] * BS1 + k1];
257 y[i * bs0 + k0] += vi;
Linear algebra interface.
Definition sparsitybuild.h:15