Skip to content

Commit

Permalink
Define a Python source which receives a reference to a subscriber (#496)
Browse files Browse the repository at this point in the history
* Allows a Python generator source to check if the subscriber is still subscribed.
* Define a class `SubscriberFuncWrapper`  for Python sources rather than just a lambda. The reason is that python objects captured by the lambda need to be destroyed while the gil is held, which causes a problem if the lambda is destroyed unexpectedly.
* Update `conftest.py` to set the loglevel to `DEBUG` if the `GLOG_v` environment variable is defined.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #496
  • Loading branch information
dagardner-nv authored Sep 11, 2024
1 parent ca8a73f commit 8489b45
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 8 deletions.
9 changes: 4 additions & 5 deletions python/mrc/_pymrc/include/pymrc/segment.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -143,10 +143,9 @@ class BuilderProxy
const std::string& name,
pybind11::function gen_factory);

static std::shared_ptr<mrc::segment::ObjectProperties> make_source(
mrc::segment::IBuilder& self,
const std::string& name,
const std::function<void(pymrc::PyObjectSubscriber& sub)>& f);
static std::shared_ptr<mrc::segment::ObjectProperties> make_source_subscriber(mrc::segment::IBuilder& self,
const std::string& name,
pybind11::function gen_factory);

static std::shared_ptr<mrc::segment::ObjectProperties> make_source_component(mrc::segment::IBuilder& self,
const std::string& name,
Expand Down
62 changes: 62 additions & 0 deletions python/mrc/_pymrc/src/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <functional>
#include <iterator>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -257,6 +258,60 @@ std::shared_ptr<mrc::segment::ObjectProperties> build_source(mrc::segment::IBuil
return self.construct_object<PythonSource<PyHolder>>(name, wrapper);
}

class SubscriberFuncWrapper : public mrc::pymrc::PythonSource<PyHolder>
{
public:
using base_t = mrc::pymrc::PythonSource<PyHolder>;
using typename base_t::source_type_t;
using typename base_t::subscriber_fn_t;

SubscriberFuncWrapper(py::function gen_factory) : PythonSource(build()), m_gen_factory{std::move(gen_factory)} {}

private:
subscriber_fn_t build()
{
return [this](rxcpp::subscriber<source_type_t> subscriber) {
auto& ctx = runnable::Context::get_runtime_context();

try
{
DVLOG(10) << ctx.info() << " Starting source";
py::gil_scoped_acquire gil;
py::object py_sub = py::cast(subscriber);
auto py_iter = m_gen_factory.operator()<py::iterator>(std::move(py_sub));
PyIteratorWrapper iter_wrapper{std::move(py_iter)};

for (auto next_val : iter_wrapper)
{
// Only send if its subscribed. Very important to ensure the object has been moved!
if (subscriber.is_subscribed())
{
py::gil_scoped_release no_gil;
subscriber.on_next(std::move(next_val));
}
else
{
DVLOG(10) << ctx.info() << " Source unsubscribed. Stopping";
break;
}
}

} catch (const std::exception& e)
{
LOG(ERROR) << ctx.info() << "Error occurred in source. Error msg: " << e.what();

subscriber.on_error(std::current_exception());
return;
}
subscriber.on_completed();

DVLOG(10) << ctx.info() << " Source complete";
};
}

PyFuncWrapper m_gen_factory{};
};

std::shared_ptr<mrc::segment::ObjectProperties> build_source_component(mrc::segment::IBuilder& self,
const std::string& name,
PyIteratorWrapper iter_wrapper)
Expand Down Expand Up @@ -308,6 +363,13 @@ std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source(mrc::s
return build_source(self, name, PyIteratorWrapper(std::move(gen_factory)));
}

std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source_subscriber(mrc::segment::IBuilder& self,
const std::string& name,
py::function gen_factory)
{
return self.construct_object<SubscriberFuncWrapper>(name, std::move(gen_factory));
}

