From 9528a6f37231a2c901b29be359b8032ad69091b1 Mon Sep 17 00:00:00 2001 From: shengtsui <124718322+shengtsui@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:28:03 -0800 Subject: [PATCH] VTEN-18-Add-pinned-memory-pool-to-VTensor (#24) --- docs/source/api/core/mempool.rst | 3 ++ lib/core/rmm_utils.cc | 43 +++++++++++++++++++++++++--- lib/core/rmm_utils.hpp | 49 +++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 5 deletions(-) diff --git a/docs/source/api/core/mempool.rst b/docs/source/api/core/mempool.rst index 0fb9e31..efd6c2e 100644 --- a/docs/source/api/core/mempool.rst +++ b/docs/source/api/core/mempool.rst @@ -10,6 +10,9 @@ vt::Mempool .. doxygenclass:: vt::Mempool :members: +.. doxygenclass:: vt::PinnedMempool + :members: + .. doxygenclass:: vt::GlobalMempool :members: diff --git a/lib/core/rmm_utils.cc b/lib/core/rmm_utils.cc index 10d457d..9efbaf4 100644 --- a/lib/core/rmm_utils.cc +++ b/lib/core/rmm_utils.cc @@ -2,8 +2,8 @@ #include #include #include -#include - +#include "rmm/mr/device/pool_memory_resource.hpp" +#include "rmm/mr/host/pinned_memory_resource.hpp" namespace vt { @@ -18,7 +18,7 @@ namespace vt { } /** - * @brief Private implementation pointer. This could reduce compilation time for RMM headers. The implementation is defined in rmm_utils.cc. + * @brief Private implementation pointer. This could reduce compilation time for RMM headers. * */ class Mempool::MempoolImpl { @@ -39,10 +39,45 @@ namespace vt { rmm::mr::logging_resource_adaptor log_mr; rmm::mr::pool_memory_resource> pool_mr; }; + + Mempool::~Mempool() = default; - Mempool::Mempool(unsigned long pool_size, const std::string& log_filepath) + Mempool::Mempool(size_t pool_size, const std::string& log_filepath) : pimpl(std::make_unique(pool_size, log_filepath)) {} + + /** + * @brief Private implementation pointer for PinnedMempool. This could reduce compilation time for RMM headers. + * + */ + class PinnedMempool::PinnedMempoolImpl { + public: + /** + * @brief Construct a new Pinned Mempool Impl object + * + * @param initial_pinned_pool_size: The initial size of the pinned memory pool. + * @param pinned_pool_size: The size of the pinned memory pool. + */ + PinnedMempoolImpl(size_t initial_pinned_pool_size, size_t pinned_pool_size) + : pool_mr(pinned_mr, initial_pinned_pool_size, pinned_pool_size) {} + + rmm::mr::pinned_memory_resource pinned_mr; + rmm::mr::pool_memory_resource pool_mr; + }; + + PinnedMempool::~PinnedMempool() = default; + + void PinnedMempool::deallocate(void* ptr, size_t size) { + pimpl->pool_mr.deallocate(ptr, size); + } + + void* PinnedMempool::allocate(size_t size) { + return pimpl->pool_mr.allocate(size); + } + + PinnedMempool::PinnedMempool(size_t initial_pinned_pool_size, size_t pinned_pool_size) + : pimpl(std::make_unique(initial_pinned_pool_size, pinned_pool_size)) {} + GlobalMempool& GlobalMempool::get_instance(size_t pool_size) { static GlobalMempool instance(pool_size); return instance; diff --git a/lib/core/rmm_utils.hpp b/lib/core/rmm_utils.hpp index ae6475a..5ab8e7f 100644 --- a/lib/core/rmm_utils.hpp +++ b/lib/core/rmm_utils.hpp @@ -35,13 +35,60 @@ class Mempool { * @param pool_size: The size of the memory pool. * @param log_filepath: The log file path for memory usage logging. */ - Mempool(unsigned long pool_size, const std::string& log_filepath); + Mempool(size_t pool_size, const std::string& log_filepath); + + /** + * @brief Destroy the Mempool object + * + */ + ~Mempool(); private: class MempoolImpl; std::unique_ptr pimpl; }; +/** + * @brief A class that manages a pinned memory pool. + * + */ +class PinnedMempool { + public: + /** + * @brief Construct a PinnedMempool object. + * + * @param initial_pinned_pool_size: The initial size of the pinned memory pool. + * @param pinned_pool_size: The size of the pinned memory pool. + */ + PinnedMempool(size_t initial_pinned_pool_size, size_t pinned_pool_size); + + /** + * @brief Destroy the PinnedMempool object + * + */ + ~PinnedMempool(); + + /** + * @brief Deallocate the memory for the given pointer and size. + * + * @param ptr: The pointer to the memory. + * @param size: The size of the memory. + */ + void deallocate(void* ptr, size_t size); + + /** + * @brief Allocate the memory for the given size. + * + * @param size: The size of the memory. + * @return void*: The pointer to the allocated memory. + */ + void* allocate(size_t size); + + private: + class PinnedMempoolImpl; + std::unique_ptr pimpl; +}; + /** * @brief Singleton class for global memory pool management *