Skip to content

Commit

Permalink
[coll] Add nccl. (#9726)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 28, 2023
1 parent 0c62109 commit 6755179
Show file tree
Hide file tree
Showing 19 changed files with 922 additions and 109 deletions.
45 changes: 28 additions & 17 deletions src/collective/allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <numeric> // for partial_sum
#include <vector> // for vector

#include "broadcast.h"
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span

namespace xgboost::collective::cpu_impl {
namespace xgboost::collective {
namespace cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch) {
auto world = comm.World();
auto rank = comm.Rank();
CHECK_LT(worker_off, world);
if (world == 1) {
return Success();
}

for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world;
Expand All @@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return Success();
}

Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv) {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm.World(); ++r) {
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
}
} // namespace cpu_impl

namespace detail {
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
if (world == 1) {
return Success();
}
auto rank = comm.Rank();

auto prev = BootstrapPrev(rank, comm.World());
Expand All @@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);

// get worker offset
CHECK_EQ(world + 1, offset.size());
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);

// copy data
auto current = erased_result.subspan(offset[rank], data.size_bytes());
auto erased_data = EraseType(data);
std::copy_n(erased_data.data(), erased_data.size(), current.data());

for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
Expand All @@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl
} // namespace detail
} // namespace xgboost::collective
45 changes: 35 additions & 10 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,46 @@
#include <type_traits> // for remove_cv_t
#include <vector> // for vector

#include "../common/type.h" // for EraseType
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
#include "xgboost/linalg.h"
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
* = 1, then it owns the third segment.
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment.
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);

/**
* @brief Implement allgather-v using broadcast.
*
* https://arxiv.org/abs/1812.05964
*/
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv);
} // namespace cpu_impl

namespace detail {
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> offset) {
// get worker offset
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
}

// An implementation that's used by both cpu and gpu
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl
} // namespace detail

template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
Expand Down Expand Up @@ -68,9 +87,15 @@ template <typename T>
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = common::EraseType(h_result);
auto erased_data = common::EraseType(data);
std::vector<std::int64_t> offset(world + 1);
std::vector<std::int64_t> recv_segments(world + 1);
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};

// get worker offset
detail::AllgatherVOffset(sizes, s_segments);
// copy data
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
std::copy_n(erased_data.data(), erased_data.size(), current.data());

return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
common::Span{offset.data(), offset.size()}, erased_result);
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}
} // namespace xgboost::collective
54 changes: 38 additions & 16 deletions src/collective/coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus

#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "xgboost/context.h" // for Context
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm

namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
Op op) {
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type, Op op) {
namespace coll = ::xgboost::collective;

auto redop_fn = [](auto lhs, auto out, auto elem_op) {
Expand Down Expand Up @@ -55,21 +53,45 @@ namespace xgboost::collective {
return comm.Block();
}

[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root) {
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root);
}

[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size) {
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
return RingAllgather(comm, data, size);
}

[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
common::Span<std::int8_t const> data,
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);

// copy data
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
if (current.data() != data.data()) {
std::copy_n(data.data(), data.size(), current.data());
}

switch (algo) {
case AllgatherVAlgo::kRing:
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
case AllgatherVAlgo::kBcast:
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
default: {
return Fail("Unknown algorithm for allgather-v");
}
}
}

#if !defined(XGBOOST_USE_NCCL)
Coll* Coll::MakeCUDAVar() {
LOG(FATAL) << "NCCL is required for device communication.";
return nullptr;
}
#endif

} // namespace xgboost::collective
Loading

0 comments on commit 6755179

Please sign in to comment.