diff --git a/.typos.toml b/.typos.toml index fba9255c02..08d29ada65 100644 --- a/.typos.toml +++ b/.typos.toml @@ -32,6 +32,7 @@ prev = "prev" creat = "creat" ret = "ret" daa = "daa" +cll = "cll" [default] locale = "en-us" diff --git a/docs/book/stacks-and-components/component-guide/model-registries/mlflow.md b/docs/book/stacks-and-components/component-guide/model-registries/mlflow.md index 0ee2ae9135..71346066df 100644 --- a/docs/book/stacks-and-components/component-guide/model-registries/mlflow.md +++ b/docs/book/stacks-and-components/component-guide/model-registries/mlflow.md @@ -159,7 +159,7 @@ zenml model-registry models register-version Tensorflow-model \ ### Deploy a registered model -Afte you have registered a model in the MLflow model registry, you can also +After you have registered a model in the MLflow model registry, you can also easily deploy it as a prediction service. Checkout the [MLflow model deployer documentation](../model-deployers/mlflow.md#deploy-from-model-registry) for more information on how to do that. diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 413a27bbff..77a3d2a6f2 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -37,6 +37,7 @@ from zenml.logger import get_logger from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model +from zenml.models.v2.misc.scaler_models import ScalerModel from zenml.utils import deprecation_utils if TYPE_CHECKING: @@ -137,6 +138,7 @@ class StepConfigurationUpdate(StrictBaseModel): failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None model: Optional[Model] = None + scaler: Optional[ScalerModel] = None outputs: Mapping[str, PartialArtifactConfiguration] = {} diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 7dcad19898..fd53b9665c 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -384,3 +384,13 @@ class PluginSubType(StrEnum): WEBHOOK = "webhook" # Action Subtypes PIPELINE_RUN = "pipeline_run" + + +class AggregateFunction(StrEnum): + """All possible aggregation functions.""" + + COUNT = "count" + SUM = "sum" + MEAN = "mean" + MIN = "min" + MAX = "max" diff --git a/src/zenml/integrations/accelerate/__init__.py b/src/zenml/integrations/accelerate/__init__.py new file mode 100644 index 0000000000..b0daab7c84 --- /dev/null +++ b/src/zenml/integrations/accelerate/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Accelerate integration for ZenML.""" + +from zenml.integrations.accelerate.scalers.accelerate_scaler import AccelerateScaler + +__all__ = [ + "AccelerateScaler", +] \ No newline at end of file diff --git a/src/zenml/integrations/accelerate/scalers/__init__.py b/src/zenml/integrations/accelerate/scalers/__init__.py new file mode 100644 index 0000000000..6139deb965 --- /dev/null +++ b/src/zenml/integrations/accelerate/scalers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Accelerate scalers for ZenML.""" diff --git a/src/zenml/integrations/accelerate/scalers/accelerate_scaler.py b/src/zenml/integrations/accelerate/scalers/accelerate_scaler.py new file mode 100644 index 0000000000..2c2f91579e --- /dev/null +++ b/src/zenml/integrations/accelerate/scalers/accelerate_scaler.py @@ -0,0 +1,131 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility function to run Accelerate jobs.""" + +import subprocess +from typing import Any, Callable, Optional, TypeVar + +import cloudpickle as pickle + +from zenml.logger import get_logger +from zenml.models.v2.misc.scaler_models import ScalerModel +from zenml.utils.function_utils import _cli_arg_name, create_cli_wrapped_script + +logger = get_logger(__name__) +F = TypeVar("F", bound=Callable[..., None]) + + +class AccelerateScaler(ScalerModel): + """Accelerate scaler model. + + Accelerate package: https://huggingface.co/docs/accelerate/en/index + + Example: + ```python + from zenml import step + from zenml.integrations.accelerate import AccelerateScaler + + @step(scaler=AccelerateScaler(num_processes=42)) + def training_step(some_param: int, ...): + # your training code is below + ... + ``` + + Args: + num_processes: The number of processes to use (shall be less or equal to GPUs count). + """ + + num_processes: Optional[int] = None + + def run(self, step_function: F, **function_kwargs: Any) -> Any: + """Run a function with accelerate. + + Accelerate package: https://huggingface.co/docs/accelerate/en/index + + Example: + ```python + from zenml import step + from zenml.integrations.accelerate import AccelerateScaler + + @step(scaler=AccelerateScaler(num_processes=42)) + def training_step(some_param: int, ...): + # your training code is below + ... + ``` + + Args: + step_function: The function to run. + **function_kwargs: The keyword arguments to pass to the function. + + Returns: + The return value of the function in the main process. + + Raises: + CalledProcessError: If the function fails. + """ + import torch + + logger.info("Starting accelerate job...") + + device_count = torch.cuda.device_count() + if self.num_processes is None: + num_processes = device_count + else: + if self.num_processes > device_count: + logger.warning( + f"Number of processes ({self.num_processes}) is greater than " + f"the number of available GPUs ({device_count}). Using all GPUs." + ) + num_processes = device_count + num_processes = self.num_processes + + with create_cli_wrapped_script( + step_function, flavour="accelerate" + ) as ( + script_path, + output_path, + ): + command = f"accelerate launch --num_processes {num_processes} " + command += str(script_path.absolute()) + " " + for k, v in function_kwargs.items(): + k = _cli_arg_name(k) + if isinstance(v, bool): + if v: + command += f"--{k} " + elif isinstance(v, str): + command += f'--{k} "{v}" ' + elif type(v) in (list, tuple, set): + for each in v: + command += f"--{k} {each} " + else: + command += f"--{k} {v} " + + logger.info(command) + + result = subprocess.run( + command, + shell=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + for stdout_line in result.stdout.split("\n"): + logger.info(stdout_line) + if result.returncode == 0: + logger.info("Accelerate training job finished.") + return pickle.load(open(output_path, "rb")) + else: + logger.error( + f"Accelerate training job failed. With return code {result.returncode}." + ) + raise subprocess.CalledProcessError(result.returncode, command) diff --git a/src/zenml/models/v2/misc/scaler_models.py b/src/zenml/models/v2/misc/scaler_models.py new file mode 100644 index 0000000000..59ab0ae1fa --- /dev/null +++ b/src/zenml/models/v2/misc/scaler_models.py @@ -0,0 +1,84 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Model definitions for ZenML scalers.""" + +from typing import Any, Callable, ClassVar, Dict, Optional, Set, TypeVar + +from pydantic import BaseModel, root_validator + +F = TypeVar("F", bound=Callable[..., None]) + + +class ScalerModel(BaseModel): + """Domain model for scalers.""" + + scaler_flavor: Optional[str] = None + + ALLOWED_SCALER_FLAVORS: ClassVar[Set[str]] = { + "AggregateScaler", + "AccelerateScaler", + } + + @root_validator(pre=True) + def validate_scaler_flavor(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate the scaler flavor. + + Args: + values: The values to validate. + + Returns: + The validated values. + + Raises: + ValueError: If the scaler flavor is not supported. + """ + if values.get("scaler_flavor", None) is None: + values["scaler_flavor"] = cls.__name__ # type: ignore[attr-defined] + if values["scaler_flavor"] not in cls.ALLOWED_SCALER_FLAVORS: + raise ValueError( + f"Invalid scaler flavor {values['scaler_flavor']}. " + f"Allowed values are {cls.ALLOWED_SCALER_FLAVORS}" + ) + return values + + def run(self, step_function: F, **kwargs: Any) -> Any: + """Run the step using scaler. + + Args: + step_function: The step function to run. + **kwargs: Additional arguments to pass to the step function. + + Returns: + The result of the step function as per scaler config. + + Raises: + NotImplementedError: If the scaler flavor is not supported. + """ + if self.scaler_flavor == "AccelerateScaler": + from zenml.integrations.accelerate import AccelerateScaler + + runner = AccelerateScaler(**self.dict()) + elif self.scaler_flavor == "AggregateScaler": + from zenml.scalers import AggregateScaler + + runner = AggregateScaler(**self.dict()) # type: ignore[assignment] + else: + raise NotImplementedError + + return runner.run(step_function, **kwargs) + + class Config: + """Pydantic model configuration.""" + + extra = "allow" diff --git a/src/zenml/new/steps/step_decorator.py b/src/zenml/new/steps/step_decorator.py index ba306f3140..d11d03dac9 100644 --- a/src/zenml/new/steps/step_decorator.py +++ b/src/zenml/new/steps/step_decorator.py @@ -36,6 +36,7 @@ from zenml.config.source import Source from zenml.materializers.base_materializer import BaseMaterializer from zenml.model.model import Model + from zenml.models.v2.misc.scaler_models import ScalerModel from zenml.steps import BaseStep MaterializerClassOrSource = Union[str, Source, Type[BaseMaterializer]] @@ -73,6 +74,7 @@ def step( on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, model_version: Optional["Model"] = None, # TODO: deprecate me + scaler: Optional["ScalerModel"] = None, ) -> Callable[["F"], "BaseStep"]: ... @@ -93,6 +95,7 @@ def step( on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, model_version: Optional["Model"] = None, # TODO: deprecate me + scaler: Optional["ScalerModel"] = None, ) -> Union["BaseStep", Callable[["F"], "BaseStep"]]: """Decorator to create a ZenML step. @@ -124,6 +127,7 @@ def step( (e.g. `module.my_function`). model: configuration of the model in the Model Control Plane. model_version: DEPRECATED, please use `model` instead. + scaler: configuration of the scaler for this step. Returns: The step instance. @@ -162,6 +166,7 @@ def inner_decorator(func: "F") -> "BaseStep": on_failure=on_failure, on_success=on_success, model=model or model_version, + scaler=scaler, ) return step_instance diff --git a/src/zenml/scalers/__init__.py b/src/zenml/scalers/__init__.py new file mode 100644 index 0000000000..6f627ef686 --- /dev/null +++ b/src/zenml/scalers/__init__.py @@ -0,0 +1,5 @@ +from zenml.scalers.aggregate_scaler import AggregateScaler + +__all__ = [ + "AggregateScaler", +] \ No newline at end of file diff --git a/src/zenml/scalers/aggregate_scaler.py b/src/zenml/scalers/aggregate_scaler.py new file mode 100644 index 0000000000..67e58bb6f0 --- /dev/null +++ b/src/zenml/scalers/aggregate_scaler.py @@ -0,0 +1,127 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Aggregate local scaler.""" + +import inspect +from multiprocessing.pool import ThreadPool +from typing import Any, Callable, Dict, List, TypeVar + +from pydantic import validator + +from zenml.enums import AggregateFunction +from zenml.logger import get_logger +from zenml.models.v2.misc.scaler_models import ScalerModel +from zenml.scalers.utils import AGGREGATE_FUNCTIONS + +logger = get_logger(__name__) +F = TypeVar("F", bound=Callable[..., None]) + + +class AggregateScaler(ScalerModel): + """Aggregate scaler. + + Example: + ```python + from zenml import step, pipeline + from zenml.scalers import AggregateScaler + + @step(scaler=AggregateScaler(parameters={"a":[1,2,3],"b":[4,5,6]}, agg_function="sum")) + def training_step_with_sum_aggregation(a:int = None, b:int = None, c:int = 2)->int: + # your code is below + return a+b+c + + @pipeline + def pipeline_with_aggregate_scaler(): + training_step_with_sum_aggregation(c=3) + # actual step output would be (1+4+3)+(2+5+3)+(3+6+3) = 30, + # where last "+3" comes from constant `c` parameter + ``` + + Args: + parameters: The parameters to run on. + num_processes: The number of processes to use (shall be less or equal to GPUs count). + agg_function: The aggregation function to use. + """ + + parameters: Dict[str, List[Any]] + num_processes: int = 1 + agg_function: AggregateFunction = AggregateFunction.SUM + + @validator("parameters") + def validate_values( + cls, parameters: Dict[str, List[Any]] + ) -> Dict[str, List[Any]]: + """Validate the parameters. + + Args: + parameters: The parameters to run on. + + Returns: + The validated parameters. + + Raises: + ValueError: If the parameters are not of the same length. + """ + lengths = {} + first_length = None + lengths_are_different = False + for k, v in parameters.items(): + lengths[k] = len(v) + if first_length is None: + first_length = len(v) + elif len(v) != first_length: + lengths_are_different = True + if lengths_are_different: + raise ValueError( + f"Parameters are not of the same length: {lengths}" + ) + return parameters + + def run(self, step_function: F, **function_kwargs: Any) -> Any: + """Run a function with matrix strategy. + + Args: + step_function: The step function to run. + **function_kwargs: Additional arguments to pass to the step function. + + Returns: + The result of the step function. + + Raises: + ValueError: If the function arguments do not match the parameters. + """ + logger.info("Starting aggregate job...") + function_arg_names = inspect.getargs(step_function.__code__).args + given_arg_names = set(self.parameters.keys()).union( + set(function_kwargs.keys()) + ) + if set(function_arg_names) != given_arg_names: + raise ValueError( + f"Function arguments {function_arg_names} do not match parameters configured {given_arg_names}" + ) + params_to_pass: List[List[Any]] = [] + for i in range(len(self.parameters[function_arg_names[0]])): + params_to_pass.append([]) + for arg_name in function_arg_names: + if arg_name in self.parameters: + params_to_pass[i].append(self.parameters[arg_name][i]) + else: + params_to_pass[i].append(function_kwargs[arg_name]) + + result = ThreadPool(processes=self.num_processes).starmap( + step_function, + params_to_pass, + ) + + return AGGREGATE_FUNCTIONS[self.agg_function](result) diff --git a/src/zenml/scalers/utils.py b/src/zenml/scalers/utils.py new file mode 100644 index 0000000000..33b1eb6181 --- /dev/null +++ b/src/zenml/scalers/utils.py @@ -0,0 +1,24 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility functions for the scalers.""" + +from typing import Any, Callable, Dict, List + +AGGREGATE_FUNCTIONS: Dict[str, Callable[[List[Any]], Any]] = { + "sum": sum, + "mean": lambda x: sum(x) / len(x), + "max": max, + "min": min, + "count": len, +} diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 1f464c1d21..03fa0c1c03 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -74,6 +74,7 @@ ) from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model + from zenml.models.v2.misc.scaler_models import ScalerModel ParametersOrDict = Union["BaseParameters", Dict[str, Any]] MaterializerClassOrSource = Union[str, Source, Type["BaseMaterializer"]] @@ -140,6 +141,7 @@ def __init__( on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, + scaler: Optional["ScalerModel"] = None, **kwargs: Any, ) -> None: """Initializes a step. @@ -169,6 +171,7 @@ def __init__( be a function with no arguments, or a source path to such a function (e.g. `module.my_function`). model: configuration of the model version in the Model Control Plane. + scaler: The scaler to use for this step. **kwargs: Keyword arguments passed to the step. """ from zenml.config.step_configurations import PartialStepConfiguration @@ -242,6 +245,7 @@ def __init__( on_failure=on_failure, on_success=on_success, model=model, + scaler=scaler, ) self._verify_and_apply_init_params(*args, **kwargs) @@ -647,8 +651,12 @@ def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any: ) except ValidationError as e: raise StepInterfaceError("Invalid entrypoint arguments.") from e - - return self.entrypoint(**validated_args) + if self.configuration.scaler: + return self.configuration.scaler.run( + self.entrypoint, **validated_args + ) + else: + return self.entrypoint(**validated_args) @property def name(self) -> str: @@ -695,6 +703,7 @@ def configure( on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, + scaler: Optional["ScalerModel"] = None, merge: bool = True, ) -> T: """Configures the step. @@ -733,6 +742,7 @@ def configure( be a function with no arguments, or a source path to such a function (e.g. `module.my_function`). model: configuration of the model version in the Model Control Plane. + scaler: configuration of the scaler for this step. merge: If `True`, will merge the given dictionary configurations like `parameters` and `settings` with existing configurations. If `False` the given configurations will @@ -807,6 +817,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, "model": model, + "scaler": scaler, } ) config = StepConfigurationUpdate(**values) @@ -830,6 +841,7 @@ def with_options( on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, + scaler: Optional["ScalerModel"] = None, merge: bool = True, ) -> "BaseStep": """Copies the step and applies the given configurations. @@ -857,6 +869,7 @@ def with_options( be a function with no arguments, or a source path to such a function (e.g. `module.my_function`). model: configuration of the model version in the Model Control Plane. + scaler: configuration of the scaler for this step. merge: If `True`, will merge the given dictionary configurations like `parameters` and `settings` with existing configurations. If `False` the given configurations will @@ -882,6 +895,7 @@ def with_options( on_success=on_success, model=model, merge=merge, + scaler=scaler, ) return step_copy diff --git a/src/zenml/utils/cuda_utils.py b/src/zenml/utils/cuda_utils.py new file mode 100644 index 0000000000..bd7f6abb89 --- /dev/null +++ b/src/zenml/utils/cuda_utils.py @@ -0,0 +1,54 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions to improve CUDA experience.""" + +import gc + +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +def cleanup_gpu_memory(force: bool = False) -> None: + """Clean up GPU memory. + + This will clean up all GPU memory on current physical machine. + This action is considered to be dangerous by default, since + it might affect other processes running in the same environment. + If this is intended, please, explicitly pass `force=True`. + + Args: + force: whether to force cleanup or not + """ + if not force: + logger.warning( + "This will clean up all GPU memory on current physical machine. " + "This action is considered to be dangerous by default, since " + "it might affect other processes running in the same environment. " + "If this is intended, please, explicitly pass `force=True`." + ) + else: + try: + import torch + except ModuleNotFoundError: + logger.warning( + "No PyTorch installed. Skipping GPU memory cleanup." + ) + return + + logger.info("Cleaning up GPU memory...") + while gc.collect(): + torch.cuda.empty_cache() diff --git a/src/zenml/utils/function_utils.py b/src/zenml/utils/function_utils.py new file mode 100644 index 0000000000..2cd0146fb5 --- /dev/null +++ b/src/zenml/utils/function_utils.py @@ -0,0 +1,221 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility functions for python functions.""" + +import inspect +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, Iterator, List, Tuple, TypeVar, Union + +import click + +from zenml.logger import get_logger +from zenml.utils.string_utils import random_str + +F = TypeVar("F", bound=Callable[..., None]) + +logger = get_logger(__name__) + +_CLI_WRAPPED_SCRIPT_TEMPLATE_HEADER = """ +from zenml.utils.function_utils import _cli_wrapped_function + +import sys +sys.path.append("{func_path}") +from {func_module} import {func_name} as func_to_wrap + +if entrypoint:=getattr(func_to_wrap, "entrypoint", None): + func = _cli_wrapped_function(entrypoint) +else: + func = _cli_wrapped_function(func_to_wrap) + +""" +_CLI_WRAPPED_ACCELERATE_MAIN = """ +if __name__=="__main__": + from accelerate import Accelerator + import cloudpickle as pickle + + accelerator = Accelerator() + + ret = func(standalone_mode=False) + + if accelerator.is_main_process: + pickle.dump(ret, open("{output_file}", "wb")) +""" +_ALLOWED_TYPES = (str, int, float, bool, Path) +_ALLOWED_COLLECTIONS = (tuple,) +_TYPES_MAPPER = { + str: click.STRING, + int: click.INT, + float: click.FLOAT, + bool: click.BOOL, + Path: click.STRING, + None: click.STRING, +} + + +def _cli_arg_name(arg_name: str) -> str: + return arg_name.replace("_", "-") + + +def _is_valid_collection_arg(arg_type: Any) -> bool: + if getattr(arg_type, "__origin__", None) in _ALLOWED_COLLECTIONS: + if arg_type.__args__[0] not in _ALLOWED_TYPES: + return False + return True + return False + + +def _is_valid_optional_arg(arg_type: Any) -> bool: + if ( + getattr(arg_type, "_name", None) == "Optional" + and getattr(arg_type, "__origin__", None) == Union + ): + if args := getattr(arg_type, "__args__", None): + if len(args) != 2: + return False + if ( + args[0] not in _ALLOWED_TYPES + and not _is_valid_collection_arg(args[0]) + ) or args[1] != type(None): + return False + return True + return False + + +def _cli_wrapped_function(func: F) -> F: + """Create a decorator to generate the CLI-wrapped function. + + Args: + func: The function to decorate. + + Returns: + The inner decorator. + + Raises: + ValueError: If the function arguments are not valid. + """ + options: List[Any] = [] + fullargspec = inspect.getfullargspec(func) + if fullargspec.defaults is not None: + defaults = [None] * ( + len(fullargspec.args) - len(fullargspec.defaults) + ) + list(fullargspec.defaults) + else: + defaults = [None] * len(fullargspec.args) + input_args_dict = ( + ( + arg_name, + fullargspec.annotations.get(arg_name, None), + defaults[i], + ) + for i, arg_name in enumerate(fullargspec.args) + ) + invalid_types = {} + for arg_name, arg_type, arg_default in input_args_dict: + if _is_valid_optional_arg(arg_type): + arg_type = arg_type.__args__[0] + arg_name = _cli_arg_name(arg_name) + if arg_type == bool: + options.append( + click.option( + f"--{arg_name}", + type=click.BOOL, + is_flag=True, + default=False, + required=False, + ) + ) + elif _is_valid_collection_arg(arg_type): + member_type = arg_type.__args__[0] + options.append( + click.option( + f"--{arg_name}", + type=member_type, + default=arg_default, + required=False, + multiple=True, + ) + ) + elif arg_type in _ALLOWED_TYPES: + options.append( + click.option( + f"--{arg_name}", + type=_TYPES_MAPPER[arg_type], + default=arg_default, + required=False if arg_default is not None else True, + ) + ) + else: + invalid_types[arg_name] = arg_type + if invalid_types: + raise ValueError( + f"Invalid argument types: {invalid_types}. CLI functions only " + f"supports: {_ALLOWED_TYPES} types (including Optional) and " + f"{_ALLOWED_COLLECTIONS} collections." + ) + options.append( + click.command( + help="Technical wrapper to pass into the `accelerate launch` command." + ) + ) + + def wrapper(function: F) -> F: + for option in reversed(options): + function = option(function) + return function + + func.__doc__ = ( + f"{func.__doc__}\n\nThis is ZenML-generated " "CLI wrapper function." + ) + + return wrapper(func) + + +@contextmanager +def create_cli_wrapped_script( + func: F, flavour: str = "accelerate" +) -> Iterator[Tuple[Path, Path]]: + """Create a script with the CLI-wrapped function. + + Args: + func: The function to use. + flavour: The flavour to use. + + Yields: + The paths of the script and the output. + """ + try: + func_path = str(Path(inspect.getabsfile(func)).parent) + random_name = random_str(20) + script_path = Path(random_name + ".py") + output_path = Path(random_name + ".out") + + with open(script_path, "w") as f: + script = _CLI_WRAPPED_SCRIPT_TEMPLATE_HEADER.format( + func_path=func_path, + func_module=func.__module__, + func_name=func.__name__, + ) + if flavour == "accelerate": + script += _CLI_WRAPPED_ACCELERATE_MAIN.format( + output_file=str(output_path.absolute()) + ) + f.write(script) + + logger.info(f"Created script:\n\n{script}") + + yield script_path, output_path + finally: + script_path.unlink() + output_path.unlink()