Skip to content

Commit

Permalink
add asyncgenerator failure test to asyncio_runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Nov 2, 2023
1 parent 4b1999d commit 1591e3d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
15 changes: 12 additions & 3 deletions python/mrc/_pymrc/include/pymrc/coro.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,22 @@ class PYBIND11_EXPORT PyTaskToCppAwaitable
PyTaskToCppAwaitable(mrc::pymrc::PyObjectHolder&& task) : m_task(std::move(task))
{
pybind11::gil_scoped_acquire acquire;
if (pybind11::module_::import("inspect").attr("iscoroutine")(m_task).cast<bool>())

auto asyncio = pybind11::module_::import("asyncio");
auto inspect = pybind11::module_::import("inspect");

if (not asyncio.attr("isfuture")(m_task).cast<bool>())
{
m_task = pybind11::module_::import("asyncio").attr("create_task")(m_task);
if (not asyncio.attr("iscoroutine")(m_task).cast<bool>())
{
throw std::runtime_error(MRC_CONCAT_STR("PyTaskToCppAwaitable expected task or coroutine but got " << pybind11::repr(m_task).cast<std::string>()));
}

m_task = asyncio.attr("create_task")(m_task);
}
}

static bool await_ready() noexcept // NOLINT(readability-convert-member-functions-to-static)
static bool await_ready() noexcept
{
// Always suspend
return false;
Expand Down
59 changes: 57 additions & 2 deletions python/mrc/_pymrc/tests/test_asyncio_runnable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,62 @@ TEST_F(TestAsyncioRunnable, UseAsyncioTasks)
EXPECT_EQ(counter, 60);
}

TEST_F(TestAsyncioRunnable, UseAsyncioGeneratorThrows)
{
// pybind11::module_::import("mrc.core.coro");

py::object globals = py::globals();
py::exec(
R"(
async def fn(value):
yield value
)",
globals);

pymrc::PyObjectHolder fn = static_cast<py::object>(globals["fn"]);

ASSERT_FALSE(fn.is_none());

std::atomic<unsigned int> counter = 0;
pymrc::Pipeline p;

auto init = [&counter, &fn](mrc::segment::IBuilder& seg) {
auto src = seg.make_source<int>("src", [](rxcpp::subscriber<int>& s) {
if (s.is_subscribed())
{
s.on_next(5);
s.on_next(10);
}

s.on_completed();
});

auto internal = seg.construct_object<PythonCallbackAsyncioRunnable>("internal", fn);

auto sink = seg.make_sink<int>("sink", [&counter](int x) {
counter.fetch_add(x, std::memory_order_relaxed);
});

seg.make_edge(src, internal);
seg.make_edge(internal, sink);
};

p.make_segment("seg1"s, init);
p.make_segment("seg2"s, init);

auto options = std::make_shared<mrc::Options>();
options->topology().user_cpuset("0");
// AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific.
options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread);

pymrc::Executor exec{options};
exec.register_pipeline(p);

exec.start();

ASSERT_THROW(exec.join(), std::runtime_error);
}

TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows)
{
// pybind11::module_::import("mrc.core.coro");
Expand Down Expand Up @@ -220,9 +276,8 @@ TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows)
exec.register_pipeline(p);

exec.start();
exec.join();

EXPECT_EQ(counter, 60);
ASSERT_THROW(exec.join(), std::runtime_error);
}

template <typename OperationT>
Expand Down

0 comments on commit 1591e3d

Please sign in to comment.