Skip to content

Commit

Permalink
[coll] Define interface for bridging. (#9695)
Browse files Browse the repository at this point in the history
* Define the basic interface that will shared by nccl, federated and native.
  • Loading branch information
trivialfis authored Oct 20, 2023
1 parent 6fbe624 commit b771f58
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 3 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
Expand Down
6 changes: 4 additions & 2 deletions src/collective/allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*/
#include "allgather.h"

#include <algorithm> // for min, copy_n
#include <algorithm> // for min, copy_n, fill_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
Expand Down Expand Up @@ -45,6 +45,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size

[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
auto rank = comm.Rank();
Expand All @@ -56,7 +57,8 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto next_ch = comm.Chan(next);

// get worker offset
std::vector<std::int64_t> offset(world + 1, 0);
CHECK_EQ(world + 1, offset.size());
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);

Expand Down
5 changes: 4 additions & 1 deletion src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace cpu_impl {

[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl

Expand Down Expand Up @@ -66,7 +67,9 @@ template <typename T>
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = EraseType(h_result);
auto erased_data = EraseType(data);
std::vector<std::int64_t> offset(world + 1);

return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
common::Span{offset.data(), offset.size()}, erased_result);
}
} // namespace xgboost::collective
75 changes: 75 additions & 0 deletions src/collective/coll.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "coll.h"

#include <algorithm> // for min, max
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus

#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "xgboost/context.h" // for Context

namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
Op op) {
namespace coll = ::xgboost::collective;

auto redop_fn = [](auto lhs, auto out, auto elem_op) {
auto p_lhs = lhs.data();
auto p_out = out.data();
for (std::size_t i = 0; i < lhs.size(); ++i) {
p_out[i] = elem_op(p_lhs[i], p_out[i]);
}
};
auto fn = [&](auto elem_op) {
return coll::Allreduce(
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
};

switch (op) {
case Op::kMax: {
return fn([](auto l, auto r) { return std::max(l, r); });
}
case Op::kMin: {
return fn([](auto l, auto r) { return std::min(l, r); });
}
case Op::kSum: {
return fn(std::plus<>{});
}
case Op::kBitwiseAND: {
return fn(std::bit_and<>{});
}
case Op::kBitwiseOR: {
return fn(std::bit_or<>{});
}
case Op::kBitwiseXOR: {
return fn(std::bit_xor<>{});
}
}
return comm.Block();
}

[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root);
}

[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size) {
return RingAllgather(comm, data, size);
}

[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
}
} // namespace xgboost::collective
66 changes: 66 additions & 0 deletions src/collective/coll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <memory> // for enable_shared_from_this

#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
/**
* @brief Interface and base implementation for collective.
*/
class Coll : public std::enable_shared_from_this<Coll> {
public:
Coll() = default;
virtual ~Coll() noexcept(false) {} // NOLINT

/**
* @brief Allreduce
*
* @param [in,out] data Data buffer for input and output.
* @param [in] type data type.
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
* the CPU implementation.
*/
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op);
/**
* @brief Broadcast
*
* @param [in,out] data Data buffer for input and output.
* @param [in] root Root rank for broadcast.
*/
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root);
/**
* @brief Allgather
*
* @param [in,out] data Data buffer for input and output.
* @param [in] size Size of data for each worker.
*/
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size);
/**
* @brief Allgather with variable length.
*
* @param [in] data Input data for the current worker.
* @param [in] sizes Size of the input from each worker.
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
* should be equal to (world + 1).
* @param [out] recv pre-allocated buffer for output.
*/
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv);
};
} // namespace xgboost::collective
23 changes: 23 additions & 0 deletions tests/cpp/collective/test_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <gtest/gtest.h>

#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for WorkerForTest, TestDistributed

Expand Down Expand Up @@ -47,6 +48,19 @@ class AllreduceWorker : public WorkerForTest {
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}

void BitOr() {
Context ctx;
std::vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto pcoll = std::make_shared<Coll>();
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (auto v : data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}
}
};

class AllreduceTest : public SocketTest {};
Expand All @@ -69,4 +83,13 @@ TEST_F(AllreduceTest, Sum) {
worker.Acc();
});
}

TEST_F(AllreduceTest, BitOr) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.BitOr();
});
}
} // namespace xgboost::collective

0 comments on commit b771f58

Please sign in to comment.