Skip to content

Commit

Permalink
update load artifact and metadata functions
Browse files Browse the repository at this point in the history
  • Loading branch information
safoinme committed Nov 28, 2023
1 parent 532435c commit 4f04b29
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 13 deletions.
7 changes: 2 additions & 5 deletions template/steps/deploying/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


from zenml import get_step_context, step
from zenml.client import Client
from zenml.logger import get_logger

# Initialize logger
Expand All @@ -25,17 +24,15 @@ def save_model_to_deploy():
"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
pipeline_extra = get_step_context().pipeline_run.config.extra
zenml_client = Client()

logger.info(
f" Loading latest version of the model for stage {pipeline_extra['target_env']}..."
)
# Get latest saved model version in target environment
latest_version = get_step_context().model_version._get_model_version()

# Load model and tokenizer from Model Control Plane
model = latest_version.get_model_object(name="model").load()
tokenizer = latest_version.get_model_object(name="tokenizer").load()
model = latest_version.load_artifact(name="model")
tokenizer = latest_version.load_artifact(name="tokenizer")
# Save the model and tokenizer locally
model_path = "./gradio/model" # replace with the actual path
tokenizer_path = "./gradio/tokenizer" # replace with the actual path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def promote_get_metrics(
# Get current model version metric in current run
model_version = get_step_context().model_version
current_version = model_version._get_model_version()
current_metrics = current_version.get_model_object(name="model").metadata["metrics"].value
current_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {current_metrics}")

# Get latest saved model version metric in target environment
Expand All @@ -49,7 +49,7 @@ def promote_get_metrics(
except KeyError:
latest_version = None
if latest_version:
latest_metrics = current_version.get_model_object(name="model").metadata["metrics"].value
latest_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {latest_metrics}")
else:
logger.info("No currently promoted model version found.")
Expand Down
10 changes: 6 additions & 4 deletions template/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
TrainingArguments,
AutoModelForSequenceClassification,
)
from zenml import log_artifact_metadata, step
from zenml import ArtifactConfig, log_artifact_metadata, step
from zenml.client import Client
from zenml.integrations.mlflow.experiment_trackers import MLFlowExperimentTracker
from zenml.logger import get_logger
from zenml.model import ModelArtifactConfig
from utils.misc import compute_metrics

# Initialize logger
Expand Down Expand Up @@ -47,7 +46,7 @@ def model_trainer(
eval_batch_size: Optional[int] = 16,
weight_decay: Optional[float] = 0.01,
mlflow_model_name: Optional[str] = "sentiment_analysis",
) -> Tuple[Annotated[PreTrainedModel, "model", ModelArtifactConfig(overwrite=True)], Annotated[PreTrainedTokenizerBase, "tokenizer", ModelArtifactConfig(overwrite=True)]]:
) -> Tuple[Annotated[PreTrainedModel, ArtifactConfig(name="model", is_model_artifact=True)], Annotated[PreTrainedTokenizerBase, ArtifactConfig(name="tokenizer", is_model_artifact=True)]]:
"""
Configure and train a model on the training dataset.
Expand Down Expand Up @@ -136,7 +135,10 @@ def model_trainer(
eval_results = trainer.evaluate(metric_key_prefix="")

# Log the evaluation results in model control plane
log_artifact_metadata(output_name="model", metrics=eval_results)
log_artifact_metadata(
metadata={"metrics": eval_results},
artifact_name="model",
)
### YOUR CODE ENDS HERE ###

return model, tokenizer
2 changes: 1 addition & 1 deletion template/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# {% include 'template/license_header' %}

from typing import Dict, Tuple, List
from typing import Dict, List, Tuple

import numpy as np
from datasets import load_metric
Expand Down
2 changes: 1 addition & 1 deletion tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def generate_and_run_project(
cloud_of_choice: str = "gcp",
dataset: str = "airline_reviews",
zenml_model_name: str = "sentiment_analysis",

):
"""Generate and run the starter project with different options."""

Expand Down Expand Up @@ -162,6 +161,7 @@ def test_latest_promotion(
tmp_path_factory=tmp_path_factory, metric_compare_promotion=False
)


def test_production_environment(
clean_zenml_client,
tmp_path_factory: pytest.TempPathFactory,
Expand Down

0 comments on commit 4f04b29

Please sign in to comment.