From 1591e3d82e21b9484c9da702356ae9d0c33bc449 Mon Sep 17 00:00:00 2001 From: Christopher Harris Date: Thu, 2 Nov 2023 17:30:10 +0000 Subject: [PATCH] add asyncgenerator failure test to asyncio_runnable --- python/mrc/_pymrc/include/pymrc/coro.hpp | 15 ++++- .../_pymrc/tests/test_asyncio_runnable.cpp | 59 ++++++++++++++++++- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/python/mrc/_pymrc/include/pymrc/coro.hpp b/python/mrc/_pymrc/include/pymrc/coro.hpp index 5c80398cc..fa9b865da 100644 --- a/python/mrc/_pymrc/include/pymrc/coro.hpp +++ b/python/mrc/_pymrc/include/pymrc/coro.hpp @@ -174,13 +174,22 @@ class PYBIND11_EXPORT PyTaskToCppAwaitable PyTaskToCppAwaitable(mrc::pymrc::PyObjectHolder&& task) : m_task(std::move(task)) { pybind11::gil_scoped_acquire acquire; - if (pybind11::module_::import("inspect").attr("iscoroutine")(m_task).cast()) + + auto asyncio = pybind11::module_::import("asyncio"); + auto inspect = pybind11::module_::import("inspect"); + + if (not asyncio.attr("isfuture")(m_task).cast()) { - m_task = pybind11::module_::import("asyncio").attr("create_task")(m_task); + if (not asyncio.attr("iscoroutine")(m_task).cast()) + { + throw std::runtime_error(MRC_CONCAT_STR("PyTaskToCppAwaitable expected task or coroutine but got " << pybind11::repr(m_task).cast())); + } + + m_task = asyncio.attr("create_task")(m_task); } } - static bool await_ready() noexcept // NOLINT(readability-convert-member-functions-to-static) + static bool await_ready() noexcept { // Always suspend return false; diff --git a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp index 9b3d17d21..e71cc5243 100644 --- a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp +++ b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp @@ -168,6 +168,62 @@ TEST_F(TestAsyncioRunnable, UseAsyncioTasks) EXPECT_EQ(counter, 60); } +TEST_F(TestAsyncioRunnable, UseAsyncioGeneratorThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + yield value + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows) { // pybind11::module_::import("mrc.core.coro"); @@ -220,9 +276,8 @@ TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows) exec.register_pipeline(p); exec.start(); - exec.join(); - EXPECT_EQ(counter, 60); + ASSERT_THROW(exec.join(), std::runtime_error); } template