Skip to content

Commit

Permalink
update mlflow_model name
Browse files Browse the repository at this point in the history
  • Loading branch information
safoinme committed Nov 15, 2023
1 parent 54b3e92 commit cda3816
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ settings:
- zenml[server]

extra:
mlflow_model_name: nlp_use_case_model
mlflow_model_name: sentiment_analysis
{%- if target_environment == 'production' %}
target_env: production
{%- else %}
Expand Down
2 changes: 1 addition & 1 deletion template/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def {{product_name}}_training_pipeline(
register_model(
model=model,
tokenizer=tokenizer,
mlflow_model_name="{{product_name}}_model",
mlflow_model_name="sentiment_analysis",
)

notify_on_success(after=["register_model"])
Expand Down
2 changes: 1 addition & 1 deletion template/steps/registrer/model_log_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def register_model(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
mlflow_model_name: Optional[str] = "model",
mlflow_model_name: Optional[str] = "sentiment_analysis",
):
"""
Register model to MLFlow.
Expand Down
2 changes: 1 addition & 1 deletion template/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def model_trainer(
load_best_model_at_end: Optional[bool] = True,
eval_batch_size: Optional[int] = 16,
weight_decay: Optional[float] = 0.01,
mlflow_model_name: Optional[str] = "model",
mlflow_model_name: Optional[str] = "sentiment_analysis",
) -> Tuple[Annotated[PreTrainedModel, "model", ModelArtifactConfig(overwrite=True)], Annotated[PreTrainedTokenizerBase, "tokenizer", ModelArtifactConfig(overwrite=True)]]:
"""
Configure and train a model on the training dataset.
Expand Down
5 changes: 3 additions & 2 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def generate_and_run_project(
deploy_to_skypilot: bool = False,
cloud_of_choice: str = "gcp",
dataset: str = "airline_reviews",
zenml_model_name: str = "sentiment_analysis",

):
"""Generate and run the starter project with different options."""
Expand Down Expand Up @@ -118,8 +119,8 @@ def generate_and_run_project(

# clean up
Client().delete_pipeline(product_name + pipeline_suffix)
Client().delete_model(product_name)
Client().active_stack.model_registry.delete_model(product_name)
Client().delete_model(zenml_model_name)
Client().active_stack.model_registry.delete_model(zenml_model_name)

os.chdir(current_dir)
shutil.rmtree(dst_path)
Expand Down

0 comments on commit cda3816

Please sign in to comment.