52 std::ranges::random_access_range R,
typename P = std::identity,
53 std::remove_cvref_t<std::invoke_result_t<P, std::iter_value_t<R>>> BITS
55 requires std::integral<
decltype(BITS)>
56 constexpr void operator()(R&& range, P proj = {})
const
59 using T = std::iter_value_t<R>;
62 using I = std::remove_cvref_t<std::invoke_result_t<P, T>>;
64 if (range.size() <= 1)
67 T max_value = proj(*std::ranges::max_element(range, std::less{}, proj));
70 constexpr I bucket_size = 1 << BITS;
71 T mask = (T(1) << BITS) - 1;
83 std::array<I, bucket_size> counter;
84 std::array<I, bucket_size + 1> offset;
87 std::vector<T> buffer(range.size());
88 std::span<T> current_perm = range;
89 std::span<T> next_perm = buffer;
90 for (I i = 0; i < its; i++)
93 std::ranges::fill(counter, 0);
96 for (
const auto& c : current_perm)
97 counter[(proj(c) & mask) >> mask_offset]++;
101 std::partial_sum(counter.begin(), counter.end(),
102 std::next(offset.begin()));
103 for (
const auto& c : current_perm)
105 I bucket = (proj(c) & mask) >> mask_offset;
106 I new_pos = offset[bucket + 1] - counter[bucket];
107 next_perm[new_pos] = c;
114 std::swap(current_perm, next_perm);
119 std::ranges::copy(buffer, range.begin());
135template <
typename T,
int BITS = 16>
136std::vector<std::int32_t>
sort_by_perm(std::span<const T> x, std::size_t shape1)
138 static_assert(std::is_integral_v<T>,
"Integral required.");
140 assert(x.size() % shape1 == 0);
141 const std::size_t shape0 = x.size() / shape1;
142 std::vector<std::int32_t> perm(shape0);
143 std::iota(perm.begin(), perm.end(), 0);
147 std::vector<T> column(shape0);
148 for (std::size_t i = 0; i < shape1; ++i)
150 int col = shape1 - 1 - i;
151 for (std::size_t j = 0; j < shape0; ++j)
152 column[j] = x[j * shape1 + col];
154 radix_sort(perm, [&column](
auto index) {
return column[index]; });
Top-level namespace.
Definition defines.h:12
std::vector< std::int32_t > sort_by_perm(std::span< const T > x, std::size_t shape1)
Compute the permutation array that sorts a 2D array by row.
Definition sort.h:136
constexpr __radix_sort radix_sort
Radix sort.
Definition sort.h:124