Skip to content

Commit

Permalink
Add utility files and cleanup configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
mkuchnik committed Mar 13, 2022
1 parent 25ee6ce commit 6123f5b
Show file tree
Hide file tree
Showing 20 changed files with 167 additions and 59 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ For some runs, we use Tensorflow to control the model, and therefore the TPU.
In other runs, we use [JAX](https://github.com/google/jax) to do the same.
In all cases, Tensorflow is in control of the data pipeline with tf.data.
**Constraints.**
TF2 with native ops is the main target of Plumber. TF1 workflows currently
cannot be resumed, and therefore require either benchmarking the input pipeline
in a TF2 context (i.e., a microbenchmark)
or tracing the pipeline in a TF1 context and resuming it in a
TF2 context. For the latter, we provide the relevant pipelines (i.e., RCNN/GNMT) with
a `get_params.py`, which can start Plumber optimization after tracing in a TF2 context.
Similarly, pipelines with custom kernels (i.e., Flax WMT) cannot be
easily deserialized in either TF2 or TF1, so we provide a `show_params.sh` to
print out the currently recommended parameters after tracing.
The parameters can then be copied over to the input pipeline.
##### ResNet
This pipeline uses JAX (TPU) + Tensorflow.
You should install the minimal requirements to avoid dependency clobber (e.g.,
Expand Down Expand Up @@ -558,6 +571,10 @@ Then to run with standard 96 threads/cores:
bash official_runners/run_96.sh
```

To obtain the Plumber parameters, please run the naive pipeline with tracing and then run
`get_params.py` with the emitted `stats.pb` placed in the same directory as you
are running `get_params.py` from.

##### SSD
This pipeline uses JAX (TPU) + Tensorflow.
Run the dependency install script.
Expand Down Expand Up @@ -607,6 +624,15 @@ After installing this dependency, you can run:
bash official_runners/run.sh
```
Because this pipeline uses custom operators which cannot be easily
deserialized, we utilize the provided `show_params.sh` script to
print out Plumber's recommendation when optimizing the `stats.pb` file emitted by
tracing the naive pipeline.
Note that this requires running the optimization at least twice to get the final
parameters to account for cache insertion.
Each of the pipelines configurations are provided as separate files which can be
copied over `input_pipeline.py`.

##### GNMT
Run the dependency install script.
```bash
Expand All @@ -617,3 +643,7 @@ Then to run:
```bash
bash official_runners/run.sh
```

To obtain the Plumber parameters, please run the naive pipeline with tracing and then run
`get_params.py` with the emitted `stats.pb` placed in the same directory as you
are running `get_params.py` from.
3 changes: 2 additions & 1 deletion end_to_end/TPU/flax_examples/wmt/input_pipeline_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

AUTOTUNE = tf.data.AUTOTUNE
AUTOTUNE = 1
PREFETCH_AUTOTUNE = 0
tf.data.AUTOTUNE = AUTOTUNE
Features = Dict[str, tf.Tensor]

Expand Down Expand Up @@ -287,7 +288,7 @@ def preprocess_wmt_data(dataset,
max_length: int = 512,
batch_size: int = 256,
drop_remainder: bool = True,
prefetch_size: int = AUTOTUNE):
prefetch_size: int = PREFETCH_AUTOTUNE):
"""Shuffle and batch/pack the given dataset."""

def length_filter(max_len):
Expand Down
20 changes: 20 additions & 0 deletions end_to_end/TPU/gnmt-research-TF-tpu-v4-16/get_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Get Plumber recommendation from a traced pipeline.
Takes as input the 'stats.pb' in the current directory.
"""

import plumber_analysis.pipeline_optimizer_wrapper
import plumber_analysis.config

is_fast = False
optimizer = plumber_analysis.pipeline_optimizer_wrapper.step_par_2(is_fast=is_fast)

experiment_params = optimizer.experiment_params()
performance_params = optimizer.get_performance_parameters()
print("Experimental params:\n{}".format(experiment_params))
print("Plumber found parameters:\n{}".format(performance_params))

dataset = optimizer.instantiate_pipeline()
dataset = plumber_analysis.pipeline_optimizer_wrapper.apply_default_options(dataset, override_presets=True)
ret = plumber_analysis.pipeline_optimizer_wrapper._benchmark_dataset(dataset, time_limit_s=62)
print(ret)
28 changes: 14 additions & 14 deletions end_to_end/TPU/gnmt-research-TF-tpu-v4-16/official_runners/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ function step_0 {
--input_pipeline_map_0_parallelism=1 \
--input_pipeline_map_1_parallelism=1 \
--input_pipeline_default_prefetching=1 \
--input_pipeline_cache=False \
--use_synthetic_data=False \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand All @@ -58,8 +56,6 @@ function step_0_no_prefetch {
--input_pipeline_map_0_parallelism=1 \
--input_pipeline_map_1_parallelism=1 \
--input_pipeline_default_prefetching=0 \
--input_pipeline_cache=False \
--use_synthetic_data=False \
${global_opt} | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand All @@ -78,9 +74,7 @@ function step_0_benchmark {
--input_pipeline_read_parallelism=1 \
--input_pipeline_map_0_parallelism=1 \
--input_pipeline_map_1_parallelism=1 \
--input_pipeline_default_prefetching=1 \
--input_pipeline_cache=False \
--use_synthetic_data=False \
--input_pipeline_default_prefetching=0 \
${benchmark_global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand All @@ -93,13 +87,13 @@ function step_plumber {
pushd ${experiment_name}
PLUMBER_NO_OPTIMIZE=True \
python3 $script_name \
--data_dir=${data_dir} \
--data_dir=${data_dir} \
--out_dir=${model_dir} \
--use_preprocessed_data=True \
--input_pipeline_read_parallelism=34 \
--input_pipeline_map_0_parallelism=19 \
--input_pipeline_read_parallelism=47 \
--input_pipeline_map_0_parallelism=20 \
--input_pipeline_map_1_parallelism=9 \
--input_pipeline_default_prefetching=14 \
--input_pipeline_default_prefetching=16 \
--input_pipeline_cache=True \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
Expand Down Expand Up @@ -143,6 +137,7 @@ function step_plumber_benchmark {
--map_tfrecord_decode_parallelism=1 \
--map_image_postprocessing_parallelism=1 \
--map_image_transpose_postprocessing_parallelism=1 \
--input_pipeline_default_prefetching=0 \
--shard_parallelism=1 \
${benchmark_global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
Expand All @@ -163,7 +158,6 @@ function step_autotune {
--input_pipeline_map_0_parallelism=-1 \
--input_pipeline_map_1_parallelism=-1 \
--input_pipeline_default_prefetching=-1 \
--input_pipeline_cache=False \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand Down Expand Up @@ -207,7 +201,6 @@ function step_heuristic {
--input_pipeline_map_0_parallelism=96 \
--input_pipeline_map_1_parallelism=96 \
--input_pipeline_default_prefetching=100 \
--input_pipeline_cache=False \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand All @@ -233,11 +226,18 @@ function step_heuristic_benchmark {
popd
}


# Get Plumber recommendation with:
# name="params"
# python3 get_params.py 2>&1 | tee ${name}_log.txt
# after tracing pipeline.
# Alternatively, get parameters by running Plumber benchmark

list_python_programs
kill_python_programs
kill_python_programs
list_python_programs
step_0
step_0_no_prefetch
list_python_programs
kill_python_programs
kill_python_programs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import tensorflow.compat.v1 as tf
#tf.disable_eager_execution()
from plumber_analysis import gen_util
from plumber_analysis import gen_util, annotations

__all__ = ["get_iterator", "get_infer_iterator"]

Expand Down Expand Up @@ -226,6 +226,7 @@ def reduce_func(unused_key, windowed_data):


# pylint: disable=g-long-lambda,line-too-long
@annotations.trace_pipeline()
def get_preprocessed_iterator(dataset_file,
batch_size,
random_seed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import tensorflow.compat.v1 as tf1
#tf1.disable_eager_execution()
import dataloader
from plumber_analysis import gen_util
from plumber_analysis import gen_util, pipeline_optimizer_wrapper, pipeline_optimizer
import mask_rcnn_params

pipeline_optimizer.DEFAULT_BENCHMARK_TIME = 62
pipeline_optimizer_wrapper.BENCHMARK_TIME = 62

FLAGS = flags.FLAGS

flags.DEFINE_string('hparams', '',
Expand Down
6 changes: 4 additions & 2 deletions end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ def reduce_func(unused_key, dataset):
num_parallel_calls=DEFAULT_PARALLELISM())

#dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = dataset.prefetch(DEFAULT_PREFETCHING())
if DEFAULT_PREFETCHING():
dataset = dataset.prefetch(DEFAULT_PREFETCHING())

if (self._mode == tf.estimator.ModeKeys.TRAIN and
num_examples > 0):
Expand All @@ -563,7 +564,8 @@ def reduce_func(unused_key, dataset):
options.experimental_deterministic = deterministic
options.experimental_threading.max_intra_op_parallelism = 1
options.experimental_threading.private_threadpool_size = 96
#options.experimental_optimization.autotune_stats_filename = "stats.pb"
# NOTE(mkuchnik): Plumber Tracing
options.experimental_optimization.autotune_stats_filename = "stats.pb"
dataset = dataset.with_options(options)

if self._mode == tf.estimator.ModeKeys.TRAIN:
Expand Down
20 changes: 20 additions & 0 deletions end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/get_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Get Plumber recommendation from a traced pipeline.
Takes as input the 'stats.pb' in the current directory.
"""

import plumber_analysis.pipeline_optimizer_wrapper
import plumber_analysis.config

is_fast = False
optimizer = plumber_analysis.pipeline_optimizer_wrapper.step_par_2(is_fast=is_fast)

experiment_params = optimizer.experiment_params()
performance_params = optimizer.get_performance_parameters()
print("Experimental params:\n{}".format(experiment_params))
print("Plumber found parameters:\n{}".format(performance_params))

dataset = optimizer.instantiate_pipeline()
dataset = plumber_analysis.pipeline_optimizer_wrapper.apply_default_options(dataset, override_presets=True)
ret = plumber_analysis.pipeline_optimizer_wrapper._benchmark_dataset(dataset, time_limit_s=62)
print(ret)
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function step_naive {
--train_batch_size=32 \
--eval_batch_size=32 \
--input_pipeline_default_parallelism=1 \
--input_pipeline_default_prefetching=1 \
--input_pipeline_default_prefetching=0 \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand All @@ -106,7 +106,7 @@ function step_plumber {
--train_batch_size=32 \
--eval_batch_size=32 \
--input_pipeline_default_parallelism=1 \
--input_pipeline_default_prefetching=1 \
--input_pipeline_default_prefetching=0 \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand Down Expand Up @@ -217,7 +217,6 @@ function step_autotune_benchmark {
}

function step_plumber_benchmark {
#TODO(mkuchnik): Error piping not working
name=step_plumber_benchmark
experiment_name=${experiment_prefix}_${name}
mkdir -p ${experiment_name}
Expand All @@ -231,8 +230,9 @@ function step_plumber_benchmark {
--log_dir=$log_dir \
--train_batch_size=32 \
--eval_batch_size=32 \
${benchmark_global_opt} | tee -a ${name}_log.txt
#${benchmark_global_opt} 2>&1 | tee -a ${name}_log.txt
--input_pipeline_default_parallelism=1 \
--input_pipeline_default_prefetching=0 \
${benchmark_global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
}
Expand Down Expand Up @@ -277,7 +277,12 @@ function step_naive_benchmark {
popd
}

#step_naive_benchmark
# Get Plumber recommendation with:
# name="params"
# python3 get_params.py 2>&1 | tee ${name}_log.txt
# after tracing pipeline.
# Alternatively, get parameters by running Plumber benchmark

list_python_programs
kill_python_programs
kill_python_programs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function step_naive {
--train_batch_size=32 \
--eval_batch_size=32 \
--input_pipeline_default_parallelism=1 \
--input_pipeline_default_prefetching=1 \
--input_pipeline_default_prefetching=0 \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand All @@ -106,7 +106,7 @@ function step_plumber {
--train_batch_size=32 \
--eval_batch_size=32 \
--input_pipeline_default_parallelism=1 \
--input_pipeline_default_prefetching=1 \
--input_pipeline_default_prefetching=0 \
${global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
Expand Down Expand Up @@ -217,7 +217,6 @@ function step_autotune_benchmark {
}

function step_plumber_benchmark {
#TODO(mkuchnik): Error piping not working
name=step_plumber_benchmark
experiment_name=${experiment_prefix}_${name}
mkdir -p ${experiment_name}
Expand All @@ -231,8 +230,9 @@ function step_plumber_benchmark {
--log_dir=$log_dir \
--train_batch_size=32 \
--eval_batch_size=32 \
${benchmark_global_opt} | tee -a ${name}_log.txt
#${benchmark_global_opt} 2>&1 | tee -a ${name}_log.txt
--input_pipeline_default_parallelism=1 \
--input_pipeline_default_prefetching=0 \
${benchmark_global_opt} 2>&1 | tee ${name}_log.txt
cp stats.pb $name.pb
popd
}
Expand Down Expand Up @@ -277,7 +277,12 @@ function step_naive_benchmark {
popd
}

#step_naive_benchmark
# Get Plumber recommendation with:
# name="params"
# python3 get_params.py 2>&1 | tee ${name}_log.txt
# after tracing pipeline.
# Alternatively, get parameters by running Plumber benchmark

list_python_programs
kill_python_programs
kill_python_programs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self,
self.post_train_loop_callbacks = []
self.num_outfeeds = self.eval_steps
self.config = tf.ConfigProto(
operation_timeout_in_ms=600 * 60 * 1000,
operation_timeout_in_ms=2 * 600 * 60 * 1000,
allow_soft_placement=True,
graph_options=tf.GraphOptions(
rewrite_options=rewriter_config_pb2.RewriterConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
import augmentations
import logging

from plumber_analysis import pipeline_optimizer as _pipeline_optimizer

_pipeline_optimizer.DEFAULT_BENCHMARK_TIME = 60 + 2
pipeline_optimizer.BENCHMARK_TIME = 60 + 2

config.enable_compat_logging()

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -306,7 +311,8 @@ def set_shapes(images, labels):
file_pattern = train_file_pattern if train else val_file_pattern
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
dataset = dataset.shard(num_hosts, index)
concurrent_files = min(10, 1024 // num_hosts)
# concurrent_files = min(10, 1024 // num_hosts)
concurrent_files = DEFAULT_PARALLELISM()
dataset = dataset.interleave(
tf.data.TFRecordDataset, concurrent_files, 1, concurrent_files)

Expand Down
Loading

0 comments on commit 6123f5b

Please sign in to comment.