48template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
49 typename X,
typename Y>
50void insert_csr(U&& data,
const V& cols,
const W& row_ptr,
const X& x,
51 const Y& xrows,
const Y& xcols, OP op,
52 typename Y::value_type num_rows);
74template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
75 typename X,
typename Y>
76void insert_blocked_csr(U&& data,
const V& cols,
const W& row_ptr,
const X& x,
77 const Y& xrows,
const Y& xcols, OP op,
78 typename Y::value_type num_rows);
97template <
typename OP,
typename U,
typename V,
typename W,
typename X,
99void insert_nonblocked_csr(U&& data,
const V& cols,
const W& row_ptr,
100 const X& x,
const Y& xrows,
const Y& xcols, OP op,
101 typename Y::value_type num_rows,
int bs0,
int bs1);
106template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
107 typename X,
typename Y>
108void impl::insert_csr(U&& data,
const V& cols,
const W& row_ptr,
const X& x,
109 const Y& xrows,
const Y& xcols, OP op,
110 [[maybe_unused]]
typename Y::value_type num_rows)
112 const std::size_t nc = xcols.size();
113 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
114 for (std::size_t r = 0; r < xrows.size(); ++r)
118 using T =
typename X::value_type;
119 const T* xr = x.data() + r * nc * BS0 * BS1;
123 throw std::runtime_error(
"Local row out of range");
126 auto cit0 = std::next(cols.begin(), row_ptr[row]);
127 auto cit1 = std::next(cols.begin(), row_ptr[row + 1]);
128 for (std::size_t c = 0; c < nc; ++c)
131 auto it = std::lower_bound(cit0, cit1, xcols[c]);
133 if (it == cit1 or *it != xcols[c])
134 throw std::runtime_error(
"Entry not in sparsity");
136 std::size_t d = std::distance(cols.begin(), it);
137 int di = d * BS0 * BS1;
139 assert(di < data.size());
140 for (
int i = 0; i < BS0; ++i)
142 for (
int j = 0; j < BS1; ++j)
143 op(data[di + j], xr[xi + j]);
152template <
int BS0,
int BS1,
typename OP,
typename U,
typename V,
typename W,
153 typename X,
typename Y>
154void impl::insert_blocked_csr(U&& data,
const V& cols,
const W& row_ptr,
155 const X& x,
const Y& xrows,
const Y& xcols, OP op,
156 [[maybe_unused]]
typename Y::value_type num_rows)
158 const std::size_t nc = xcols.size();
159 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
160 for (std::size_t r = 0; r < xrows.size(); ++r)
163 auto row = xrows[r] * BS0;
167 throw std::runtime_error(
"Local row out of range");
170 for (
int i = 0; i < BS0; ++i)
172 using T =
typename X::value_type;
173 const T* xr = x.data() + (r * BS0 + i) * nc * BS1;
175 auto cit0 = std::next(cols.begin(), row_ptr[row + i]);
176 auto cit1 = std::next(cols.begin(), row_ptr[row + i + 1]);
177 for (std::size_t c = 0; c < nc; ++c)
180 auto it = std::lower_bound(cit0, cit1, xcols[c] * BS1);
182 if (it == cit1 or *it != xcols[c] * BS1)
183 throw std::runtime_error(
"Entry not in sparsity");
185 std::size_t d = std::distance(cols.begin(), it);
186 assert(d < data.size());
188 for (
int j = 0; j < BS1; ++j)
189 op(data[d + j], xr[xi + j]);
196template <
typename OP,
typename U,
typename V,
typename W,
typename X,
198void impl::insert_nonblocked_csr(U&& data,
const V& cols,
const W& row_ptr,
199 const X& x,
const Y& xrows,
const Y& xcols,
202 typename Y::value_type num_rows,
205 const std::size_t nc = xcols.size();
206 const int nbs = bs0 * bs1;
208 assert(x.size() == xrows.size() * xcols.size());
209 for (std::size_t r = 0; r < xrows.size(); ++r)
212 auto rdiv = std::div(xrows[r], bs0);
213 using T =
typename X::value_type;
214 const T* xr = x.data() + r * nc;
217 if (rdiv.quot >= num_rows)
218 throw std::runtime_error(
"Local row out of range");
221 auto cit0 = std::next(cols.begin(), row_ptr[rdiv.quot]);
222 auto cit1 = std::next(cols.begin(), row_ptr[rdiv.quot + 1]);
223 for (std::size_t c = 0; c < nc; ++c)
226 auto cdiv = std::div(xcols[c], bs1);
227 auto it = std::lower_bound(cit0, cit1, cdiv.quot);
229 if (it == cit1 or *it != cdiv.quot)
230 throw std::runtime_error(
"Entry not in sparsity");
232 std::size_t d = std::distance(cols.begin(), it);
233 const int di = d * nbs + rdiv.rem * bs1 + cdiv.rem;
234 assert(di < data.size());
Linear algebra interface.
Definition sparsitybuild.h:15