DOLFINx 0.9.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
matrix_csr_impl.h
1// Copyright (C) 2021-2023 Garth N. Wells and Chris N. Richardson
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <iostream>
10#include <numeric>
11#include <span>
12#include <utility>
13#include <vector>
14
15namespace dolfinx::la
16{
17namespace impl
18{
48template <int BS0, int BS1, typename OP, typename U, typename V, typename W,
49 typename X, typename Y>
50void insert_csr(U&& data, const V& cols, const W& row_ptr, const X& x,
51 const Y& xrows, const Y& xcols, OP op,
52 typename Y::value_type num_rows);
53
74template <int BS0, int BS1, typename OP, typename U, typename V, typename W,
75 typename X, typename Y>
76void insert_blocked_csr(U&& data, const V& cols, const W& row_ptr, const X& x,
77 const Y& xrows, const Y& xcols, OP op,
78 typename Y::value_type num_rows);
79
97template <typename OP, typename U, typename V, typename W, typename X,
98 typename Y>
99void insert_nonblocked_csr(U&& data, const V& cols, const W& row_ptr,
100 const X& x, const Y& xrows, const Y& xcols, OP op,
101 typename Y::value_type num_rows, int bs0, int bs1);
102
103} // namespace impl
104
105//-----------------------------------------------------------------------------
106template <int BS0, int BS1, typename OP, typename U, typename V, typename W,
107 typename X, typename Y>
108void impl::insert_csr(U&& data, const V& cols, const W& row_ptr, const X& x,
109 const Y& xrows, const Y& xcols, OP op,
110 [[maybe_unused]] typename Y::value_type num_rows)
111{
112 const std::size_t nc = xcols.size();
113 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
114 for (std::size_t r = 0; r < xrows.size(); ++r)
115 {
116 // Row index and current data row
117 auto row = xrows[r];
118 using T = typename X::value_type;
119 const T* xr = x.data() + r * nc * BS0 * BS1;
120
121#ifndef NDEBUG
122 if (row >= num_rows)
123 throw std::runtime_error("Local row out of range");
124#endif
125 // Columns indices for row
126 auto cit0 = std::next(cols.begin(), row_ptr[row]);
127 auto cit1 = std::next(cols.begin(), row_ptr[row + 1]);
128 for (std::size_t c = 0; c < nc; ++c)
129 {
130 // Find position of column index
131 auto it = std::lower_bound(cit0, cit1, xcols[c]);
132
133 if (it == cit1 or *it != xcols[c])
134 throw std::runtime_error("Entry not in sparsity");
135
136 std::size_t d = std::distance(cols.begin(), it);
137 int di = d * BS0 * BS1;
138 int xi = c * BS1;
139 assert(di < data.size());
140 for (int i = 0; i < BS0; ++i)
141 {
142 for (int j = 0; j < BS1; ++j)
143 op(data[di + j], xr[xi + j]);
144 di += BS1;
145 xi += nc * BS1;
146 }
147 }
148 }
149}
150//-----------------------------------------------------------------------------
151// Insert with block insertion into a regular CSR (block size 1)
152template <int BS0, int BS1, typename OP, typename U, typename V, typename W,
153 typename X, typename Y>
154void impl::insert_blocked_csr(U&& data, const V& cols, const W& row_ptr,
155 const X& x, const Y& xrows, const Y& xcols, OP op,
156 [[maybe_unused]] typename Y::value_type num_rows)
157{
158 const std::size_t nc = xcols.size();
159 assert(x.size() == xrows.size() * xcols.size() * BS0 * BS1);
160 for (std::size_t r = 0; r < xrows.size(); ++r)
161 {
162 // Row index and current data row
163 auto row = xrows[r] * BS0;
164
165#ifndef NDEBUG
166 if (row >= num_rows)
167 throw std::runtime_error("Local row out of range");
168#endif
169
170 for (int i = 0; i < BS0; ++i)
171 {
172 using T = typename X::value_type;
173 const T* xr = x.data() + (r * BS0 + i) * nc * BS1;
174 // Columns indices for row
175 auto cit0 = std::next(cols.begin(), row_ptr[row + i]);
176 auto cit1 = std::next(cols.begin(), row_ptr[row + i + 1]);
177 for (std::size_t c = 0; c < nc; ++c)
178 {
179 // Find position of column index
180 auto it = std::lower_bound(cit0, cit1, xcols[c] * BS1);
181
182 if (it == cit1 or *it != xcols[c] * BS1)
183 throw std::runtime_error("Entry not in sparsity");
184
185 std::size_t d = std::distance(cols.begin(), it);
186 assert(d < data.size());
187 int xi = c * BS1;
188 for (int j = 0; j < BS1; ++j)
189 op(data[d + j], xr[xi + j]);
190 }
191 }
192 }
193}
194//-----------------------------------------------------------------------------
195// Add individual entries in block-CSR storage
196template <typename OP, typename U, typename V, typename W, typename X,
197 typename Y>
198void impl::insert_nonblocked_csr(U&& data, const V& cols, const W& row_ptr,
199 const X& x, const Y& xrows, const Y& xcols,
200 OP op,
201 [[maybe_unused]]
202 typename Y::value_type num_rows,
203 int bs0, int bs1)
204{
205 const std::size_t nc = xcols.size();
206 const int nbs = bs0 * bs1;
207
208 assert(x.size() == xrows.size() * xcols.size());
209 for (std::size_t r = 0; r < xrows.size(); ++r)
210 {
211 // Row index and current data row
212 auto rdiv = std::div(xrows[r], bs0);
213 using T = typename X::value_type;
214 const T* xr = x.data() + r * nc;
215
216#ifndef NDEBUG
217 if (rdiv.quot >= num_rows)
218 throw std::runtime_error("Local row out of range");
219#endif
220 // Columns indices for row
221 auto cit0 = std::next(cols.begin(), row_ptr[rdiv.quot]);
222 auto cit1 = std::next(cols.begin(), row_ptr[rdiv.quot + 1]);
223 for (std::size_t c = 0; c < nc; ++c)
224 {
225 // Find position of column index
226 auto cdiv = std::div(xcols[c], bs1);
227 auto it = std::lower_bound(cit0, cit1, cdiv.quot);
228
229 if (it == cit1 or *it != cdiv.quot)
230 throw std::runtime_error("Entry not in sparsity");
231
232 std::size_t d = std::distance(cols.begin(), it);
233 const int di = d * nbs + rdiv.rem * bs1 + cdiv.rem;
234 assert(di < data.size());
235 op(data[di], xr[c]);
236 }
237 }
238}
239//-----------------------------------------------------------------------------
240} // namespace dolfinx::la
Linear algebra interface.
Definition sparsitybuild.h:15