Skip to content

Commit

Permalink
Fix bug where AsyncioRunnable hangs if process_one throws and the…
Browse files Browse the repository at this point in the history
… source is not emitting new values (#523)

* Fixes a bug first observed in [NVIDIA-AI-Blueprints/vulnerability-analysis](https://github.com/NVIDIA-AI-Blueprints/vulnerability-analysis) and reported in nv-morpheus/Morpheus#2086
* `AsyncioRunnable` will now call `on_state_update(state_t::Kill)` when an exception is caught
* Replace blocking call to `await_read` with `await_read_until` allowing `AsyncioRunnable` to check `stop_source.stop_requested()` 
* Define new `await_read_until` method in `IEdgeReadable`, unfortunately this interface has numerous subclasses which all then needed new `await_read_until` methods, even though `EdgeChannelReader` is the only class that really needed it. Alternatives:
  - In `AsyncSink` perform a static cast of `this->get_readable_edge()` to `EdgeChannelReader`
  - Define `await_read_until` method in `IEdgeReadable` but give it an implementation that throws a non-impl exception (or asserts false)

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Will Killian (https://github.com/willkill07)

URL: #523
  • Loading branch information
dagardner-nv authored Jan 13, 2025
1 parent 7d5e48f commit aaf402a
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 10 deletions.
6 changes: 6 additions & 0 deletions cpp/mrc/include/mrc/edge/edge_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include "mrc/channel/types.hpp" // for time_point_t
#include "mrc/edge/edge_readable.hpp"
#include "mrc/edge/edge_writable.hpp"
#include "mrc/edge/forward.hpp"
Expand Down Expand Up @@ -45,6 +46,11 @@ class EdgeChannelReader : public IEdgeReadable<T>
return m_channel->await_read(t);
}

virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp)
{
return m_channel->await_read_until(t, tp);
}

private:
EdgeChannelReader(std::shared_ptr<mrc::channel::Channel<T>> channel) : m_channel(std::move(channel)) {}

Expand Down
32 changes: 31 additions & 1 deletion cpp/mrc/include/mrc/edge/edge_readable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mrc/channel/channel.hpp"
#include "mrc/channel/egress.hpp"
#include "mrc/channel/ingress.hpp"
#include "mrc/channel/types.hpp" // for time_point_t
#include "mrc/edge/edge.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/forward.hpp"
Expand Down Expand Up @@ -61,7 +62,8 @@ class IEdgeReadable : public virtual Edge<T>, public IEdgeReadableBase
return EdgeTypeInfo::create<T>();
}

virtual channel::Status await_read(T& t) = 0;
virtual channel::Status await_read(T& t) = 0;
virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) = 0;
};

template <typename InputT, typename OutputT = InputT>
Expand Down Expand Up @@ -110,6 +112,20 @@ class ConvertingEdgeReadable<InputT, OutputT, std::enable_if_t<std::is_convertib

return ret_val;
}

channel::Status await_read_until(OutputT& data, const mrc::channel::time_point_t& tp) override
{
InputT source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
// Convert to the sink type
data = std::move(source_data);
}

return status;
}
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -137,6 +153,20 @@ class LambdaConvertingEdgeReadable : public ConvertingEdgeReadableBase<InputT, O
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
// Convert to the sink type
data = m_lambda_fn(std::move(source_data));
}

return status;
}

private:
lambda_fn_t m_lambda_fn{};
};
Expand Down
5 changes: 4 additions & 1 deletion cpp/mrc/include/mrc/node/sink_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class NullReadableEdge : public edge::IEdgeReadable<T>
channel::Status await_read(T& t) override
{
throw std::runtime_error("Attempting to read from a null edge. Ensure an edge was established for all sinks.");
}

return channel::Status::error;
channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Attempting to read from a null edge. Ensure an edge was established for all sinks.");
}
};

Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/include/mrc/node/source_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ class ForwardingReadableProvider : public ReadableProvider<T>
return m_parent.get_next(t);
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Not implemented");
}

private:
ForwardingReadableProvider<T>& m_parent;
};
Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/tests/node/test_nodes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class EdgeReadableLambda : public edge::IEdgeReadable<T>
return m_on_await_read(t);
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Not implemented");
}

