diff --git a/python/kvikio/kvikio/cufile_driver.py b/python/kvikio/kvikio/cufile_driver.py index 166bb73304..8556a11fa2 100644 --- a/python/kvikio/kvikio/cufile_driver.py +++ b/python/kvikio/kvikio/cufile_driver.py @@ -1,6 +1,8 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. +import atexit + from kvikio._lib import cufile_driver # type: ignore # TODO: Wrap nicely, maybe as a dataclass? @@ -14,6 +16,9 @@ def driver_open() -> None: opens the driver, but every call should have a matching call to `driver_close`. + Normally, it is not required to open and close the cuFile driver since + it is done automatically. + Raises ------ RuntimeError @@ -35,3 +40,18 @@ def driver_close() -> None: If cuFile isn't available. """ return cufile_driver.driver_close() + + +def initialize() -> None: + """Open the cuFile driver and close it again at module exit + + Normally, it is not required to open and close the cuFile driver since + it is done automatically. + + Raises + ------ + RuntimeError + If cuFile isn't available. + """ + driver_open() + atexit.register(driver_close) diff --git a/python/kvikio/tests/test_cufile_driver.py b/python/kvikio/tests/test_cufile_driver.py index 42208dcc9f..0a64bf0952 100644 --- a/python/kvikio/tests/test_cufile_driver.py +++ b/python/kvikio/tests/test_cufile_driver.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. import pytest