Skip to content

Commit

Permalink
0.55.2 update
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Feb 8, 2024
1 parent 39d7da7 commit f51345c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 12 deletions.
8 changes: 5 additions & 3 deletions stack-showcase/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
inference_preprocessor,
)

from zenml import get_pipeline_context, pipeline, ExternalArtifact
from zenml import get_pipeline_context, pipeline
from zenml.logger import get_logger

logger = get_logger(__name__)
Expand All @@ -41,10 +41,12 @@ def inference(random_state: str, target: str):
target: Name of target column in dataset.
"""
# Get the production model artifact
model = ExternalArtifact(name="sklearn_classifier")
model = get_pipeline_context().model.get_artifact("sklearn_classifier")

# Get the preprocess pipeline artifact associated with this version
preprocess_pipeline = ExternalArtifact(name="preprocess_pipeline")
preprocess_pipeline = get_pipeline_context().model.get_artifact(
"preprocess_pipeline"
)

# Link all the steps together by calling them and passing the output
# of one step as the input of the next step.
Expand Down
2 changes: 1 addition & 1 deletion stack-showcase/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
zenml[server]>=0.50.0
zenml[server]>=0.55.2
notebook
scikit-learn<1.3
s3fs>2022.3.0,<=2023.4.0
Expand Down
40 changes: 34 additions & 6 deletions stack-showcase/run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"# pre-demo recommendations\n",
"! zenml connect --url https://1cf18d95-zenml.cloudinfra.zenml.io \n",
"! zenml delete model breast_cancer_classifier -y"
"! zenml model delete breast_cancer_classifier -y"
]
},
{
Expand Down Expand Up @@ -549,8 +549,6 @@
"metadata": {},
"outputs": [],
"source": [
"from zenml import ExternalArtifact\n",
"\n",
"@pipeline\n",
"def training(\n",
" train_dataset_id: Optional[UUID] = None,\n",
Expand All @@ -565,8 +563,8 @@
" dataset_trn, dataset_tst = feature_engineering()\n",
" else:\n",
" # Load the datasets from an older pipeline\n",
" dataset_trn = ExternalArtifact(id=train_dataset_id)\n",
" dataset_tst = ExternalArtifact(id=test_dataset_id)\n",
" dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n",
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id)\n",
"\n",
" trained_model = model_trainer(\n",
" dataset_trn=dataset_trn,\n",
Expand Down Expand Up @@ -919,7 +917,7 @@
" df_inference = inference_preprocessor(\n",
" dataset_inf=df_inference,\n",
" # We use the preprocess pipeline from the feature engineering pipeline\n",
" preprocess_pipeline=ExternalArtifact(id=preprocess_pipeline_id),\n",
" preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n",
" target=target,\n",
" )\n",
" inference_predict(\n",
Expand Down Expand Up @@ -1213,6 +1211,36 @@
"print(f'SGD version: {sgd_model_version.run_metadata[\"wandb_url\"].value}')\n",
"print(f'RF version: {rf_model_version.run_metadata[\"wandb_url\"].value}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With Model Control Plane we can also easily track lineage of artifacts and pipeline runs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for artifact_name, versions in sgd_model_version.data_artifacts.items():\n",
" if versions:\n",
" print(f\"Existing version of `{artifact_name}`:\")\n",
" for version_name, artifact_ in versions.items():\n",
" print(version_name, artifact_.data_type.attribute)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for run_name, run_ in sgd_model_version.pipeline_runs.items():\n",
" print(run_name, run_.id)"
]
}
],
"metadata": {
Expand Down
11 changes: 9 additions & 2 deletions stack-showcase/steps/model_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from sklearn.base import ClassifierMixin
from sklearn.metrics import confusion_matrix

from zenml import log_artifact_metadata, step, log_model_metadata
from zenml import log_artifact_metadata, step, log_model_metadata, get_step_context
from zenml.logger import get_logger
import wandb
from zenml.client import Client
from zenml.exceptions import StepContextError


logger = get_logger(__name__)
Expand Down Expand Up @@ -109,7 +110,13 @@ def model_evaluator(
.ravel()
.tolist(),
}
log_model_metadata(metadata={"wandb_url": wandb.run.url})
try:
if get_step_context().model:
log_model_metadata(metadata={"wandb_url": wandb.run.url})
except StepContextError:
# if model not configured not able to log metadata
pass

log_artifact_metadata(
metadata=metadata,
artifact_name="sklearn_classifier",
Expand Down

0 comments on commit f51345c

Please sign in to comment.