private:
std::function<channel::Status(T&)> m_on_await_read;
std::function<void()> m_on_complete;
Expand Down
26 changes: 19 additions & 7 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <mrc/coroutines/closable_ring_buffer.hpp>
#include <mrc/coroutines/task.hpp>
#include <mrc/coroutines/task_container.hpp>
#include <mrc/edge/edge_channel.hpp> // for EdgeChannelReader
#include <mrc/exceptions/exception_catcher.hpp>
#include <mrc/node/sink_properties.hpp>
#include <mrc/runnable/forward.hpp>
Expand Down Expand Up @@ -118,8 +119,17 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
{
protected:
AsyncSink() :
m_read_async([this](T& value) {
return this->get_readable_edge()->await_read(value);
m_read_async([this](T& value, std::stop_source& stop_source) {
using namespace std::chrono_literals;
auto edge = this->get_readable_edge();
channel::Status status = channel::Status::timeout;
while ((status == channel::Status::timeout || status == channel::Status::empty) &&
not stop_source.stop_requested())
{
status = edge->await_read_until(value, std::chrono::system_clock::now() + 10ms);
}

return status;
})
{
// Set the default channel
Expand All @@ -129,13 +139,13 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
/**
* @brief Asynchronously reads a value from the sink's channel
*/
coroutines::Task<mrc::channel::Status> read_async(T& value)
coroutines::Task<mrc::channel::Status> read_async(T& value, std::stop_source& stop_source)
{
co_return co_await m_read_async(std::ref(value));
co_return co_await m_read_async(std::ref(value), std::ref(stop_source));
}

private:
BoostFutureAwaitableOperation<mrc::channel::Status(T&)> m_read_async;
BoostFutureAwaitableOperation<mrc::channel::Status(T&, std::stop_source&)> m_read_async;
};

/**
Expand Down Expand Up @@ -297,8 +307,8 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
{
InputT data;

auto read_status = co_await this->read_async(data);

mrc::channel::Status read_status = mrc::channel::Status::success;
read_status = co_await this->read_async(data, m_stop_source);
if (read_status != mrc::channel::Status::success)
{
break;
Expand All @@ -309,6 +319,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m

co_await outstanding_tasks.garbage_collect_and_yield_until_empty();

// this is a no-op if there are no exceptions
catcher.rethrow_next_exception();
}

Expand Down Expand Up @@ -339,6 +350,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
} catch (...)
{
catcher.push_exception(std::current_exception());
on_state_update(state_t::Kill);
}
}

Expand Down
60 changes: 60 additions & 0 deletions python/mrc/_pymrc/include/pymrc/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ class ConvertingEdgeReadable<

return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
// We need to hold the GIL here, because casting from c++ -> pybind11::object allocates memory with
// Py_Malloc.
// Its also important to note that you do not want to hold the GIL when calling m_output->await_write, as
// that can trigger a deadlock with another fiber reading from the end of the channel
pymrc::AcquireGIL gil;

data = pybind11::cast(std::move(source_data));
}

return status;
}
};

template <typename OutputT>
Expand Down Expand Up @@ -224,6 +243,21 @@ struct ConvertingEdgeReadable<
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
pymrc::AcquireGIL gil;

data = pybind11::cast<output_t>(pybind11::object(std::move(source_data)));
}

return status;
}

static void register_converter()
{
EdgeConnector<input_t, output_t>::register_converter();
Expand All @@ -249,6 +283,19 @@ struct ConvertingEdgeReadable<pymrc::PyObjectHolder, pybind11::object, void>

return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
data = std::move(source_data);
}

return status;
}
};

template <>
Expand All @@ -271,6 +318,19 @@ struct ConvertingEdgeReadable<pybind11::object, pymrc::PyObjectHolder, void>
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
data = pymrc::PyObjectHolder(std::move(source_data));
}

return status;
}

static void register_converter()
{
EdgeConnector<input_t, output_t>::register_converter();
Expand Down
66 changes: 65 additions & 1 deletion python/mrc/_pymrc/tests/test_asyncio_runnable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <atomic>
#include <chrono>
#include <coroutine>
#include <cstddef> // for size_t
#include <functional>
#include <memory>
#include <stdexcept>
Expand All @@ -60,6 +61,7 @@ class Scheduler;

namespace py = pybind11;
namespace pymrc = mrc::pymrc;
using namespace std::chrono_literals;
using namespace std::string_literals;
using namespace py::literals;

Expand Down Expand Up @@ -102,6 +104,11 @@ class __attribute__((visibility("default"))) PythonCallbackAsyncioRunnable : pub
result = co_await pymrc::coro::PyTaskToCppAwaitable(std::move(coroutine));
}

if (result.is_none())
{
co_return;
}

auto result_casted = py::cast<int>(result);

py::gil_scoped_release release;
Expand Down Expand Up @@ -316,7 +323,6 @@ auto run_operation(OperationT& operation) -> mrc::coroutines::Task<int>
TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanReturn)
{
auto operation = mrc::pymrc::BoostFutureAwaitableOperation<int()>([]() {
using namespace std::chrono_literals;
boost::this_fiber::sleep_for(10ms);
return 5;
});
Expand All @@ -333,3 +339,61 @@ TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanThrow)

ASSERT_THROW(mrc::coroutines::sync_wait(run_operation(operation)), std::runtime_error);
}

TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows2086)
{
// Reproduces Morpheus issue #2086 where an exception is thrown in Async Python code, and the source does not emit
// any additional values. When the source emits an additional value or calls on_completed, the pipeline completes
// and the exception is thrown to the caller.
pymrc::Pipeline p;

py::object globals = py::globals();
py::exec(
R"(
async def fn(value):
print(f"Sink received value={value}")
if value == 1:
print("Sink raising exception", flush=True)
raise RuntimeError("oops")
)",
globals);

pymrc::PyObjectHolder fn = static_cast<py::object>(globals["fn"]);

auto init = [&fn](mrc::segment::IBuilder& seg) {
auto src = seg.make_source<int>("src", [](rxcpp::subscriber<int>& s) {
std::size_t i = 0;
while (s.is_subscribed())
{
if (i < 2)
{
s.on_next(i);
}

boost::this_fiber::sleep_for(10ms);

++i;
}

s.on_completed();
});

auto sink = seg.construct_object<PythonCallbackAsyncioRunnable>("sink", fn);

seg.make_edge(src, sink);
};

p.make_segment("seg1"s, init);

auto options = std::make_shared<mrc::Options>();

// AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific.
options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread);

pymrc::Executor exec{options};
exec.register_pipeline(p);

exec.start();

ASSERT_THROW(exec.join(), std::runtime_error);
}

0 comments on commit aaf402a

Please sign in to comment.