From 48dad3d86685d61c56a145fc322435107b642344 Mon Sep 17 00:00:00 2001 From: Christopher Harris Date: Wed, 1 Nov 2023 04:07:02 +0000 Subject: [PATCH] simplify asyncio_runnable --- .../_pymrc/include/pymrc/asyncio_runnable.hpp | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp index fdb01e05d..401e7df01 100644 --- a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -143,7 +143,10 @@ class CoroutineRunnableSink : public mrc::node::WritableProvider, public mrc::node::SinkChannelOwner { protected: - CoroutineRunnableSink() + CoroutineRunnableSink() : + m_reader([this](T& value) { + return this->get_readable_edge()->await_read(value); + }) { // Set the default channel this->set_channel(std::make_unique>()); @@ -151,16 +154,12 @@ class CoroutineRunnableSink : public mrc::node::WritableProvider, auto build_readable_generator(std::stop_token stop_token) -> mrc::coroutines::AsyncGenerator { - auto read_awaiter = BoostFutureReader([this](T& value) { - return this->get_readable_edge()->await_read(value); - }); - while (!stop_token.stop_requested()) { T value; // Pull a message off of the upstream channel - auto status = co_await read_awaiter.async_read(std::ref(value)); + auto status = co_await m_reader.async_read(std::ref(value)); if (status != mrc::channel::Status::success) { @@ -172,6 +171,9 @@ class CoroutineRunnableSink : public mrc::node::WritableProvider, co_return; } + + private: + BoostFutureReader m_reader; }; template @@ -184,25 +186,19 @@ class CoroutineRunnableSource : public mrc::node::WritableAcceptor, { // Set the default channel this->set_channel(std::make_unique>()); - } - - // auto build_readable_generator(std::stop_token stop_token) - // -> mrc::coroutines::AsyncGenerator - // { - // while (!stop_token.stop_requested()) - // { - // co_yield mrc::coroutines::detail::VoidValue{}; - // } - - // co_return; - // } - auto build_writable_receiver() -> std::shared_ptr> - { - return std::make_shared>([this](T&& value) { + m_writer = std::make_shared>([this](T&& value) { return this->get_writable_edge()->await_write(std::move(value)); }); } + + auto get_writable_receiver() -> std::shared_ptr> + { + return m_writer; + } + + private: + std::shared_ptr> m_writer; }; template @@ -257,7 +253,7 @@ coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptr::build_readable_generator(m_stop_source.get_token()); - auto output_receiver = CoroutineRunnableSource::build_writable_receiver(); + auto output_receiver = CoroutineRunnableSource::get_writable_receiver(); // Create the task buffer to limit the number of running tasks task_buffer_t task_buffer{{.capacity = m_concurrency}};