diff --git a/cpp/mrc/CMakeLists.txt b/cpp/mrc/CMakeLists.txt index 46cb31787..f2f1e63cc 100644 --- a/cpp/mrc/CMakeLists.txt +++ b/cpp/mrc/CMakeLists.txt @@ -115,7 +115,6 @@ add_library(libmrc src/public/core/logging.cpp src/public/core/thread.cpp src/public/coroutines/event.cpp - src/public/coroutines/scheduler.cpp src/public/coroutines/sync_wait.cpp src/public/coroutines/task_container.cpp src/public/coroutines/thread_local_context.cpp diff --git a/cpp/mrc/include/mrc/coroutines/scheduler.hpp b/cpp/mrc/include/mrc/coroutines/scheduler.hpp index 0dc7660ea..f6960cb63 100644 --- a/cpp/mrc/include/mrc/coroutines/scheduler.hpp +++ b/cpp/mrc/include/mrc/coroutines/scheduler.hpp @@ -17,6 +17,8 @@ #pragma once +#include "mrc/coroutines/task.hpp" + #include #include #include @@ -27,87 +29,16 @@ namespace mrc::coroutines { /** * @brief Scheduler base class - * - * Allows all schedulers to be discovered via the mrc::this_thread::current_scheduler() */ class Scheduler : public std::enable_shared_from_this { public: - struct Operation - { - Operation(Scheduler& scheduler); - - constexpr static auto await_ready() noexcept -> bool - { - return false; - } - - std::coroutine_handle<> await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept; - - constexpr static auto await_resume() noexcept -> void {} - - Scheduler& m_scheduler; - std::coroutine_handle<> m_awaiting_coroutine; - Operation* m_next{nullptr}; - }; - - Scheduler(); virtual ~Scheduler() = default; - /** - * @brief Description of Scheduler - */ - virtual std::string description() const = 0; - - /** - * Schedules the currently executing coroutine to be run on this thread pool. This must be - * called from within the coroutines function body to schedule the coroutine on the thread pool. - * @throw std::runtime_error If the thread pool is `shutdown()` scheduling new tasks is not permitted. - * @return The operation to switch from the calling scheduling thread to the executor thread - * pool thread. - */ - [[nodiscard]] virtual auto schedule() -> Operation; - - /** - * Schedules any coroutine handle that is ready to be resumed. - * @param handle The coroutine handle to schedule. - */ - virtual auto resume(std::coroutine_handle<> coroutine) -> void = 0; - - /** - * Yields the current task to the end of the queue of waiting tasks. - */ - [[nodiscard]] auto yield() -> Operation; - - /** - * If the calling thread controlled by a Scheduler, return a pointer to the Scheduler - */ - static auto from_current_thread() noexcept -> Scheduler*; - - /** - * If the calling thread is owned by a thread_pool, return the thread index (rank) of the current thread with - * respect the threads in the pool; otherwise, return the std::hash of std::this_thread::get_id - */ - static auto get_thread_id() noexcept -> std::size_t; - - protected: - virtual auto on_thread_start(std::size_t) -> void; - - private: - /** - * @brief When co_await schedule() is called, this function will be executed by the awaiter. Each scheduler - * implementation should determine how and when to execute the operation. - * - * @param operation The schedule() awaitable pointer - * @return std::coroutine_handle<> Return a coroutine handle to which will be - * used as the return value for await_suspend(). - */ - virtual std::coroutine_handle<> schedule_operation(Operation* operation) = 0; - - mutable std::mutex m_mutex; + virtual void resume(std::coroutine_handle<> handle) noexcept = 0; - thread_local static Scheduler* m_thread_local_scheduler; - thread_local static std::size_t m_thread_id; + [[nodiscard]] virtual Task<> schedule() = 0; + [[nodiscard]] virtual Task<> yield() = 0; }; } // namespace mrc::coroutines diff --git a/cpp/mrc/src/public/coroutines/scheduler.cpp b/cpp/mrc/src/public/coroutines/scheduler.cpp deleted file mode 100644 index da3b7b35d..000000000 --- a/cpp/mrc/src/public/coroutines/scheduler.cpp +++ /dev/null @@ -1,71 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mrc/coroutines/scheduler.hpp" - -#include - -#include -#include - -namespace mrc::coroutines { - -thread_local Scheduler* Scheduler::m_thread_local_scheduler{nullptr}; -thread_local std::size_t Scheduler::m_thread_id{0}; - -Scheduler::Operation::Operation(Scheduler& scheduler) : m_scheduler(scheduler) {} - -std::coroutine_handle<> Scheduler::Operation::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -{ - m_awaiting_coroutine = awaiting_coroutine; - return m_scheduler.schedule_operation(this); -} - -Scheduler::Scheduler() = default; - -auto Scheduler::schedule() -> Operation -{ - return Operation{*this}; -} - -auto Scheduler::yield() -> Operation -{ - return schedule(); -} - -auto Scheduler::from_current_thread() noexcept -> Scheduler* -{ - return m_thread_local_scheduler; -} - -auto Scheduler::get_thread_id() noexcept -> std::size_t -{ - if (m_thread_local_scheduler == nullptr) - { - return std::hash()(std::this_thread::get_id()); - } - return m_thread_id; -} - -auto Scheduler::on_thread_start(std::size_t thread_id) -> void -{ - DVLOG(10) << "scheduler: " << description() << " initializing"; - m_thread_id = thread_id; - m_thread_local_scheduler = this; -} - -} // namespace mrc::coroutines diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp index b0e486c83..03e7f115f 100644 --- a/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp +++ b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp @@ -19,6 +19,7 @@ #include "pymrc/coro.hpp" #include "pymrc/utilities/acquire_gil.hpp" +#include "pymrc/utilities/object_wrappers.hpp" #include #include @@ -28,52 +29,67 @@ #include #include +#include namespace py = pybind11; namespace mrc::pymrc { + +/** + * @brief A MRC Scheduler which allows resuming C++20 coroutines on an Asyncio event loop. + */ class AsyncioScheduler : public mrc::coroutines::Scheduler { - public: - AsyncioScheduler(PyObjectHolder loop) : m_loop(std::move(loop)) {} - - std::string description() const override + private: + class ContinueOnLoopOperation { - return "AsyncioScheduler"; - } + public: + ContinueOnLoopOperation(PyObjectHolder loop) : m_loop(std::move(loop)) {} - void resume(std::coroutine_handle<> coroutine) override - { - if (coroutine.done()) + static bool await_ready() noexcept { - LOG(WARNING) << "AsyncioScheduler::resume() > Attempted to resume a completed coroutine"; - return; + return false; } - py::gil_scoped_acquire gil; + void await_suspend(std::coroutine_handle<> handle) noexcept + { + AsyncioScheduler::resume(m_loop, handle); + } - // TODO(MDD): Check whether or not we need thread safe version - m_loop.attr("call_soon_threadsafe")(py::cpp_function([this, handle = std::move(coroutine)]() { - if (handle.done()) - { - LOG(WARNING) << "AsyncioScheduler::resume() > Attempted to resume a completed coroutine"; - return; - } + static void await_resume() noexcept {} - py::gil_scoped_release nogil; + private: + PyObjectHolder m_loop; + }; + static void resume(PyObjectHolder loop, std::coroutine_handle<> handle) noexcept + { + pybind11::gil_scoped_acquire acquire; + loop.attr("call_soon_threadsafe")(pybind11::cpp_function([handle]() { // + pybind11::gil_scoped_release release; handle.resume(); })); } - private: - std::coroutine_handle<> schedule_operation(Operation* operation) override + public: + AsyncioScheduler(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + void resume(std::coroutine_handle<> handle) noexcept override + { + AsyncioScheduler::resume(m_loop, handle); + } + + [[nodiscard]] coroutines::Task<> schedule() override { - this->resume(std::move(operation->m_awaiting_coroutine)); + co_await ContinueOnLoopOperation(m_loop); + } - return std::noop_coroutine(); + [[nodiscard]] coroutines::Task<> yield() override + { + co_await ContinueOnLoopOperation(m_loop); } + private: mrc::pymrc::PyHolder m_loop; };