Skip to content

Commit

Permalink
add sample_rate to accelrate training
Browse files Browse the repository at this point in the history
  • Loading branch information
safoinme committed Nov 14, 2023
1 parent 0ec8ae0 commit 44f28d8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The template can be configured using the following parameters:
| Deploy to HuggingFace | Whether to deploy to HuggingFace Hub | False |
| Deploy to SkyPilot | Whether to deploy to SkyPilot | False |
| Dataset | The dataset to use from HuggingFace Datasets | airline_reviews |
| Model | The model to use from HuggingFace Models | roberta-base |
| Model | The model to use from HuggingFace Models | distilbert-base-uncased |
| Cloud Provider | The cloud provider to use (AWS or GCP) | aws |
| Metric-Based Promotion | Whether to promote models based on metrics | True |
| Notifications on Failure | Whether to notify about pipeline failures | True |
Expand Down
8 changes: 6 additions & 2 deletions copier.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ accelerator:
- gpu
- cpu
default: gpu
sample_rate:
type: bool
help: "Whether to use a sample of the dataset for quick iteration"
default: False
deploy_locally:
type: bool
help: "Whether to deploy locally"
Expand All @@ -91,8 +95,8 @@ model:
choices:
- bert-base-uncased
- roberta-base
- distilbert-base-cased
default: roberta-base
- distilbert-base-uncased
default: distilbert-base-uncased
cloud_of_choice:
type: str
help: "Whether to use AWS cloud provider or GCP"
Expand Down
16 changes: 16 additions & 0 deletions template/steps/dataset_loader/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from datasets import load_dataset, DatasetDict
from zenml import step
from zenml.logger import get_logger
{%- if sample_rate %}
import numpy as np
{%- endif %}

logger = get_logger(__name__)

Expand Down Expand Up @@ -41,6 +44,19 @@ def data_loader(
dataset = dataset.remove_columns(["airline_sentiment_confidence","negativereason_confidence"])
{%- endif %}

{%- if sample_rate %}
# Sample 20% of the data randomly for the demo
def sample_dataset(dataset, sample_rate=0.2):
sampled_dataset = DatasetDict()
for split in dataset.keys():
split_size = len(dataset[split])
indices = np.random.choice(split_size, int(split_size * sample_rate), replace=False)
sampled_dataset[split] = dataset[split].select(indices)
return sampled_dataset

dataset = sample_dataset(dataset)
{%- endif %}

# Log the dataset and sample examples
logger.info(dataset)
logger.info(f"Sample Example 1 : {dataset['train'][0]['text']} with label {dataset['train'][0]['label']}")
Expand Down
2 changes: 1 addition & 1 deletion template/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def model_trainer(
evaluation_strategy='steps',
save_strategy='steps',
save_steps=1000,
eval_steps=200,
eval_steps=100,
logging_steps=logging_steps,
save_total_limit=5,
report_to="mlflow",
Expand Down

0 comments on commit 44f28d8

Please sign in to comment.