Note: this is documentation for an old release. View the latest documentation at docs.fenicsproject.org/dolfinx/v0.9.0/cpp/doxygen/d7/d50/sort_8h_source.html
DOLFINx  0.5.1
DOLFINx C++ interface
sort.h
1 // Copyright (C) 2021 Igor Baratta
2 //
3 // This file is part of DOLFINx (https://www.fenicsproject.org)
4 //
5 // SPDX-License-Identifier: LGPL-3.0-or-later
6 
7 #pragma once
8 
9 #include <algorithm>
10 #include <bitset>
11 #include <cstdint>
12 #include <dolfinx/common/Timer.h>
13 #include <numeric>
14 #include <span>
15 #include <type_traits>
16 #include <vector>
17 
18 namespace dolfinx
19 {
20 
26 template <typename T, int BITS = 8>
27 void radix_sort(const std::span<T>& array)
28 {
29  static_assert(std::is_integral<T>(), "This function only sorts integers.");
30 
31  if (array.size() <= 1)
32  return;
33 
34  T max_value = *std::max_element(array.begin(), array.end());
35 
36  // Sort N bits at a time
37  constexpr int bucket_size = 1 << BITS;
38  T mask = (T(1) << BITS) - 1;
39 
40  // Compute number of iterations, most significant digit (N bits) of
41  // maxvalue
42  int its = 0;
43  while (max_value)
44  {
45  max_value >>= BITS;
46  its++;
47  }
48 
49  // Adjacency list arrays for computing insertion position
50  std::array<std::int32_t, bucket_size> counter;
51  std::array<std::int32_t, bucket_size + 1> offset;
52 
53  std::int32_t mask_offset = 0;
54  std::vector<T> buffer(array.size());
55  std::span<T> current_perm = array;
56  std::span<T> next_perm = buffer;
57  for (int i = 0; i < its; i++)
58  {
59  // Zero counter array
60  std::fill(counter.begin(), counter.end(), 0);
61 
62  // Count number of elements per bucket
63  for (T c : current_perm)
64  counter[(c & mask) >> mask_offset]++;
65 
66  // Prefix sum to get the inserting position
67  offset[0] = 0;
68  std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
69  for (T c : current_perm)
70  {
71  std::int32_t bucket = (c & mask) >> mask_offset;
72  std::int32_t new_pos = offset[bucket + 1] - counter[bucket];
73  next_perm[new_pos] = c;
74  counter[bucket]--;
75  }
76 
77  mask = mask << BITS;
78  mask_offset += BITS;
79 
80  std::swap(current_perm, next_perm);
81  }
82 
83  // Copy data back to array
84  if (its % 2 != 0)
85  std::copy(buffer.begin(), buffer.end(), array.begin());
86 }
87 
96 template <typename T, int BITS = 16>
97 void argsort_radix(const std::span<const T>& array,
98  std::span<std::int32_t> perm)
99 {
100  static_assert(std::is_integral_v<T>, "Integral required.");
101 
102  if (array.size() <= 1)
103  return;
104 
105  const auto [min, max] = std::minmax_element(array.begin(), array.end());
106  T range = *max - *min + 1;
107 
108  // Sort N bits at a time
109  constexpr int bucket_size = 1 << BITS;
110  T mask = (T(1) << BITS) - 1;
111  std::int32_t mask_offset = 0;
112 
113  // Compute number of iterations, most significant digit (N bits) of
114  // maxvalue
115  int its = 0;
116  while (range)
117  {
118  range >>= BITS;
119  its++;
120  }
121 
122  // Adjacency list arrays for computing insertion position
123  std::array<std::int32_t, bucket_size> counter;
124  std::array<std::int32_t, bucket_size + 1> offset;
125 
126  std::vector<std::int32_t> perm2(perm.size());
127  std::span<std::int32_t> current_perm = perm;
128  std::span<std::int32_t> next_perm = perm2;
129  for (int i = 0; i < its; i++)
130  {
131  // Zero counter
132  std::fill(counter.begin(), counter.end(), 0);
133 
134  // Count number of elements per bucket
135  for (auto cp : current_perm)
136  {
137  T value = array[cp] - *min;
138  std::int32_t bucket = (value & mask) >> mask_offset;
139  counter[bucket]++;
140  }
141 
142  // Prefix sum to get the inserting position
143  offset[0] = 0;
144  std::partial_sum(counter.begin(), counter.end(), std::next(offset.begin()));
145 
146  // Sort py permutation
147  for (auto cp : current_perm)
148  {
149  T value = array[cp] - *min;
150  std::int32_t bucket = (value & mask) >> mask_offset;
151  std::int32_t pos = offset[bucket + 1] - counter[bucket];
152  next_perm[pos] = cp;
153  counter[bucket]--;
154  }
155 
156  std::swap(current_perm, next_perm);
157 
158  mask = mask << BITS;
159  mask_offset += BITS;
160  }
161 
162  if (its % 2 == 1)
163  std::copy(perm2.begin(), perm2.end(), perm.begin());
164 }
165 
175 template <typename T, int BITS = 16>
176 std::vector<std::int32_t> sort_by_perm(const std::span<const T>& x,
177  std::size_t shape1)
178 {
179  static_assert(std::is_integral_v<T>, "Integral required.");
180  assert(shape1 > 0);
181  assert(x.size() % shape1 == 0);
182  const std::size_t shape0 = x.size() / shape1;
183  std::vector<std::int32_t> perm(shape0);
184  std::iota(perm.begin(), perm.end(), 0);
185 
186  // Sort by each column, right to left. Col 0 has the most signficant
187  // "digit".
188  std::vector<T> column(shape0);
189  for (std::size_t i = 0; i < shape1; ++i)
190  {
191  int col = shape1 - 1 - i;
192  for (std::size_t j = 0; j < shape0; ++j)
193  column[j] = x[j * shape1 + col];
194  argsort_radix<T, BITS>(column, perm);
195  }
196 
197  return perm;
198 }
199 
200 } // namespace dolfinx