Skip to content

Commit

Permalink
Inspect the type-hint of the first parameter to determine if the sour…
Browse files Browse the repository at this point in the history
…ce expects a Subscription object
  • Loading branch information
dagardner-nv committed Sep 16, 2024
1 parent dda4739 commit 831925a
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions python/mrc/_pymrc/src/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,30 @@ std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source(mrc::s
// Determine if the gen_factory is expecting to receive a subscription object
auto inspect_mod = py::module::import("inspect");
auto signature = inspect_mod.attr("signature")(gen_factory);
auto num_params = py::len(signature.attr("parameters"));
auto params = signature.attr("parameters");
auto num_params = py::len(params);
bool expects_subscription = false;

if (num_params == 1)
if (num_params > 0)
{
return self.construct_object<SubscriberFuncWrapper>(name, std::move(gen_factory));
// We know there is at least one parameter. Check if the first parameter is a subscription object
// Note, when we receive a function that has been bound with `functools.partial(fn, arg1=some_value)`, the
// parameter is still visible in the signature of the partial object.
auto mrc_mod = py::module::import("mrc");
auto param_values = params.attr("values")();
auto first_param = py::iter(param_values);
auto type_hint = py::object((*first_param).attr("annotation"));
expects_subscription = (type_hint.is(mrc_mod.attr("Subscription")) ||
type_hint.equal(py::str("mrc.Subscription")) ||
type_hint.equal(py::str("Subscription")));
}

if (num_params == 0)
if (expects_subscription)
{
return build_source(self, name, PyIteratorWrapper(std::move(gen_factory)));
return self.construct_object<SubscriberFuncWrapper>(name, std::move(gen_factory));
}

return build_source(self, name, PyIteratorWrapper(std::move(gen_factory)));
}

std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source_component(mrc::segment::IBuilder& self,
Expand Down

0 comments on commit 831925a

Please sign in to comment.