12 #include <dolfinx/common/Timer.h>
14 #include <type_traits>
16 #include <xtensor/xtensor.hpp>
17 #include <xtensor/xview.hpp>
18 #include <xtl/xspan.hpp>
28 template <
typename T,
int BITS = 8>
29 void radix_sort(
const xtl::span<T>& array)
31 static_assert(std::is_integral<T>(),
"This function only sorts integers.");
33 if (array.size() <= 1)
36 T max_value = *std::max_element(array.begin(), array.end());
39 constexpr
int bucket_size = 1 << BITS;
40 T mask = (T(1) << BITS) - 1;
52 std::array<std::int32_t, bucket_size> counter;
53 std::array<std::int32_t, bucket_size + 1> offset;
55 std::int32_t mask_offset = 0;
56 std::vector<T> buffer(array.size());
57 xtl::span<T> current_perm = array;
58 xtl::span<T> next_perm = buffer;
59 for (
int i = 0; i < its; i++)
62 std::fill(counter.begin(), counter.end(), 0);
65 for (T c : current_perm)
66 counter[(c & mask) >> mask_offset]++;
70 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
71 for (T c : current_perm)
73 std::int32_t bucket = (c & mask) >> mask_offset;
74 std::int32_t new_pos = offset[bucket + 1] - counter[bucket];
75 next_perm[new_pos] = c;
82 std::swap(current_perm, next_perm);
87 std::copy(buffer.begin(), buffer.end(), array.begin());
98 template <
typename T,
int BITS = 16>
99 void argsort_radix(
const xtl::span<const T>& array,
100 xtl::span<std::int32_t> perm)
102 static_assert(std::is_integral<T>::value,
"Integral required.");
104 if (array.size() <= 1)
107 const auto [min, max] = std::minmax_element(array.begin(), array.end());
108 T range = *max - *min + 1;
111 constexpr
int bucket_size = 1 << BITS;
112 T mask = (T(1) << BITS) - 1;
113 std::int32_t mask_offset = 0;
125 std::array<std::int32_t, bucket_size> counter;
126 std::array<std::int32_t, bucket_size + 1> offset;
128 std::vector<std::int32_t> perm2(perm.size());
129 xtl::span<std::int32_t> current_perm = perm;
130 xtl::span<std::int32_t> next_perm = perm2;
131 for (
int i = 0; i < its; i++)
134 std::fill(counter.begin(), counter.end(), 0);
137 for (
auto cp : current_perm)
139 T value = array[cp] - *min;
140 std::int32_t bucket = (value & mask) >> mask_offset;
146 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
149 for (
auto cp : current_perm)
151 T value = array[cp] - *min;
152 std::int32_t bucket = (value & mask) >> mask_offset;
153 std::int32_t pos = offset[bucket + 1] - counter[bucket];
158 std::swap(current_perm, next_perm);
165 std::copy(perm2.begin(), perm2.end(), perm.begin());
177 template <
typename T,
int BITS = 16>
178 std::vector<std::int32_t> sort_by_perm(
const xtl::span<const T>& x,
181 static_assert(std::is_integral<T>::value,
"Integral required.");
183 assert(x.size() % shape1 == 0);
184 const std::size_t shape0 = x.size() / shape1;
185 std::vector<std::int32_t> perm(shape0);
186 std::iota(perm.begin(), perm.end(), 0);
190 std::vector<T> column(shape0);
191 for (std::size_t i = 0; i < shape1; ++i)
193 int col = shape1 - 1 - i;
194 for (std::size_t j = 0; j < shape0; ++j)
195 column[j] = x[j * shape1 + col];
196 argsort_radix<T, BITS>(column, perm);