DOLFINx 0.8.0
DOLFINx C++ interface
Loading...
Searching...
No Matches
sort.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <algorithm>
10#include <bitset>
11#include <cstdint>
12#include <dolfinx/common/Timer.h>
13#include <numeric>
14#include <span>
15#include <type_traits>
16#include <vector>
17
18namespace dolfinx
19{
20
26template <typename T, int BITS = 8>
27void radix_sort(std::span<T> array)
28{
29 static_assert(std::is_integral<T>(), "This function only sorts integers.");
30
31 if (array.size() <= 1)
32 return;
33
34 T max_value = *std::max_element(array.begin(), array.end());
35
36 // Sort N bits at a time
37 constexpr int bucket_size = 1 << BITS;
38 T mask = (T(1) << BITS) - 1;
39
40 // Compute number of iterations, most significant digit (N bits) of
41 // maxvalue
42 int its = 0;
43 while (max_value)
44 {
45 max_value >>= BITS;
46 its++;
47 }
48
49 // Adjacency list arrays for computing insertion position
50 std::array<std::int32_t, bucket_size> counter;
51 std::array<std::int32_t, bucket_size + 1> offset;
52
53 std::int32_t mask_offset = 0;
54 std::vector<T> buffer(array.size());
55 std::span<T> current_perm = array;
56 std::span<T> next_perm = buffer;
57 for (int i = 0; i < its; i++)
58 {
59 // Zero counter array
60 std::fill(counter.begin(), counter.end(), 0);
61
62 // Count number of elements per bucket
63 for (T c : current_perm)
64 counter[(c & mask) >> mask_offset]++;
65
66 // Prefix sum to get the inserting position
67 offset[0] = 0;
68 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
69 for (T c : current_perm)
70 {
71 std::int32_t bucket = (c & mask) >> mask_offset;
72 std::int32_t new_pos = offset[bucket + 1] - counter[bucket];
73 next_perm[new_pos] = c;
74 counter[bucket]--;
75 }
76
77 mask = mask << BITS;
78 mask_offset += BITS;
79
80 std::swap(current_perm, next_perm);
81 }
82
83 // Copy data back to array
84 if (its % 2 != 0)
85 std::copy(buffer.begin(), buffer.end(), array.begin());
86}
87
96template <typename T, int BITS = 16>
97void argsort_radix(std::span<const T> array, std::span<std::int32_t> perm)
98{
99 static_assert(std::is_integral_v<T>, "Integral required.");
100
101 if (array.size() <= 1)
102 return;
103
104 const auto [min, max] = std::minmax_element(array.begin(), array.end());
105 T range = *max - *min + 1;
106
107 // Sort N bits at a time
108 constexpr int bucket_size = 1 << BITS;
109 T mask = (T(1) << BITS) - 1;
110 std::int32_t mask_offset = 0;
111
112 // Compute number of iterations, most significant digit (N bits) of
113 // maxvalue
114 int its = 0;
115 while (range)
116 {
117 range >>= BITS;
118 its++;
119 }
120
121 // Adjacency list arrays for computing insertion position
122 std::array<std::int32_t, bucket_size> counter;
123 std::array<std::int32_t, bucket_size + 1> offset;
124
125 std::vector<std::int32_t> perm2(perm.size());
126 std::span<std::int32_t> current_perm = perm;
127 std::span<std::int32_t> next_perm = perm2;
128 for (int i = 0; i < its; i++)
129 {
130 // Zero counter
131 std::fill(counter.begin(), counter.end(), 0);
132
133 // Count number of elements per bucket
134 for (auto cp : current_perm)
135 {
136 T value = array[cp] - *min;
137 std::int32_t bucket = (value & mask) >> mask_offset;
138 counter[bucket]++;
139 }
140
141 // Prefix sum to get the inserting position
142 offset[0] = 0;
143 std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
144
145 // Sort py permutation
146 for (auto cp : current_perm)
147 {
148 T value = array[cp] - *min;
149 std::int32_t bucket = (value & mask) >> mask_offset;
150 std::int32_t pos = offset[bucket + 1] - counter[bucket];
151 next_perm[pos] = cp;
152 counter[bucket]--;
153 }
154
155 std::swap(current_perm, next_perm);
156
157 mask = mask << BITS;
158 mask_offset += BITS;
159 }
160
161 if (its % 2 == 1)
162 std::copy(perm2.begin(), perm2.end(), perm.begin());
163}
164
174template <typename T, int BITS = 16>
175std::vector<std::int32_t> sort_by_perm(std::span<const T> x, std::size_t shape1)
176{
177 static_assert(std::is_integral_v<T>, "Integral required.");
178 assert(shape1 > 0);
179 assert(x.size() % shape1 == 0);
180 const std::size_t shape0 = x.size() / shape1;
181 std::vector<std::int32_t> perm(shape0);
182 std::iota(perm.begin(), perm.end(), 0);
183
184 // Sort by each column, right to left. Col 0 has the most significant
185 // "digit".
186 std::vector<T> column(shape0);
187 for (std::size_t i = 0; i < shape1; ++i)
188 {
189 int col = shape1 - 1 - i;
190 for (std::size_t j = 0; j < shape0; ++j)
191 column[j] = x[j * shape1 + col];
192 argsort_radix<T, BITS>(column, perm);
193 }
194
195 return perm;
196}
197
198} // namespace dolfinx
Top-level namespace.
Definition defines.h:12
std::vector< std::int32_t > sort_by_perm(std::span< const T > x, std::size_t shape1)
Compute the permutation array that sorts a 2D array by row.
Definition sort.h:175
void argsort_radix(std::span< const T > array, std::span< std::int32_t > perm)
Definition sort.h:97
void radix_sort(std::span< T > array)
Definition sort.h:27