-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[coll] Define interface for bridging. (#9695)
* Define the basic interface that will shared by nccl, federated and native.
- Loading branch information
1 parent
6fbe624
commit b771f58
Showing
7 changed files
with
174 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#include "coll.h" | ||
|
||
#include <algorithm> // for min, max | ||
#include <cstddef> // for size_t | ||
#include <cstdint> // for int8_t, int64_t | ||
#include <functional> // for bit_and, bit_or, bit_xor, plus | ||
|
||
#include "allgather.h" // for RingAllgatherV, RingAllgather | ||
#include "allreduce.h" // for Allreduce | ||
#include "broadcast.h" // for Broadcast | ||
#include "comm.h" // for Comm | ||
#include "xgboost/context.h" // for Context | ||
|
||
namespace xgboost::collective { | ||
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm, | ||
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type, | ||
Op op) { | ||
namespace coll = ::xgboost::collective; | ||
|
||
auto redop_fn = [](auto lhs, auto out, auto elem_op) { | ||
auto p_lhs = lhs.data(); | ||
auto p_out = out.data(); | ||
for (std::size_t i = 0; i < lhs.size(); ++i) { | ||
p_out[i] = elem_op(p_lhs[i], p_out[i]); | ||
} | ||
}; | ||
auto fn = [&](auto elem_op) { | ||
return coll::Allreduce( | ||
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); }); | ||
}; | ||
|
||
switch (op) { | ||
case Op::kMax: { | ||
return fn([](auto l, auto r) { return std::max(l, r); }); | ||
} | ||
case Op::kMin: { | ||
return fn([](auto l, auto r) { return std::min(l, r); }); | ||
} | ||
case Op::kSum: { | ||
return fn(std::plus<>{}); | ||
} | ||
case Op::kBitwiseAND: { | ||
return fn(std::bit_and<>{}); | ||
} | ||
case Op::kBitwiseOR: { | ||
return fn(std::bit_or<>{}); | ||
} | ||
case Op::kBitwiseXOR: { | ||
return fn(std::bit_xor<>{}); | ||
} | ||
} | ||
return comm.Block(); | ||
} | ||
|
||
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm, | ||
common::Span<std::int8_t> data, std::int32_t root) { | ||
return cpu_impl::Broadcast(comm, data, root); | ||
} | ||
|
||
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm, | ||
common::Span<std::int8_t> data, std::size_t size) { | ||
return RingAllgather(comm, data, size); | ||
} | ||
|
||
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm, | ||
common::Span<std::int8_t const> data, | ||
common::Span<std::int64_t const> sizes, | ||
common::Span<std::int64_t> recv_segments, | ||
common::Span<std::int8_t> recv) { | ||
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv); | ||
} | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#pragma once | ||
#include <cstddef> // for size_t | ||
#include <cstdint> // for int8_t, int64_t | ||
#include <memory> // for enable_shared_from_this | ||
|
||
#include "../data/array_interface.h" // for ArrayInterfaceHandler | ||
#include "comm.h" // for Comm | ||
#include "xgboost/collective/result.h" // for Result | ||
#include "xgboost/context.h" // for Context | ||
#include "xgboost/span.h" // for Span | ||
|
||
namespace xgboost::collective { | ||
/** | ||
* @brief Interface and base implementation for collective. | ||
*/ | ||
class Coll : public std::enable_shared_from_this<Coll> { | ||
public: | ||
Coll() = default; | ||
virtual ~Coll() noexcept(false) {} // NOLINT | ||
|
||
/** | ||
* @brief Allreduce | ||
* | ||
* @param [in,out] data Data buffer for input and output. | ||
* @param [in] type data type. | ||
* @param [in] op Reduce operation. For custom operation, user needs to reach down to | ||
* the CPU implementation. | ||
*/ | ||
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm, | ||
common::Span<std::int8_t> data, | ||
ArrayInterfaceHandler::Type type, Op op); | ||
/** | ||
* @brief Broadcast | ||
* | ||
* @param [in,out] data Data buffer for input and output. | ||
* @param [in] root Root rank for broadcast. | ||
*/ | ||
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm, | ||
common::Span<std::int8_t> data, std::int32_t root); | ||
/** | ||
* @brief Allgather | ||
* | ||
* @param [in,out] data Data buffer for input and output. | ||
* @param [in] size Size of data for each worker. | ||
*/ | ||
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm, | ||
common::Span<std::int8_t> data, std::size_t size); | ||
/** | ||
* @brief Allgather with variable length. | ||
* | ||
* @param [in] data Input data for the current worker. | ||
* @param [in] sizes Size of the input from each worker. | ||
* @param [out] recv_segments pre-allocated offset for each worker in the output, size | ||
* should be equal to (world + 1). | ||
* @param [out] recv pre-allocated buffer for output. | ||
*/ | ||
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm, | ||
common::Span<std::int8_t const> data, | ||
common::Span<std::int64_t const> sizes, | ||
common::Span<std::int64_t> recv_segments, | ||
common::Span<std::int8_t> recv); | ||
}; | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters