DOLFINx 0.10.0.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
matrix_csr_impl.h
1// Copyright (C) 2021-2023 Garth N. Wells and Chris N. Richardson
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <numeric>
10#include <span>
11#include <utility>
12#include <vector>
13
14namespace dolfinx::la
15{
16namespace impl
17{
47template <int BS0, int BS1, typename OP, typename U, typename V, typename W,
48 typename X, typename Y>
49void insert_csr(U&& data, const V& cols, const W& row_ptr, const X& x,
50 const Y& xrows, const Y& xcols, OP op,
51 [[maybe_unused]] typename Y::value_type num_rows)
52{
53 const std::size_t nc = xcols.size();
54 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
55 for (std::size_t r = 0; r < xrows.size(); ++r)
56 {
57 // Row index and current data row
58 auto row = xrows[r];
59 using T = typename X::value_type;
60 const T* xr = x.data() + r * nc * BS0 * BS1;
61
62#ifndef NDEBUG
63 if (row >= num_rows)
64 throw std::runtime_error("Local row out of range");
65#endif
66 // Columns indices for row
67 auto cit0 = std::next(cols.begin(), row_ptr[row]);
68 auto cit1 = std::next(cols.begin(), row_ptr[row + 1]);
69 for (std::size_t c = 0; c < nc; ++c)
70 {
71 // Find position of column index
72 auto it = std::lower_bound(cit0, cit1, xcols[c]);
73 if (it == cit1 or *it != xcols[c])
74 throw std::runtime_error("Entry not in sparsity");
75
76 std::size_t d = std::distance(cols.begin(), it);
77 std::size_t di = d * BS0 * BS1;
78 std::size_t xi = c * BS1;
79 assert(di < data.size());
80 for (int i = 0; i < BS0; ++i)
81 {
82 for (int j = 0; j < BS1; ++j)
83 op(data[di + j], xr[xi + j]);
84 di += BS1;
85 xi += nc * BS1;
86 }
87 }
88 }
89}
90
112template <int BS0, int BS1, typename OP, typename U, typename V, typename W,
113 typename X, typename Y>
114void insert_blocked_csr(U&& data, const V& cols, const W& row_ptr, const X& x,
115 const Y& xrows, const Y& xcols, OP op,
116 [[maybe_unused]] typename Y::value_type num_rows)
117{
118 const std::size_t nc = xcols.size();
119 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
120 for (std::size_t r = 0; r < xrows.size(); ++r)
121 {
122 // Row index and current data row
123 auto row = xrows[r] * BS0;
124#ifndef NDEBUG
125 if (row >= num_rows)
126 throw std::runtime_error("Local row out of range");
127#endif
128
129 for (int i = 0; i < BS0; ++i)
130 {
131 using T = typename X::value_type;
132 const T* xr = x.data() + (r * BS0 + i) * nc * BS1;
133
134 // Columns indices for row
135 auto cit0 = std::next(cols.begin(), row_ptr[row + i]);
136 auto cit1 = std::next(cols.begin(), row_ptr[row + i + 1]);
137 for (std::size_t c = 0; c < nc; ++c)
138 {
139 // Find position of column index
140 auto it = std::lower_bound(cit0, cit1, xcols[c] * BS1);
141 if (it == cit1 or *it != xcols[c] * BS1)
142 throw std::runtime_error("Entry not in sparsity");
143
144 std::size_t d = std::distance(cols.begin(), it);
145 assert(d < data.size());
146 std::size_t xi = c * BS1;
147 for (int j = 0; j < BS1; ++j)
148 op(data[d + j], xr[xi + j]);
149 }
150 }
151 }
152}
153
173template <typename OP, typename U, typename V, typename W, typename X,
174 typename Y>
175void insert_nonblocked_csr(U&& data, const V& cols, const W& row_ptr,
176 const X& x, const Y& xrows, const Y& xcols, OP op,
177 [[maybe_unused]] typename Y::value_type num_rows,
178 int bs0, int bs1)
179{
180 const std::size_t nc = xcols.size();
181 const int nbs = bs0 * bs1;
182
183 assert(x.size() == xrows.size() * xcols.size());
184 for (std::size_t r = 0; r < xrows.size(); ++r)
185 {
186 // Row index and current data row
187 auto rdiv = std::div(xrows[r], bs0);
188 using T = typename X::value_type;
189 const T* xr = x.data() + r * nc;
190
191#ifndef NDEBUG
192 if (rdiv.quot >= num_rows)
193 throw std::runtime_error("Local row out of range");
194#endif
195 // Columns indices for row
196 auto cit0 = std::next(cols.begin(), row_ptr[rdiv.quot]);
197 auto cit1 = std::next(cols.begin(), row_ptr[rdiv.quot + 1]);
198 for (std::size_t c = 0; c < nc; ++c)
199 {
200 // Find position of column index
201 auto cdiv = std::div(xcols[c], bs1);
202 auto it = std::lower_bound(cit0, cit1, cdiv.quot);
203 if (it == cit1 or *it != cdiv.quot)
204 throw std::runtime_error("Entry not in sparsity");
205
206 std::size_t d = std::distance(cols.begin(), it);
207 std::size_t di = d * nbs + rdiv.rem * bs1 + cdiv.rem;
208 assert(di < data.size());
209 op(data[di], xr[c]);
210 }
211 }
212}
213
225template <typename T, int BS1>
226void spmv(std::span<const T> values, std::span<const std::int64_t> row_begin,
227 std::span<const std::int64_t> row_end,
228 std::span<const std::int32_t> indices, std::span<const T> x,
229 std::span<T> y, int bs0, int bs1)
230{
231 assert(row_begin.size() == row_end.size());
232 for (int k0 = 0; k0 < bs0; ++k0)
233 {
234 for (std::size_t i = 0; i < row_begin.size(); i++)
235 {
236 T vi{0};
237 for (std::int32_t j = row_begin[i]; j < row_end[i]; j++)
238 {
239 if constexpr (BS1 == -1)
240 {
241 for (int k1 = 0; k1 < bs1; ++k1)
242 {
243 vi += values[j * bs1 * bs0 + k1 * bs0 + k0]
244 * x[indices[j] * bs1 + k1];
245 }
246 }
247 else
248 {
249 for (int k1 = 0; k1 < BS1; ++k1)
250 {
251 vi += values[j * BS1 * bs0 + k1 * bs0 + k0]
252 * x[indices[j] * BS1 + k1];
253 }
254 }
255 }
256
257 y[i * bs0 + k0] += vi;
258 }
259 }
260}
261
262} // namespace impl
263} // namespace dolfinx::la
Linear algebra interface.
Definition sparsitybuild.h:15