std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source_component(mrc::segment::IBuilder& self,
const std::string& name,
pybind11::iterator source_iterator)
Expand Down
8 changes: 7 additions & 1 deletion python/mrc/core/segment.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -134,6 +134,12 @@ PYBIND11_MODULE(segment, py_mod)
const std::string&,
py::function)>(&BuilderProxy::make_source));

Builder.def("make_source_subscriber",
static_cast<std::shared_ptr<mrc::segment::ObjectProperties> (*)(mrc::segment::IBuilder&,
const std::string&,
py::function)>(
&BuilderProxy::make_source_subscriber));

Builder.def("make_source_component",
static_cast<std::shared_ptr<mrc::segment::ObjectProperties> (*)(mrc::segment::IBuilder&,
const std::string&,
Expand Down
6 changes: 5 additions & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import typing

import pytest
Expand Down Expand Up @@ -50,6 +51,9 @@ def configure_tests_logging(is_debugger_attached: bool):
if (is_debugger_attached):
log_level = logging.INFO

if (os.environ.get('GLOG_v') is not None):
log_level = logging.DEBUG

mrc_logging.init_logging("mrc_testing", py_level=log_level)


Expand Down
98 changes: 97 additions & 1 deletion python/tests/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,6 +14,8 @@
# limitations under the License.

import asyncio
import os
import time
import typing

import pytest
Expand All @@ -30,6 +32,53 @@ def pairwise(t):
node_fn_type = typing.Callable[[mrc.Builder], mrc.SegmentObject]


@pytest.fixture
def source():

def build(builder: mrc.Builder):

def gen_data():
yield 1
yield 2
yield 3

return builder.make_source("source", gen_data)

return build


@pytest.fixture
def endless_source():

def build(builder: mrc.Builder):

def gen_data():
i = 0
while True:
yield i
i += 1
time.sleep(0.1)

return builder.make_source("endless_source", gen_data())

return build


@pytest.fixture
def blocking_source():

def build(builder: mrc.Builder):

def gen_data(subscriber: mrc.Subscriber):
yield 1
while subscriber.is_subscribed():
time.sleep(0.1)

return builder.make_source_subscriber("blocking_source", gen_data)

return build


@pytest.fixture
def source_pyexception():

Expand Down Expand Up @@ -64,6 +113,21 @@ def gen_data_and_raise():
return build


@pytest.fixture
def node_exception():

def build(builder: mrc.Builder):

def on_next(data):
time.sleep(1)
print("Received value: {}".format(data), flush=True)
raise RuntimeError("unittest")

return builder.make_node("node", mrc.core.operators.map(on_next))

return build


@pytest.fixture
def sink():

Expand Down Expand Up @@ -112,6 +176,8 @@ def build_executor():
def inner(pipe: mrc.Pipeline):
options = mrc.Options()

options.topology.user_cpuset = f"0-{os.cpu_count() - 1}"
options.engine_factories.default_engine_type = mrc.core.options.EngineType.Thread
executor = mrc.Executor(options)
executor.register_pipeline(pipe)

Expand Down Expand Up @@ -183,5 +249,35 @@ async def run_pipeline():
asyncio.run(run_pipeline())


@pytest.mark.parametrize("souce_name", ["source", "endless_source", "blocking_source"])
def test_pyexception_in_node(source: node_fn_type,
endless_source: node_fn_type,
blocking_source: node_fn_type,
node_exception: node_fn_type,
build_pipeline: build_pipeline_type,
build_executor: build_executor_type,
souce_name: str):
"""
Test to reproduce Morpheus issue #1838 where an exception raised in a node doesn't always shutdown the executor
when the source is intended to run indefinitely.
"""

if souce_name == "endless_source":
source_fn = endless_source
elif souce_name == "blocking_source":
source_fn = blocking_source
else:
source_fn = source

pipe = build_pipeline(source_fn, node_exception)

executor: mrc.Executor = None

executor = build_executor(pipe)

with pytest.raises(RuntimeError):
executor.join()


if (__name__ in ("__main__", )):
test_pyexception_in_source()

0 comments on commit 8489b45

Please sign in to comment.