diff --git a/docs/book/user-guide/advanced-guide/pipelining-features/configure-steps-pipelines.md b/docs/book/user-guide/advanced-guide/pipelining-features/configure-steps-pipelines.md index 439ce16ab4..375b599cc2 100644 --- a/docs/book/user-guide/advanced-guide/pipelining-features/configure-steps-pipelines.md +++ b/docs/book/user-guide/advanced-guide/pipelining-features/configure-steps-pipelines.md @@ -132,6 +132,7 @@ python run.py An example of a generated YAML configuration template ```yaml +stack: Optional[str] build: Union[PipelineBuildBase, UUID, NoneType] enable_artifact_metadata: Optional[bool] enable_artifact_visualization: Optional[bool] @@ -330,6 +331,12 @@ These are boolean flags for various configurations: * `enable_cache`: Utilize [caching](../../starter-guide/cache-previous-executions.md) or not. * `enable_step_logs`: Enable tracking [step logs](managing-steps.md#enable-or-disable-logs-storing). +### `active_stack` name or ID + +The name or the UUID of the `active stack` to use for this +pipeline. If specified, the active stack is set for the duration of the pipeline execution and restored upon +completion. If not specified, the current active stack is used. + ### `build` ID The UUID of the [`build`](../infrastructure-management/containerize-your-pipeline.md) to use for this pipeline. If specified, Docker image building is skipped for remote orchestrators, and the Docker image specified in this build is used. diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index f667e8c4f0..4e9e34b887 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -34,6 +34,7 @@ ScheduleFilter, ) from zenml.new.pipelines.pipeline import Pipeline +from zenml.stack.utils import temporary_active_stack from zenml.utils import source_utils, uuid_utils from zenml.utils.yaml_utils import write_yaml @@ -184,7 +185,7 @@ def build_pipeline( name_id_or_prefix=pipeline_name_or_id, version=version ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + with temporary_active_stack(stack_name_or_id=stack_name_or_id): pipeline_instance = Pipeline.from_model(pipeline_model) build = pipeline_instance.build(config_path=config_path) @@ -286,7 +287,7 @@ def run_pipeline( "or file path." ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + with temporary_active_stack(stack_name_or_id=stack_name_or_id): pipeline_instance = Pipeline.from_model(pipeline_model) pipeline_instance = pipeline_instance.with_options( config_path=config_path, diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 29fab25bd0..bcefd8182d 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Utility functions for the CLI.""" -import contextlib import datetime import json import os @@ -26,7 +25,6 @@ Any, Callable, Dict, - Iterator, List, NoReturn, Optional, @@ -78,8 +76,6 @@ from zenml.zen_server.deploy import ServerDeployment if TYPE_CHECKING: - from uuid import UUID - from rich.text import Text from zenml.enums import ExecutionStatus @@ -2489,33 +2485,6 @@ def wrapper(function: F) -> F: return inner_decorator -@contextlib.contextmanager -def temporary_active_stack( - stack_name_or_id: Union["UUID", str, None] = None, -) -> Iterator["Stack"]: - """Contextmanager to temporarily activate a stack. - - Args: - stack_name_or_id: The name or ID of the stack to activate. If not given, - this contextmanager will not do anything. - - Yields: - The active stack. - """ - from zenml.client import Client - - try: - if stack_name_or_id: - old_stack_id = Client().active_stack_model.id - Client().activate_stack(stack_name_or_id) - else: - old_stack_id = None - yield Client().active_stack - finally: - if old_stack_id: - Client().activate_stack(old_stack_id) - - def get_package_information( package_names: Optional[List[str]] = None, ) -> Dict[str, str]: diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index 7eca7dead7..fa929bdec7 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -30,6 +30,7 @@ class PipelineRunConfiguration( ): """Class for pipeline run configurations.""" + stack: Optional[str] = None run_name: Optional[str] = None enable_cache: Optional[bool] = None enable_artifact_metadata: Optional[bool] = None diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 5b0358f93d..7ff767327f 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -76,6 +76,7 @@ prepare_model_versions, ) from zenml.stack import Stack +from zenml.stack.utils import temporary_active_stack from zenml.steps import BaseStep from zenml.steps.entrypoint_function_utils import ( StepArtifact, @@ -537,7 +538,9 @@ def build( Returns: The build output. """ - with track_handler(event=AnalyticsEvent.BUILD_PIPELINE): + with track_handler( + event=AnalyticsEvent.BUILD_PIPELINE + ), temporary_active_stack(): self._prepare_if_possible() deployment, pipeline_spec, _, _ = self._compile( config_path=config_path, @@ -620,7 +623,9 @@ def _run( logger.info(f"Initiating a new run for the pipeline: `{self.name}`.") - with track_handler(AnalyticsEvent.RUN_PIPELINE) as analytics_handler: + with track_handler( + AnalyticsEvent.RUN_PIPELINE + ) as analytics_handler, temporary_active_stack(): deployment, pipeline_spec, schedule, build = self._compile( config_path=config_path, run_name=run_name, @@ -1015,6 +1020,8 @@ def _compile( # Update with the values in code so they take precedence run_config = pydantic_utils.update_model(run_config, update=update) + self._update_stack_from_config(run_config) + deployment, pipeline_spec = Compiler().compile( pipeline=self, stack=Client().active_stack, @@ -1438,3 +1445,14 @@ def _prepare_if_possible(self) -> None: ) else: self.prepare() + + def _update_stack_from_config( + self, run_configuration: PipelineRunConfiguration + ) -> None: + """Activate the stack from the pipeline run configuration if one is given. + + Args: + run_configuration: The run configuration for this pipeline. + """ + if run_configuration.stack is not None: + Client().activate_stack(run_configuration.stack) diff --git a/src/zenml/stack/utils.py b/src/zenml/stack/utils.py index fa1cc7da77..6f41a64aca 100644 --- a/src/zenml/stack/utils.py +++ b/src/zenml/stack/utils.py @@ -13,13 +13,16 @@ # permissions and limitations under the License. """Util functions for handling stacks, components, and flavors.""" -from typing import Any, Dict, Optional +import contextlib +from typing import Any, Dict, Generator, Optional, Union +from uuid import UUID from zenml.client import Client from zenml.enums import StackComponentType, StoreType from zenml.logger import get_logger from zenml.models import FlavorFilter, FlavorResponse from zenml.stack.flavor import Flavor +from zenml.stack.stack import Stack from zenml.stack.stack_component import StackComponentConfig from zenml.zen_stores.base_zen_store import BaseZenStore @@ -139,3 +142,28 @@ def get_flavor_by_name_and_type_from_zen_store( f"'{component_type}' exists." ) return flavors[0] + + +@contextlib.contextmanager +def temporary_active_stack( + stack_name_or_id: Union[UUID, str, None] = None, +) -> Generator[Stack, Any, Any]: + """Contextmanager to temporarily activate a stack. + + Args: + stack_name_or_id: The name or ID of the stack to activate. If not given, + this contextmanager will not do anything. + + Yields: + The active stack. + """ + try: + if stack_name_or_id: + old_stack_id = Client().active_stack_model.id + Client().activate_stack(stack_name_or_id) + else: + old_stack_id = None + yield Client().active_stack + finally: + if old_stack_id: + Client().activate_stack(old_stack_id) diff --git a/tests/integration/functional/cli/test_pipeline.py b/tests/integration/functional/cli/test_pipeline.py index 76aab5720c..ad507284ac 100644 --- a/tests/integration/functional/cli/test_pipeline.py +++ b/tests/integration/functional/cli/test_pipeline.py @@ -368,6 +368,39 @@ def test_pipeline_run_with_config_file(clean_client: "Client", tmp_path): assert runs[0].name == "custom_run_name" +def test_pipeline_run_with_different_stack_in_config_file( + clean_client: "Client", tmp_path +): + """Tests that the run command works with a run config file with an active stack defined.""" + runner = CliRunner() + run_command = cli.commands["pipeline"].commands["run"] + + pipeline_id = pipeline_instance.register().id + + components = { + key: components[0].id + for key, components in Client().active_stack_model.components.items() + } + new_stack = Client().create_stack(name="new", components=components) + + config_path = tmp_path / "config.yaml" + run_config = PipelineRunConfiguration( + run_name="custom_run_name", stack=str(new_stack.id) + ) + config_path.write_text(run_config.yaml()) + + result = runner.invoke( + run_command, [pipeline_instance.name, "--config", str(config_path)] + ) + assert result.exit_code == 0 + + runs = Client().list_pipeline_runs(pipeline_id=pipeline_id) + assert len(runs) == 1 + assert runs[0].name == "custom_run_name" + assert runs[0].stack.id == new_stack.id + assert Client().active_stack.id != new_stack.id + + def test_pipeline_run_with_different_stack(clean_client: "Client"): """Tests that the run command works with a different stack.""" runner = CliRunner() diff --git a/tests/integration/functional/cli/utils.py b/tests/integration/functional/cli/utils.py index 228f6149b6..d4634ce395 100644 --- a/tests/integration/functional/cli/utils.py +++ b/tests/integration/functional/cli/utils.py @@ -16,10 +16,7 @@ from tests.harness.harness import TestHarness from zenml.cli import cli -from zenml.cli.utils import ( - parse_name_and_extra_arguments, - temporary_active_stack, -) +from zenml.cli.utils import parse_name_and_extra_arguments from zenml.client import Client from zenml.models import ( TagFilter, @@ -27,6 +24,7 @@ UserResponse, WorkspaceResponse, ) +from zenml.stack.utils import temporary_active_stack from zenml.utils.string_utils import random_str SAMPLE_CUSTOM_ARGUMENTS = [