Skip to content

Commit

Permalink
Merge pull request #75 from aak7912/main
Browse files Browse the repository at this point in the history
pass datasource object to compare_training_runs() to filter by datasource
  • Loading branch information
htahir1 authored Apr 22, 2021
2 parents 7c70d0d + dd3697a commit 9657a0e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 47 deletions.
4 changes: 2 additions & 2 deletions zenml/repo/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,11 @@ def load_pipeline_config(self, file_name: Text) -> Dict[Text, Any]:
pipelines_dir = self.zenml_config.get_pipelines_dir()
return yaml_utils.read_yaml(os.path.join(pipelines_dir, file_name))

def compare_training_runs(self, port: int = 0):
def compare_training_runs(self, port: int = 0, datasource=None):
"""Launch the compare app for all training pipelines in repo"""
from zenml.utils.post_training.post_training_utils import \
launch_compare_tool
launch_compare_tool(port)
launch_compare_tool(port, datasource)

def clean(self):
"""Deletes associated metadata store, pipelines dir and artifacts"""
Expand Down
25 changes: 15 additions & 10 deletions zenml/utils/post_training/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import plotly.graph_objects as go
import tensorflow_model_analysis as tfma

from zenml.enums import PipelineStatusTypes, GDPComponent
from zenml.pipelines import TrainingPipeline
from zenml.repo import Repository
from zenml.enums import PipelineStatusTypes, GDPComponent

pn.extension('plotly')

Expand All @@ -31,17 +31,26 @@ class Application(param.Parameterized):
slicing_metric_selector = param.ObjectSelector(default='', objects=[''])
performance_metric_selector = param.ObjectSelector(objects=[])

def __init__(self, **params):
def __init__(self, datasource=None, **params):
super(Application, self).__init__(**params)

# lists
result_list = []
hparam_list = []
repo: Repository = Repository.get_instance()
self.datasource = datasource

# get all pipelines in this workspace
all_pipelines: List[TrainingPipeline] = repo.get_pipelines_by_type([
TrainingPipeline.PIPELINE_TYPE])
if datasource:
# filter pipeline by datasource, and then the training ones
all_pipelines: List[TrainingPipeline] = \
repo.get_pipelines_by_datasource(datasource)
all_pipelines = [p for p in all_pipelines if
p.PIPELINE_TYPE == TrainingPipeline.PIPELINE_TYPE]
else:
all_pipelines: List[TrainingPipeline] = repo.get_pipelines_by_type(
[
TrainingPipeline.PIPELINE_TYPE])

# get a dataframe of all results + all hyperparameter combinations
for p in all_pipelines:
Expand Down Expand Up @@ -177,8 +186,8 @@ def parameter_graph(self):
return fig


def generate_interface():
app = Application()
def generate_interface(datasource=None):
app = Application(datasource=datasource)
handlers = pn.Param(app.param)

# Analysis Page
Expand All @@ -195,7 +204,3 @@ def generate_interface():
('Analysis Page', analysis_page),
)
return interface


platform = generate_interface()
platform.servable()
47 changes: 12 additions & 35 deletions zenml/utils/post_training/post_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import click
import nbformat as nbf
import pandas as pd
import panel
import panel as pn
import tensorflow as tf
import tensorflow_data_validation as tfdv
Expand All @@ -29,8 +30,7 @@
from tensorflow_transform.tf_metadata import schema_utils
from tfx.utils import io_utils

from zenml.constants import APP_NAME, EVALUATION_NOTEBOOK, \
COMPARISON_NOTEBOOK
from zenml.constants import APP_NAME, EVALUATION_NOTEBOOK
from zenml.enums import GDPComponent
from zenml.logger import get_logger
from zenml.utils.path_utils import read_file_contents
Expand Down Expand Up @@ -336,36 +336,13 @@ def evaluate_single_pipeline(
os.system(f'jupyter notebook {final_out_path} --port {port}')


def launch_compare_tool(port: int = 0):
"""Launches `compare` tool for comparing multiple training pipelines."""
# assumes compare.py in the same folder
template = \
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'compare.py')
compare_cell = read_file_contents(template)

# generate notebook
nb = nbf.v4.new_notebook()
nb['cells'] = [
nbf.v4.new_code_cell(compare_cell),
]

# TODO: [LOW] Check if we can centralize this along with the one used in
# evaluate_single_pipeline()
config_folder = click.get_app_dir(APP_NAME)
if not (os.path.exists(config_folder) and os.path.isdir(
config_folder)):
os.makedirs(config_folder)

final_out_path = os.path.join(config_folder, COMPARISON_NOTEBOOK)
s = nbf.writes(nb)
if isinstance(s, bytes):
s = s.decode('utf8')

with open(final_out_path, 'w') as f:
f.write(s)

# serve notebook
if port == 0:
os.system('panel serve "{}" --show'.format(final_out_path))
else:
os.system(f'panel serve "{final_out_path}" --port {port} --show')
def launch_compare_tool(port: int = 0, datasource=None):
"""Launches `compare` tool for comparing multiple training pipelines.
Args:
port: Port to launch application on.
datasource (BaseDatasource): object of type BaseDatasource, to
filter only pipelines using that particular datasource.
"""
from zenml.utils.post_training.compare import generate_interface
panel.serve(generate_interface(datasource), port=port)

0 comments on commit 9657a0e

Please sign in to comment.