From 831925a601f8b46a13308ec2ffd0de8d852730ee Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 16 Sep 2024 11:40:55 -0700 Subject: [PATCH] Inspect the type-hint of the first parameter to determine if the source expects a Subscription object --- python/mrc/_pymrc/src/segment.cpp | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/mrc/_pymrc/src/segment.cpp b/python/mrc/_pymrc/src/segment.cpp index 98d16c2d0..00dfacd9f 100644 --- a/python/mrc/_pymrc/src/segment.cpp +++ b/python/mrc/_pymrc/src/segment.cpp @@ -364,18 +364,30 @@ std::shared_ptr BuilderProxy::make_source(mrc::s // Determine if the gen_factory is expecting to receive a subscription object auto inspect_mod = py::module::import("inspect"); auto signature = inspect_mod.attr("signature")(gen_factory); - auto num_params = py::len(signature.attr("parameters")); + auto params = signature.attr("parameters"); + auto num_params = py::len(params); bool expects_subscription = false; - if (num_params == 1) + if (num_params > 0) { - return self.construct_object(name, std::move(gen_factory)); + // We know there is at least one parameter. Check if the first parameter is a subscription object + // Note, when we receive a function that has been bound with `functools.partial(fn, arg1=some_value)`, the + // parameter is still visible in the signature of the partial object. + auto mrc_mod = py::module::import("mrc"); + auto param_values = params.attr("values")(); + auto first_param = py::iter(param_values); + auto type_hint = py::object((*first_param).attr("annotation")); + expects_subscription = (type_hint.is(mrc_mod.attr("Subscription")) || + type_hint.equal(py::str("mrc.Subscription")) || + type_hint.equal(py::str("Subscription"))); } - if (num_params == 0) + if (expects_subscription) { - return build_source(self, name, PyIteratorWrapper(std::move(gen_factory))); + return self.construct_object(name, std::move(gen_factory)); } + + return build_source(self, name, PyIteratorWrapper(std::move(gen_factory))); } std::shared_ptr BuilderProxy::make_source_component(mrc::segment::IBuilder& self,