DOLFINx 0.10.0.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
sort.h
1// Copyright (C) 2021-2025 Igor Baratta and Paul T. Kühner
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <algorithm>
10#include <bit>
11#include <cassert>
12#include <concepts>
13#include <cstdint>
14#include <functional>
15#include <iterator>
16#include <limits>
17#include <numeric>
18#include <span>
19#include <type_traits>
20#include <utility>
21#include <vector>
22
23namespace dolfinx
24{
25
26struct __unsigned_projection
27{
28 // Transforms the projected value to an unsigned int (if signed), while
29 // maintaining relative order by
30 // x ↦ x + |std::numeric_limits<I>::min()|
31 template <std::signed_integral T>
32 constexpr std::make_unsigned_t<T> operator()(T e) const noexcept
33 {
34 using uT = std::make_unsigned_t<T>;
35
36 // Assert binary structure for bit shift
37 static_assert(static_cast<uT>(std::numeric_limits<T>::min())
38 + static_cast<uT>(std::numeric_limits<T>::max())
39 == static_cast<uT>(T(-1)));
40 static_assert(std::numeric_limits<uT>::digits
41 == std::numeric_limits<T>::digits + 1);
42 static_assert(std::bit_cast<uT>(std::numeric_limits<T>::min())
43 == (uT(1) << (sizeof(T) * 8 - 1)));
44
45 return std::bit_cast<uT>(std::forward<T>(e))
46 ^ (uT(1) << (sizeof(T) * 8 - 1));
47 }
48};
49
51inline constexpr __unsigned_projection unsigned_projection{};
52
53struct __radix_sort
54{
79 template <std::ranges::random_access_range R, typename P = std::identity,
80 std::make_unsigned_t<std::remove_cvref_t<
81 std::invoke_result_t<P, std::iter_value_t<R>>>>
82 BITS
83 = 8>
84 requires std::integral<decltype(BITS)>
85 constexpr void operator()(R&& range, P proj = {}) const
86 {
87 // value type
88 using T = std::iter_value_t<R>;
89
90 // index type (if no projection is provided it holds I == T)
91 using I = std::remove_cvref_t<std::invoke_result_t<P, T>>;
92 using uI = std::make_unsigned_t<I>;
93
94 if constexpr (!std::is_same_v<uI, I>)
95 {
96 __radix_sort()(std::forward<R>(range), [&](const T& e) -> uI
97 { return unsigned_projection(proj(e)); });
98 return;
99 }
100
101 if (range.size() <= 1)
102 return;
103
104 uI max_value = proj(*std::ranges::max_element(range, std::less{}, proj));
105
106 // Sort N bits at a time
107 constexpr uI bucket_size = 1 << BITS;
108 uI mask = (uI(1) << BITS) - 1;
109
110 // Compute number of iterations, most significant digit (N bits) of
111 // maxvalue
112 I its = 0;
113
114 // optimize for case where all first bits are set - then order will not
115 // depend on it
116 bool all_first_bit = std::ranges::all_of(
117 range, [&](const auto& e)
118 { return proj(e) & (uI(1) << (sizeof(uI) * 8 - 1)); });
119 if (all_first_bit)
120 max_value = max_value & ~(uI(1) << (sizeof(uI) * 8 - 1));
121
122 while (max_value)
123 {
124 max_value >>= BITS;
125 its++;
126 }
127
128 // Adjacency list arrays for computing insertion position
129 std::array<I, bucket_size> counter;
130 std::array<I, bucket_size + 1> offset;
131
132 uI mask_offset = 0;
133 std::vector<T> buffer(range.size());
134 std::span<T> current_perm = range;
135 std::span<T> next_perm = buffer;
136 for (I i = 0; i < its; i++)
137 {
138 // Zero counter array
139 std::ranges::fill(counter, 0);
140
141 // Count number of elements per bucket
142 for (const auto& c : current_perm)
143 counter[(proj(c) & mask) >> mask_offset]++;
144
145 // Prefix sum to get the inserting position
146 offset[0] = 0;
147 std::partial_sum(counter.begin(), counter.end(),
148 std::next(offset.begin()));
149 for (const auto& c : current_perm)
150 {
151 uI bucket = (proj(c) & mask) >> mask_offset;
152 uI new_pos = offset[bucket + 1] - counter[bucket];
153 next_perm[new_pos] = c;
154 counter[bucket]--;
155 }
156
157 mask = mask << BITS;
158 mask_offset += BITS;
159
160 std::swap(current_perm, next_perm);
161 }
162
163 // Copy data back to array
164 if (its % 2 != 0)
165 std::ranges::copy(buffer, range.begin());
166 }
167};
168
170inline constexpr __radix_sort radix_sort{};
171
182template <typename T, int BITS = 16>
183std::vector<std::int32_t> sort_by_perm(std::span<const T> x, std::size_t shape1)
184{
185 static_assert(std::is_integral_v<T>, "Integral required.");
186
187 if (x.empty())
188 return std::vector<std::int32_t>{};
189
190 assert(shape1 > 0);
191 assert(x.size() % shape1 == 0);
192 const std::size_t shape0 = x.size() / shape1;
193 std::vector<std::int32_t> perm(shape0);
194 std::iota(perm.begin(), perm.end(), 0);
195
196 // Sort by each column, right to left. Col 0 has the most significant
197 // "digit".
198 std::vector<T> column(shape0);
199 for (std::size_t i = 0; i < shape1; ++i)
200 {
201 std::size_t col = shape1 - 1 - i;
202 for (std::size_t j = 0; j < shape0; ++j)
203 column[j] = x[j * shape1 + col];
204
205 radix_sort(perm, [&column](auto index) { return column[index]; });
206 }
207
208 return perm;
209}
210
211} // namespace dolfinx
Top-level namespace.
Definition defines.h:12
constexpr __unsigned_projection unsigned_projection
Projection from signed to signed int.
Definition sort.h:51
std::vector< std::int32_t > sort_by_perm(std::span< const T > x, std::size_t shape1)
Compute the permutation array that sorts a 2D array by row.
Definition sort.h:183
constexpr __radix_sort radix_sort
Radix sort.
Definition sort.h:170