From b771f584534906a8b6c3027c223096519498134a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 20 Oct 2023 16:20:48 +0800 Subject: [PATCH] [coll] Define interface for bridging. (#9695) * Define the basic interface that will shared by nccl, federated and native. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + src/collective/allgather.cc | 6 ++- src/collective/allgather.h | 5 +- src/collective/coll.cc | 75 ++++++++++++++++++++++++++ src/collective/coll.h | 66 +++++++++++++++++++++++ tests/cpp/collective/test_allreduce.cc | 23 ++++++++ 7 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 src/collective/coll.cc create mode 100644 src/collective/coll.h diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 37511ec62b70..8af5dbbf647a 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -102,6 +102,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 611cff8742dc..60f754fef47e 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -102,6 +102,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \ diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 378a06911f20..a51b79fbc956 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -3,7 +3,7 @@ */ #include "allgather.h" -#include // for min, copy_n +#include // for min, copy_n, fill_n #include // for size_t #include // for int8_t, int32_t, int64_t #include // for shared_ptr @@ -45,6 +45,7 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, common::Span data, + common::Span offset, common::Span erased_result) { auto world = comm.World(); auto rank = comm.Rank(); @@ -56,7 +57,8 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size auto next_ch = comm.Chan(next); // get worker offset - std::vector offset(world + 1, 0); + CHECK_EQ(world + 1, offset.size()); + std::fill_n(offset.data(), offset.size(), 0); std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1); CHECK_EQ(*offset.cbegin(), 0); diff --git a/src/collective/allgather.h b/src/collective/allgather.h index cb5f5b8afe41..967187cebb17 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -26,6 +26,7 @@ namespace cpu_impl { [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, common::Span data, + common::Span offset, common::Span erased_result); } // namespace cpu_impl @@ -66,7 +67,9 @@ template auto h_result = common::Span{result.data(), result.size()}; auto erased_result = EraseType(h_result); auto erased_data = EraseType(data); + std::vector offset(world + 1); - return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result); + return cpu_impl::RingAllgatherV(comm, sizes, erased_data, + common::Span{offset.data(), offset.size()}, erased_result); } } // namespace xgboost::collective diff --git a/src/collective/coll.cc b/src/collective/coll.cc new file mode 100644 index 000000000000..6682e57ffdae --- /dev/null +++ b/src/collective/coll.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "coll.h" + +#include // for min, max +#include // for size_t +#include // for int8_t, int64_t +#include // for bit_and, bit_or, bit_xor, plus + +#include "allgather.h" // for RingAllgatherV, RingAllgather +#include "allreduce.h" // for Allreduce +#include "broadcast.h" // for Broadcast +#include "comm.h" // for Comm +#include "xgboost/context.h" // for Context + +namespace xgboost::collective { +[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm, + common::Span data, ArrayInterfaceHandler::Type, + Op op) { + namespace coll = ::xgboost::collective; + + auto redop_fn = [](auto lhs, auto out, auto elem_op) { + auto p_lhs = lhs.data(); + auto p_out = out.data(); + for (std::size_t i = 0; i < lhs.size(); ++i) { + p_out[i] = elem_op(p_lhs[i], p_out[i]); + } + }; + auto fn = [&](auto elem_op) { + return coll::Allreduce( + comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); }); + }; + + switch (op) { + case Op::kMax: { + return fn([](auto l, auto r) { return std::max(l, r); }); + } + case Op::kMin: { + return fn([](auto l, auto r) { return std::min(l, r); }); + } + case Op::kSum: { + return fn(std::plus<>{}); + } + case Op::kBitwiseAND: { + return fn(std::bit_and<>{}); + } + case Op::kBitwiseOR: { + return fn(std::bit_or<>{}); + } + case Op::kBitwiseXOR: { + return fn(std::bit_xor<>{}); + } + } + return comm.Block(); +} + +[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm, + common::Span data, std::int32_t root) { + return cpu_impl::Broadcast(comm, data, root); +} + +[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm, + common::Span data, std::size_t size) { + return RingAllgather(comm, data, size); +} + +[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm, + common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv) { + return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv); +} +} // namespace xgboost::collective diff --git a/src/collective/coll.h b/src/collective/coll.h new file mode 100644 index 000000000000..9a318db8dcbd --- /dev/null +++ b/src/collective/coll.h @@ -0,0 +1,66 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for size_t +#include // for int8_t, int64_t +#include // for enable_shared_from_this + +#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "comm.h" // for Comm +#include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +/** + * @brief Interface and base implementation for collective. + */ +class Coll : public std::enable_shared_from_this { + public: + Coll() = default; + virtual ~Coll() noexcept(false) {} // NOLINT + + /** + * @brief Allreduce + * + * @param [in,out] data Data buffer for input and output. + * @param [in] type data type. + * @param [in] op Reduce operation. For custom operation, user needs to reach down to + * the CPU implementation. + */ + [[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm, + common::Span data, + ArrayInterfaceHandler::Type type, Op op); + /** + * @brief Broadcast + * + * @param [in,out] data Data buffer for input and output. + * @param [in] root Root rank for broadcast. + */ + [[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm, + common::Span data, std::int32_t root); + /** + * @brief Allgather + * + * @param [in,out] data Data buffer for input and output. + * @param [in] size Size of data for each worker. + */ + [[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm, + common::Span data, std::size_t size); + /** + * @brief Allgather with variable length. + * + * @param [in] data Input data for the current worker. + * @param [in] sizes Size of the input from each worker. + * @param [out] recv_segments pre-allocated offset for each worker in the output, size + * should be equal to (world + 1). + * @param [out] recv pre-allocated buffer for output. + */ + [[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm, + common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv); +}; +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 62b87e411882..50b1722ae8e1 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -4,6 +4,7 @@ #include #include "../../../src/collective/allreduce.h" +#include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/tracker.h" #include "test_worker.h" // for WorkerForTest, TestDistributed @@ -47,6 +48,19 @@ class AllreduceWorker : public WorkerForTest { ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; } } + + void BitOr() { + Context ctx; + std::vector data(comm_.World(), 0); + data[comm_.Rank()] = ~std::uint32_t{0}; + auto pcoll = std::make_shared(); + auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}), + ArrayInterfaceHandler::kU4, Op::kBitwiseOR); + ASSERT_TRUE(rc.OK()) << rc.Report(); + for (auto v : data) { + ASSERT_EQ(v, ~std::uint32_t{0}); + } + } }; class AllreduceTest : public SocketTest {}; @@ -69,4 +83,13 @@ TEST_F(AllreduceTest, Sum) { worker.Acc(); }); } + +TEST_F(AllreduceTest, BitOr) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + AllreduceWorker worker{host, port, timeout, n_workers, r}; + worker.BitOr(); + }); +} } // namespace xgboost::collective