DOLFINx 0.11.0.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
sort.h
1// Copyright (C) 2021-2025 Igor Baratta and Paul T. Kühner
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <algorithm>
10#include <bit>
11#include <cassert>
12#include <concepts>
13#include <cstdint>
14#include <functional>
15#include <iterator>
16#include <limits>
17#include <numeric>
18#include <span>
19#include <type_traits>
20#include <utility>
21#include <vector>
22
23namespace dolfinx
24{
25struct __unsigned_projection
26{
27 // Transforms the projected value to an unsigned int (if signed),
28 // while maintaining relative order by
29 // x ↦ x + |std::numeric_limits<I>::min()|
30 template <std::signed_integral T>
31 constexpr std::make_unsigned_t<T> operator()(T e) const noexcept
32 {
33 using uT = std::make_unsigned_t<T>;
34
35 // Assert binary structure for bit shift
36 static_assert(static_cast<uT>(std::numeric_limits<T>::min())
37 + static_cast<uT>(std::numeric_limits<T>::max())
38 == static_cast<uT>(T(-1)));
39 static_assert(std::numeric_limits<uT>::digits
40 == std::numeric_limits<T>::digits + 1);
41 static_assert(std::bit_cast<uT>(std::numeric_limits<T>::min())
42 == (uT(1) << (sizeof(T) * 8 - 1)));
43
44 return std::bit_cast<uT>(std::forward<T>(e))
45 ^ (uT(1) << (sizeof(T) * 8 - 1));
46 }
47};
48
50inline constexpr __unsigned_projection unsigned_projection{};
51
76template <int BITS = 8, typename P = std::identity,
77 std::ranges::random_access_range R>
78constexpr void radix_sort(R&& range, P proj = {})
79{
80 using bits_t = std::make_unsigned_t<
81 std::remove_cvref_t<std::invoke_result_t<P, std::iter_value_t<R>>>>;
82 constexpr bits_t _BITS = BITS;
83
84 // Value type
85 using T = std::iter_value_t<R>;
86
87 // Index type (if no projection is provided it holds I == T)
88 using I = std::remove_cvref_t<std::invoke_result_t<P, T>>;
89 using uI = std::make_unsigned_t<I>;
90
91 if constexpr (!std::is_same_v<uI, I>)
92 {
93 radix_sort<_BITS>(std::forward<R>(range), [&](const T& e) -> uI
94 { return unsigned_projection(proj(e)); });
95 return;
96 }
97
98 if (range.size() <= 1)
99 return;
100
101 uI max_value = proj(*std::ranges::max_element(range, std::less{}, proj));
102
103 // Sort N bits at a time
104 constexpr uI bucket_size = 1 << _BITS;
105 uI mask = (uI(1) << _BITS) - 1;
106
107 // Compute number of iterations, most significant digit (N bits) of
108 // maxvalue
109 I its = 0;
110
111 // Optimize for case where all first bits are set - then order will
112 // not depend on it
113 if (bool all_first_bit = std::ranges::all_of(
114 range, [&proj](const auto& e)
115 { return proj(e) & (uI(1) << (sizeof(uI) * 8 - 1)); });
116 all_first_bit)
117 {
118 max_value = max_value & ~(uI(1) << (sizeof(uI) * 8 - 1));
119 }
120
121 while (max_value)
122 {
123 max_value >>= _BITS;
124 its++;
125 }
126
127 // Adjacency list arrays for computing insertion position
128 std::array<I, bucket_size> counter;
129 std::array<I, bucket_size + 1> offset;
130
131 uI mask_offset = 0;
132 std::vector<T> buffer(range.size());
133 std::span<T> current_perm = range;
134 std::span<T> next_perm = buffer;
135 for (I i = 0; i < its; i++)
136 {
137 // Zero counter array
138 std::ranges::fill(counter, 0);
139
140 // Count number of elements per bucket
141 for (auto c : current_perm)
142 counter[(proj(c) & mask) >> mask_offset]++;
143
144 // Prefix sum to get the inserting position
145 offset[0] = 0;
146 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
147 for (auto c : current_perm)
148 {
149 uI bucket = (proj(c) & mask) >> mask_offset;
150 uI new_pos = offset[bucket + 1] - counter[bucket];
151 next_perm[new_pos] = c;
152 counter[bucket]--;
153 }
154
155 mask = mask << _BITS;
156 mask_offset += _BITS;
157
158 std::swap(current_perm, next_perm);
159 }
160
161 // Copy data back to array
162 if (its % 2 != 0)
163 std::ranges::copy(buffer, range.begin());
164}
165
176template <typename T, int BITS = 16>
177std::vector<std::int32_t> sort_by_perm(std::span<const T> x, std::size_t shape1)
178{
179 static_assert(std::is_integral_v<T>, "Integral required.");
180
181 if (x.empty())
182 return std::vector<std::int32_t>{};
183
184 assert(shape1 > 0);
185 assert(x.size() % shape1 == 0);
186 const std::size_t shape0 = x.size() / shape1;
187 std::vector<std::int32_t> perm(shape0);
188 std::iota(perm.begin(), perm.end(), 0);
189
190 // Sort by each column, right to left. Col 0 has the most significant
191 // "digit".
192 std::vector<T> column(shape0);
193 for (std::size_t i = 0; i < shape1; ++i)
194 {
195 std::size_t col = shape1 - 1 - i;
196 for (std::size_t j = 0; j < shape0; ++j)
197 column[j] = x[j * shape1 + col];
198 radix_sort<BITS>(perm, [column = std::cref(column)](auto index)
199 { return column.get()[index]; });
200 }
201
202 return perm;
203}
204
205} // namespace dolfinx
Top-level namespace.
Definition defines.h:12
constexpr __unsigned_projection unsigned_projection
Projection from signed to signed int.
Definition sort.h:50
constexpr void radix_sort(R &&range, P proj={})
Sort a range with radix sorting algorithm. The bucket size is determined by the number of bits to sor...
Definition sort.h:78
std::vector< std::int32_t > sort_by_perm(std::span< const T > x, std::size_t shape1)
Compute the permutation array that sorts a 2D array by row.
Definition sort.h:177