Skip to content

Commit

Permalink
Add await_read_until method to IEdgeReadable
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Dec 19, 2024
1 parent 5c2d482 commit bf9bcb5
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
7 changes: 2 additions & 5 deletions cpp/mrc/include/mrc/edge/edge_readable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ class IEdgeReadable : public virtual Edge<T>, public IEdgeReadableBase
return EdgeTypeInfo::create<T>();
}

virtual channel::Status await_read(T& t) = 0;
virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp)
{
throw std::runtime_error("Not implemented");
};
virtual channel::Status await_read(T& t) = 0;
virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) = 0;
};

template <typename InputT, typename OutputT = InputT>
Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/include/mrc/node/source_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ class ForwardingReadableProvider : public ReadableProvider<T>
return m_parent.get_next(t);
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Not implemented");
}

private:
ForwardingReadableProvider<T>& m_parent;
};
Expand Down
6 changes: 6 additions & 0 deletions cpp/mrc/tests/node/test_nodes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class EdgeReadableLambda : public edge::IEdgeReadable<T>
return m_on_await_read(t);
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Not implemented");
return channel::Status::error;
}

private:
std::function<channel::Status(T&)> m_on_await_read;
std::function<void()> m_on_complete;
Expand Down
60 changes: 60 additions & 0 deletions python/mrc/_pymrc/include/pymrc/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ class ConvertingEdgeReadable<

return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
// We need to hold the GIL here, because casting from c++ -> pybind11::object allocates memory with
// Py_Malloc.
// Its also important to note that you do not want to hold the GIL when calling m_output->await_write, as
// that can trigger a deadlock with another fiber reading from the end of the channel
pymrc::AcquireGIL gil;

data = pybind11::cast(std::move(source_data));
}

return status;
}
};

template <typename OutputT>
Expand Down Expand Up @@ -224,6 +243,21 @@ struct ConvertingEdgeReadable<
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
pymrc::AcquireGIL gil;

data = pybind11::cast<output_t>(pybind11::object(std::move(source_data)));
}

return status;
}

static void register_converter()
{
EdgeConnector<input_t, output_t>::register_converter();
Expand All @@ -249,6 +283,19 @@ struct ConvertingEdgeReadable<pymrc::PyObjectHolder, pybind11::object, void>

return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
data = std::move(source_data);
}

return status;
}
};

template <>
Expand All @@ -271,6 +318,19 @@ struct ConvertingEdgeReadable<pybind11::object, pymrc::PyObjectHolder, void>
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
data = pymrc::PyObjectHolder(std::move(source_data));
}

return status;
}

static void register_converter()
{
EdgeConnector<input_t, output_t>::register_converter();
Expand Down

0 comments on commit bf9bcb5

Please sign in to comment.