DOLFINx 0.9.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
sort.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <algorithm>
10#include <cassert>
11#include <concepts>
12#include <cstdint>
13#include <functional>
14#include <iterator>
15#include <numeric>
16#include <span>
17#include <type_traits>
18#include <utility>
19#include <vector>
20
21namespace dolfinx
22{
23
24struct __radix_sort
25{
26
51 template <
52 std::ranges::random_access_range R, typename P = std::identity,
53 std::remove_cvref_t<std::invoke_result_t<P, std::iter_value_t<R>>> BITS
54 = 8>
55 requires std::integral<decltype(BITS)>
56 constexpr void operator()(R&& range, P proj = {}) const
57 {
58 // value type
59 using T = std::iter_value_t<R>;
60
61 // index type (if no projection is provided it holds I == T)
62 using I = std::remove_cvref_t<std::invoke_result_t<P, T>>;
63
64 if (range.size() <= 1)
65 return;
66
67 T max_value = proj(*std::ranges::max_element(range, std::less{}, proj));
68
69 // Sort N bits at a time
70 constexpr I bucket_size = 1 << BITS;
71 T mask = (T(1) << BITS) - 1;
72
73 // Compute number of iterations, most significant digit (N bits) of
74 // maxvalue
75 I its = 0;
76 while (max_value)
77 {
78 max_value >>= BITS;
79 its++;
80 }
81
82 // Adjacency list arrays for computing insertion position
83 std::array<I, bucket_size> counter;
84 std::array<I, bucket_size + 1> offset;
85
86 I mask_offset = 0;
87 std::vector<T> buffer(range.size());
88 std::span<T> current_perm = range;
89 std::span<T> next_perm = buffer;
90 for (I i = 0; i < its; i++)
91 {
92 // Zero counter array
93 std::ranges::fill(counter, 0);
94
95 // Count number of elements per bucket
96 for (const auto& c : current_perm)
97 counter[(proj(c) & mask) >> mask_offset]++;
98
99 // Prefix sum to get the inserting position
100 offset[0] = 0;
101 std::partial_sum(counter.begin(), counter.end(),
102 std::next(offset.begin()));
103 for (const auto& c : current_perm)
104 {
105 I bucket = (proj(c) & mask) >> mask_offset;
106 I new_pos = offset[bucket + 1] - counter[bucket];
107 next_perm[new_pos] = c;
108 counter[bucket]--;
109 }
110
111 mask = mask << BITS;
112 mask_offset += BITS;
113
114 std::swap(current_perm, next_perm);
115 }
116
117 // Copy data back to array
118 if (its % 2 != 0)
119 std::ranges::copy(buffer, range.begin());
120 }
121};
122
124inline constexpr __radix_sort radix_sort{};
125
135template <typename T, int BITS = 16>
136std::vector<std::int32_t> sort_by_perm(std::span<const T> x, std::size_t shape1)
137{
138 static_assert(std::is_integral_v<T>, "Integral required.");
139 assert(shape1 > 0);
140 assert(x.size() % shape1 == 0);
141 const std::size_t shape0 = x.size() / shape1;
142 std::vector<std::int32_t> perm(shape0);
143 std::iota(perm.begin(), perm.end(), 0);
144
145 // Sort by each column, right to left. Col 0 has the most significant
146 // "digit".
147 std::vector<T> column(shape0);
148 for (std::size_t i = 0; i < shape1; ++i)
149 {
150 int col = shape1 - 1 - i;
151 for (std::size_t j = 0; j < shape0; ++j)
152 column[j] = x[j * shape1 + col];
153
154 radix_sort(perm, [&column](auto index) { return column[index]; });
155 }
156
157 return perm;
158}
159
160} // namespace dolfinx
Top-level namespace.
Definition defines.h:12
std::vector< std::int32_t > sort_by_perm(std::span< const T > x, std::size_t shape1)
Compute the permutation array that sorts a 2D array by row.
Definition sort.h:136
constexpr __radix_sort radix_sort
Radix sort.
Definition sort.h:124