12 #include <dolfinx/common/Timer.h>
15 #include <xtensor/xtensor.hpp>
16 #include <xtensor/xview.hpp>
17 #include <xtl/xspan.hpp>
27 template <
typename T,
int BITS = 8>
28 void radix_sort(xtl::span<T> array)
30 static_assert(std::is_integral<T>(),
"This function only sorts integers.");
32 if (array.size() <= 1)
35 T max_value = *std::max_element(array.begin(), array.end());
38 constexpr
int bucket_size = 1 << BITS;
39 T mask = (T(1) << BITS) - 1;
51 std::array<std::int32_t, bucket_size> counter;
52 std::array<std::int32_t, bucket_size + 1> offset;
54 std::int32_t mask_offset = 0;
55 std::vector<T> buffer(array.size());
56 xtl::span<T> current_perm = array;
57 xtl::span<T> next_perm = buffer;
58 for (
int i = 0; i < its; i++)
61 std::fill(counter.begin(), counter.end(), 0);
64 for (T c : current_perm)
65 counter[(c & mask) >> mask_offset]++;
69 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
70 for (T c : current_perm)
72 std::int32_t bucket = (c & mask) >> mask_offset;
73 std::int32_t new_pos = offset[bucket + 1] - counter[bucket];
74 next_perm[new_pos] = c;
81 std::swap(current_perm, next_perm);
86 std::copy(buffer.begin(), buffer.end(), array.begin());
98 template <
typename T,
int BITS = 16>
99 void argsort_radix(
const xtl::span<const T>& array,
100 xtl::span<std::int32_t> perm)
102 if (array.size() <= 1)
105 const auto [min, max] = std::minmax_element(array.begin(), array.end());
106 T range = *max - *min + 1;
109 constexpr
int bucket_size = 1 << BITS;
110 T mask = (T(1) << BITS) - 1;
111 std::int32_t mask_offset = 0;
123 std::array<std::int32_t, bucket_size> counter;
124 std::array<std::int32_t, bucket_size + 1> offset;
126 std::vector<std::int32_t> perm2(perm.size());
127 xtl::span<std::int32_t> current_perm = perm;
128 xtl::span<std::int32_t> next_perm = perm2;
129 for (
int i = 0; i < its; i++)
132 std::fill(counter.begin(), counter.end(), 0);
135 for (
auto cp : current_perm)
137 T value = array[cp] - *min;
138 std::int32_t bucket = (value & mask) >> mask_offset;
144 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
147 for (
auto cp : current_perm)
149 T value = array[cp] - *min;
150 std::int32_t bucket = (value & mask) >> mask_offset;
151 std::int32_t pos = offset[bucket + 1] - counter[bucket];
156 std::swap(current_perm, next_perm);
163 std::copy(perm2.begin(), perm2.end(), perm.begin());
166 template <
typename T,
int BITS = 16>
167 std::vector<std::int32_t> sort_by_perm(
const xt::xtensor<T, 2>& array)
170 const int cols = array.shape(1);
171 const int size = array.shape(0);
172 std::vector<std::int32_t> perm(size);
173 std::iota(perm.begin(), perm.end(), 0);
177 for (
int i = 0; i < cols; i++)
179 int col = cols - 1 - i;
180 xt::xtensor<std::int32_t, 1> column = xt::view(array, xt::all(), col);
181 argsort_radix<std::int32_t, BITS>(xtl::span<const std::int32_t>(column),
int size(MPI_Comm comm)
Return size of the group (number of processes) associated with the communicator.
Definition: MPI.cpp:82