315 std::array<std::int64_t, 2> shape,
316 std::int64_t rank_offset)
318 assert(rank_offset >= 0 or x.empty());
319 using T =
typename std::remove_reference_t<typename U::value_type>;
323 assert(x.size() % shape[1] == 0);
324 const std::int32_t shape0_local = x.size() / shape[1];
326 spdlog::debug(
"Sending data to post offices (distribute_to_postoffice)");
329 std::vector<int> row_to_dest(shape0_local);
330 for (std::int32_t i = 0; i < shape0_local; ++i)
333 row_to_dest[i] = dest;
338 std::vector<std::array<std::int32_t, 2>> dest_to_index;
339 dest_to_index.reserve(shape0_local);
340 for (std::int32_t i = 0; i < shape0_local; ++i)
342 std::size_t idx = i + rank_offset;
344 dest_to_index.push_back({dest, i});
346 std::ranges::sort(dest_to_index);
350 std::vector<int> dest;
351 std::vector<std::int32_t> num_items_per_dest,
352 pos_to_neigh_rank(shape0_local, -1);
354 auto it = dest_to_index.begin();
355 while (it != dest_to_index.end())
357 const int neigh_rank = dest.size();
360 dest.push_back((*it)[0]);
364 = std::find_if(it, dest_to_index.end(),
365 [r = dest.back()](
auto& idx) { return idx[0] != r; });
368 num_items_per_dest.push_back(std::distance(it, it1));
371 for (
auto e = it; e != it1; ++e)
372 pos_to_neigh_rank[(*e)[1]] = neigh_rank;
382 "Number of neighbourhood source ranks in distribute_to_postoffice: {}",
383 static_cast<int>(src.size()));
387 int err = MPI_Dist_graph_create_adjacent(
388 comm, src.size(), src.data(), MPI_UNWEIGHTED, dest.size(), dest.data(),
389 MPI_UNWEIGHTED, MPI_INFO_NULL,
false, &neigh_comm);
393 std::vector<std::int32_t> send_disp = {0};
394 std::partial_sum(num_items_per_dest.begin(), num_items_per_dest.end(),
395 std::back_inserter(send_disp));
398 std::vector<T> send_buffer_data(shape[1] * send_disp.back());
399 std::vector<std::int64_t> send_buffer_index(send_disp.back());
401 std::vector<std::int32_t> send_offsets = send_disp;
402 for (std::int32_t i = 0; i < shape0_local; ++i)
404 if (
int neigh_dest = pos_to_neigh_rank[i]; neigh_dest != -1)
406 std::size_t pos = send_offsets[neigh_dest];
407 send_buffer_index[pos] = i + rank_offset;
408 std::copy_n(std::next(x.begin(), i * shape[1]), shape[1],
409 std::next(send_buffer_data.begin(), shape[1] * pos));
410 ++send_offsets[neigh_dest];
417 std::vector<int> num_items_recv(src.size());
418 num_items_per_dest.reserve(1);
419 num_items_recv.reserve(1);
420 err = MPI_Neighbor_alltoall(num_items_per_dest.data(), 1, MPI_INT,
421 num_items_recv.data(), 1, MPI_INT, neigh_comm);
425 std::vector<std::int32_t> recv_disp(num_items_recv.size() + 1, 0);
426 std::partial_sum(num_items_recv.begin(), num_items_recv.end(),
427 std::next(recv_disp.begin()));
430 std::vector<std::int64_t> recv_buffer_index(recv_disp.back());
431 err = MPI_Neighbor_alltoallv(
432 send_buffer_index.data(), num_items_per_dest.data(), send_disp.data(),
433 MPI_INT64_T, recv_buffer_index.data(), num_items_recv.data(),
434 recv_disp.data(), MPI_INT64_T, neigh_comm);
438 MPI_Datatype compound_type;
440 MPI_Type_commit(&compound_type);
441 std::vector<T> recv_buffer_data(shape[1] * recv_disp.back());
442 err = MPI_Neighbor_alltoallv(
443 send_buffer_data.data(), num_items_per_dest.data(), send_disp.data(),
444 compound_type, recv_buffer_data.data(), num_items_recv.data(),
445 recv_disp.data(), compound_type, neigh_comm);
447 err = MPI_Type_free(&compound_type);
449 err = MPI_Comm_free(&neigh_comm);
452 spdlog::debug(
"Completed send data to post offices.");
456 std::vector<std::int32_t> index_local(recv_buffer_index.size());
457 std::ranges::transform(recv_buffer_index, index_local.begin(),
458 [r0](
auto idx) { return idx - r0; });
460 return {index_local, recv_buffer_data};
466 const U& x, std::array<std::int64_t, 2> shape,
467 std::int64_t rank_offset)
469 assert(rank_offset >= 0 or x.empty());
470 using T =
typename std::remove_reference_t<typename U::value_type>;
473 assert(shape[1] > 0);
477 assert(x.size() % shape[1] == 0);
478 const std::int64_t shape0_local = x.size() / shape[1];
485 comm, x, {shape[0], shape[1]}, rank_offset);
486 assert(post_indices.size() == post_x.size() / shape[1]);
492 std::vector<std::tuple<int, std::int64_t, std::int32_t>> src_to_index;
493 for (std::size_t i = 0; i < indices.size(); ++i)
495 std::size_t idx = indices[i];
497 src_to_index.push_back({src, idx, i});
499 std::ranges::sort(src_to_index);
503 std::vector<std::int32_t> num_items_per_src;
504 std::vector<int> src;
506 auto it = src_to_index.begin();
507 while (it != src_to_index.end())
509 src.push_back(std::get<0>(*it));
511 = std::find_if(it, src_to_index.end(), [r = src.back()](
auto& idx)
512 { return std::get<0>(idx) != r; });
513 num_items_per_src.push_back(std::distance(it, it1));
520 const std::vector<int> dest
523 "Neighbourhood destination ranks from post office in "
524 "distribute_data (rank, num dests, num dests/mpi_size): {}, {}, {}",
525 rank,
static_cast<int>(dest.size()),
526 static_cast<double>(dest.size()) /
size);
530 MPI_Comm neigh_comm0;
531 int err = MPI_Dist_graph_create_adjacent(
532 comm, dest.size(), dest.data(), MPI_UNWEIGHTED, src.size(), src.data(),
533 MPI_UNWEIGHTED, MPI_INFO_NULL,
false, &neigh_comm0);
537 std::vector<int> num_items_recv(dest.size());
538 num_items_per_src.reserve(1);
539 num_items_recv.reserve(1);
540 err = MPI_Neighbor_alltoall(num_items_per_src.data(), 1, MPI_INT,
541 num_items_recv.data(), 1, MPI_INT, neigh_comm0);
545 std::vector<std::int32_t> send_disp = {0};
546 std::partial_sum(num_items_per_src.begin(), num_items_per_src.end(),
547 std::back_inserter(send_disp));
548 std::vector<std::int32_t> recv_disp = {0};
549 std::partial_sum(num_items_recv.begin(), num_items_recv.end(),
550 std::back_inserter(recv_disp));
554 assert(send_disp.back() == (
int)src_to_index.size());
555 std::vector<std::int64_t> send_buffer_index(src_to_index.size());
556 std::ranges::transform(src_to_index, send_buffer_index.begin(),
557 [](
auto x) { return std::get<1>(x); });
560 std::vector<std::int64_t> recv_buffer_index(recv_disp.back());
561 err = MPI_Neighbor_alltoallv(
562 send_buffer_index.data(), num_items_per_src.data(), send_disp.data(),
563 MPI_INT64_T, recv_buffer_index.data(), num_items_recv.data(),
564 recv_disp.data(), MPI_INT64_T, neigh_comm0);
567 err = MPI_Comm_free(&neigh_comm0);
576 const std::array<std::int64_t, 2> postoffice_range
578 std::vector<std::int32_t> post_indices_map(
579 postoffice_range[1] - postoffice_range[0], -1);
580 for (std::size_t i = 0; i < post_indices.size(); ++i)
582 assert(post_indices[i] < (
int)post_indices_map.size());
583 post_indices_map[post_indices[i]] = i;
587 std::vector<T> send_buffer_data(shape[1] * recv_disp.back());
588 for (std::size_t p = 0; p < recv_disp.size() - 1; ++p)
590 int offset = recv_disp[p];
591 for (std::int32_t i = recv_disp[p]; i < recv_disp[p + 1]; ++i)
593 std::int64_t index = recv_buffer_index[i];
594 if (index >= rank_offset and index < (rank_offset + shape0_local))
597 std::int32_t local_index = index - rank_offset;
598 std::copy_n(std::next(x.begin(), shape[1] * local_index), shape[1],
599 std::next(send_buffer_data.begin(), shape[1] * offset));
604 auto local_index = index - postoffice_range[0];
605 std::int32_t pos = post_indices_map[local_index];
607 std::copy_n(std::next(post_x.begin(), shape[1] * pos), shape[1],
608 std::next(send_buffer_data.begin(), shape[1] * offset));
615 err = MPI_Dist_graph_create_adjacent(
616 comm, src.size(), src.data(), MPI_UNWEIGHTED, dest.size(), dest.data(),
617 MPI_UNWEIGHTED, MPI_INFO_NULL,
false, &neigh_comm0);
620 MPI_Datatype compound_type0;
622 MPI_Type_commit(&compound_type0);
624 std::vector<T> recv_buffer_data(shape[1] * send_disp.back());
625 err = MPI_Neighbor_alltoallv(
626 send_buffer_data.data(), num_items_recv.data(), recv_disp.data(),
627 compound_type0, recv_buffer_data.data(), num_items_per_src.data(),
628 send_disp.data(), compound_type0, neigh_comm0);
631 err = MPI_Type_free(&compound_type0);
633 err = MPI_Comm_free(&neigh_comm0);
636 std::vector<std::int32_t> index_pos_to_buffer(indices.size(), -1);
637 for (std::size_t i = 0; i < src_to_index.size(); ++i)
638 index_pos_to_buffer[std::get<2>(src_to_index[i])] = i;
641 std::vector<T> x_new(shape[1] * indices.size());
642 for (std::size_t i = 0; i < indices.size(); ++i)
644 const std::int64_t index = indices[i];
645 if (index >= rank_offset and index < (rank_offset + shape0_local))
648 auto local_index = index - rank_offset;
649 std::copy_n(std::next(x.begin(), shape[1] * local_index), shape[1],
650 std::next(x_new.begin(), shape[1] * i));
658 auto local_index = index - postoffice_range[0];
659 std::int32_t pos = post_indices_map[local_index];
661 std::copy_n(std::next(post_x.begin(), shape[1] * pos), shape[1],
662 std::next(x_new.begin(), shape[1] * i));
667 std::int32_t pos = index_pos_to_buffer[i];
669 std::copy_n(std::next(recv_buffer_data.begin(), shape[1] * pos),
670 shape[1], std::next(x_new.begin(), shape[1] * i));