From 6123f5bce36eec7dc75b6b9298054b493d930bdc Mon Sep 17 00:00:00 2001 From: Michael Kuchnik Date: Sun, 13 Mar 2022 04:08:37 +0000 Subject: [PATCH] Add utility files and cleanup configurations --- README.md | 30 +++++++++++++++++++ .../flax_examples/wmt/input_pipeline_naive.py | 3 +- .../gnmt-research-TF-tpu-v4-16/get_params.py | 20 +++++++++++++ .../official_runners/run.sh | 28 ++++++++--------- .../utils/iterator_utils.py | 3 +- .../benchmark_mlperf.py | 5 +++- .../dataloader.py | 6 ++-- .../get_params.py | 20 +++++++++++++ .../official_runners/run_48.sh | 17 +++++++---- .../official_runners/run_96.sh | 17 +++++++---- .../util/train_and_eval_runner.py | 2 +- .../mlperf_input_pipeline.py | 8 ++++- .../input_pipeline.py | 6 ++++ .../official_runners/run.sh | 10 ++----- .../official_runners/run_96.sh | 10 ++----- .../mlperf_input_pipeline.py | 5 ++++ .../official_runners/run.sh | 11 +++---- .../src/plumber_analysis/annotations.py | 7 ++++- .../plumber_analysis/pipeline_optimizer.py | 3 +- .../pipeline_optimizer_wrapper.py | 15 ++++++---- 20 files changed, 167 insertions(+), 59 deletions(-) create mode 100644 end_to_end/TPU/gnmt-research-TF-tpu-v4-16/get_params.py create mode 100644 end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/get_params.py diff --git a/README.md b/README.md index 950d84a..320b9e7 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,19 @@ For some runs, we use Tensorflow to control the model, and therefore the TPU. In other runs, we use [JAX](https://github.com/google/jax) to do the same. In all cases, Tensorflow is in control of the data pipeline with tf.data. +**Constraints.** +TF2 with native ops is the main target of Plumber. TF1 workflows currently +cannot be resumed, and therefore require either benchmarking the input pipeline +in a TF2 context (i.e., a microbenchmark) +or tracing the pipeline in a TF1 context and resuming it in a +TF2 context. For the latter, we provide the relevant pipelines (i.e., RCNN/GNMT) with +a `get_params.py`, which can start Plumber optimization after tracing in a TF2 context. +Similarly, pipelines with custom kernels (i.e., Flax WMT) cannot be +easily deserialized in either TF2 or TF1, so we provide a `show_params.sh` to +print out the currently recommended parameters after tracing. +The parameters can then be copied over to the input pipeline. + + ##### ResNet This pipeline uses JAX (TPU) + Tensorflow. You should install the minimal requirements to avoid dependency clobber (e.g., @@ -558,6 +571,10 @@ Then to run with standard 96 threads/cores: bash official_runners/run_96.sh ``` +To obtain the Plumber parameters, please run the naive pipeline with tracing and then run +`get_params.py` with the emitted `stats.pb` placed in the same directory as you +are running `get_params.py` from. + ##### SSD This pipeline uses JAX (TPU) + Tensorflow. Run the dependency install script. @@ -607,6 +624,15 @@ After installing this dependency, you can run: bash official_runners/run.sh ``` +Because this pipeline uses custom operators which cannot be easily +deserialized, we utilize the provided `show_params.sh` script to +print out Plumber's recommendation when optimizing the `stats.pb` file emitted by +tracing the naive pipeline. +Note that this requires running the optimization at least twice to get the final +parameters to account for cache insertion. +Each of the pipelines configurations are provided as separate files which can be +copied over `input_pipeline.py`. + ##### GNMT Run the dependency install script. ```bash @@ -617,3 +643,7 @@ Then to run: ```bash bash official_runners/run.sh ``` + +To obtain the Plumber parameters, please run the naive pipeline with tracing and then run +`get_params.py` with the emitted `stats.pb` placed in the same directory as you +are running `get_params.py` from. diff --git a/end_to_end/TPU/flax_examples/wmt/input_pipeline_naive.py b/end_to_end/TPU/flax_examples/wmt/input_pipeline_naive.py index 6483acb..cb4fe58 100644 --- a/end_to_end/TPU/flax_examples/wmt/input_pipeline_naive.py +++ b/end_to_end/TPU/flax_examples/wmt/input_pipeline_naive.py @@ -32,6 +32,7 @@ AUTOTUNE = tf.data.AUTOTUNE AUTOTUNE = 1 +PREFETCH_AUTOTUNE = 0 tf.data.AUTOTUNE = AUTOTUNE Features = Dict[str, tf.Tensor] @@ -287,7 +288,7 @@ def preprocess_wmt_data(dataset, max_length: int = 512, batch_size: int = 256, drop_remainder: bool = True, - prefetch_size: int = AUTOTUNE): + prefetch_size: int = PREFETCH_AUTOTUNE): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): diff --git a/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/get_params.py b/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/get_params.py new file mode 100644 index 0000000..fa65dd1 --- /dev/null +++ b/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/get_params.py @@ -0,0 +1,20 @@ +"""Get Plumber recommendation from a traced pipeline. + +Takes as input the 'stats.pb' in the current directory. +""" + +import plumber_analysis.pipeline_optimizer_wrapper +import plumber_analysis.config + +is_fast = False +optimizer = plumber_analysis.pipeline_optimizer_wrapper.step_par_2(is_fast=is_fast) + +experiment_params = optimizer.experiment_params() +performance_params = optimizer.get_performance_parameters() +print("Experimental params:\n{}".format(experiment_params)) +print("Plumber found parameters:\n{}".format(performance_params)) + +dataset = optimizer.instantiate_pipeline() +dataset = plumber_analysis.pipeline_optimizer_wrapper.apply_default_options(dataset, override_presets=True) +ret = plumber_analysis.pipeline_optimizer_wrapper._benchmark_dataset(dataset, time_limit_s=62) +print(ret) diff --git a/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/official_runners/run.sh b/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/official_runners/run.sh index 0ca47a7..5a942cb 100755 --- a/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/official_runners/run.sh +++ b/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/official_runners/run.sh @@ -37,8 +37,6 @@ function step_0 { --input_pipeline_map_0_parallelism=1 \ --input_pipeline_map_1_parallelism=1 \ --input_pipeline_default_prefetching=1 \ - --input_pipeline_cache=False \ - --use_synthetic_data=False \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -58,8 +56,6 @@ function step_0_no_prefetch { --input_pipeline_map_0_parallelism=1 \ --input_pipeline_map_1_parallelism=1 \ --input_pipeline_default_prefetching=0 \ - --input_pipeline_cache=False \ - --use_synthetic_data=False \ ${global_opt} | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -78,9 +74,7 @@ function step_0_benchmark { --input_pipeline_read_parallelism=1 \ --input_pipeline_map_0_parallelism=1 \ --input_pipeline_map_1_parallelism=1 \ - --input_pipeline_default_prefetching=1 \ - --input_pipeline_cache=False \ - --use_synthetic_data=False \ + --input_pipeline_default_prefetching=0 \ ${benchmark_global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -93,13 +87,13 @@ function step_plumber { pushd ${experiment_name} PLUMBER_NO_OPTIMIZE=True \ python3 $script_name \ - --data_dir=${data_dir} \ + --data_dir=${data_dir} \ --out_dir=${model_dir} \ --use_preprocessed_data=True \ - --input_pipeline_read_parallelism=34 \ - --input_pipeline_map_0_parallelism=19 \ + --input_pipeline_read_parallelism=47 \ + --input_pipeline_map_0_parallelism=20 \ --input_pipeline_map_1_parallelism=9 \ - --input_pipeline_default_prefetching=14 \ + --input_pipeline_default_prefetching=16 \ --input_pipeline_cache=True \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb @@ -143,6 +137,7 @@ function step_plumber_benchmark { --map_tfrecord_decode_parallelism=1 \ --map_image_postprocessing_parallelism=1 \ --map_image_transpose_postprocessing_parallelism=1 \ + --input_pipeline_default_prefetching=0 \ --shard_parallelism=1 \ ${benchmark_global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb @@ -163,7 +158,6 @@ function step_autotune { --input_pipeline_map_0_parallelism=-1 \ --input_pipeline_map_1_parallelism=-1 \ --input_pipeline_default_prefetching=-1 \ - --input_pipeline_cache=False \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -207,7 +201,6 @@ function step_heuristic { --input_pipeline_map_0_parallelism=96 \ --input_pipeline_map_1_parallelism=96 \ --input_pipeline_default_prefetching=100 \ - --input_pipeline_cache=False \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -233,11 +226,18 @@ function step_heuristic_benchmark { popd } + +# Get Plumber recommendation with: +# name="params" +# python3 get_params.py 2>&1 | tee ${name}_log.txt +# after tracing pipeline. +# Alternatively, get parameters by running Plumber benchmark + list_python_programs kill_python_programs kill_python_programs list_python_programs -step_0 +step_0_no_prefetch list_python_programs kill_python_programs kill_python_programs diff --git a/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/utils/iterator_utils.py b/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/utils/iterator_utils.py index 9227929..35b9945 100644 --- a/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/utils/iterator_utils.py +++ b/end_to_end/TPU/gnmt-research-TF-tpu-v4-16/utils/iterator_utils.py @@ -17,7 +17,7 @@ import tensorflow.compat.v1 as tf #tf.disable_eager_execution() -from plumber_analysis import gen_util +from plumber_analysis import gen_util, annotations __all__ = ["get_iterator", "get_infer_iterator"] @@ -226,6 +226,7 @@ def reduce_func(unused_key, windowed_data): # pylint: disable=g-long-lambda,line-too-long +@annotations.trace_pipeline() def get_preprocessed_iterator(dataset_file, batch_size, random_seed, diff --git a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/benchmark_mlperf.py b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/benchmark_mlperf.py index 28d8a4b..3de8de8 100644 --- a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/benchmark_mlperf.py +++ b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/benchmark_mlperf.py @@ -5,9 +5,12 @@ import tensorflow.compat.v1 as tf1 #tf1.disable_eager_execution() import dataloader -from plumber_analysis import gen_util +from plumber_analysis import gen_util, pipeline_optimizer_wrapper, pipeline_optimizer import mask_rcnn_params +pipeline_optimizer.DEFAULT_BENCHMARK_TIME = 62 +pipeline_optimizer_wrapper.BENCHMARK_TIME = 62 + FLAGS = flags.FLAGS flags.DEFINE_string('hparams', '', diff --git a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/dataloader.py b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/dataloader.py index d4eb116..bd79f34 100644 --- a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/dataloader.py +++ b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/dataloader.py @@ -542,7 +542,8 @@ def reduce_func(unused_key, dataset): num_parallel_calls=DEFAULT_PARALLELISM()) #dataset = dataset.prefetch(tf.data.AUTOTUNE) - dataset = dataset.prefetch(DEFAULT_PREFETCHING()) + if DEFAULT_PREFETCHING(): + dataset = dataset.prefetch(DEFAULT_PREFETCHING()) if (self._mode == tf.estimator.ModeKeys.TRAIN and num_examples > 0): @@ -563,7 +564,8 @@ def reduce_func(unused_key, dataset): options.experimental_deterministic = deterministic options.experimental_threading.max_intra_op_parallelism = 1 options.experimental_threading.private_threadpool_size = 96 - #options.experimental_optimization.autotune_stats_filename = "stats.pb" + # NOTE(mkuchnik): Plumber Tracing + options.experimental_optimization.autotune_stats_filename = "stats.pb" dataset = dataset.with_options(options) if self._mode == tf.estimator.ModeKeys.TRAIN: diff --git a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/get_params.py b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/get_params.py new file mode 100644 index 0000000..fa65dd1 --- /dev/null +++ b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/get_params.py @@ -0,0 +1,20 @@ +"""Get Plumber recommendation from a traced pipeline. + +Takes as input the 'stats.pb' in the current directory. +""" + +import plumber_analysis.pipeline_optimizer_wrapper +import plumber_analysis.config + +is_fast = False +optimizer = plumber_analysis.pipeline_optimizer_wrapper.step_par_2(is_fast=is_fast) + +experiment_params = optimizer.experiment_params() +performance_params = optimizer.get_performance_parameters() +print("Experimental params:\n{}".format(experiment_params)) +print("Plumber found parameters:\n{}".format(performance_params)) + +dataset = optimizer.instantiate_pipeline() +dataset = plumber_analysis.pipeline_optimizer_wrapper.apply_default_options(dataset, override_presets=True) +ret = plumber_analysis.pipeline_optimizer_wrapper._benchmark_dataset(dataset, time_limit_s=62) +print(ret) diff --git a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_48.sh b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_48.sh index 8e4d1c9..26c380c 100755 --- a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_48.sh +++ b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_48.sh @@ -83,7 +83,7 @@ function step_naive { --train_batch_size=32 \ --eval_batch_size=32 \ --input_pipeline_default_parallelism=1 \ - --input_pipeline_default_prefetching=1 \ + --input_pipeline_default_prefetching=0 \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -106,7 +106,7 @@ function step_plumber { --train_batch_size=32 \ --eval_batch_size=32 \ --input_pipeline_default_parallelism=1 \ - --input_pipeline_default_prefetching=1 \ + --input_pipeline_default_prefetching=0 \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -217,7 +217,6 @@ function step_autotune_benchmark { } function step_plumber_benchmark { - #TODO(mkuchnik): Error piping not working name=step_plumber_benchmark experiment_name=${experiment_prefix}_${name} mkdir -p ${experiment_name} @@ -231,8 +230,9 @@ function step_plumber_benchmark { --log_dir=$log_dir \ --train_batch_size=32 \ --eval_batch_size=32 \ - ${benchmark_global_opt} | tee -a ${name}_log.txt - #${benchmark_global_opt} 2>&1 | tee -a ${name}_log.txt + --input_pipeline_default_parallelism=1 \ + --input_pipeline_default_prefetching=0 \ + ${benchmark_global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd } @@ -277,7 +277,12 @@ function step_naive_benchmark { popd } -#step_naive_benchmark +# Get Plumber recommendation with: +# name="params" +# python3 get_params.py 2>&1 | tee ${name}_log.txt +# after tracing pipeline. +# Alternatively, get parameters by running Plumber benchmark + list_python_programs kill_python_programs kill_python_programs diff --git a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_96.sh b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_96.sh index a130e7a..3ce244c 100755 --- a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_96.sh +++ b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/official_runners/run_96.sh @@ -83,7 +83,7 @@ function step_naive { --train_batch_size=32 \ --eval_batch_size=32 \ --input_pipeline_default_parallelism=1 \ - --input_pipeline_default_prefetching=1 \ + --input_pipeline_default_prefetching=0 \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -106,7 +106,7 @@ function step_plumber { --train_batch_size=32 \ --eval_batch_size=32 \ --input_pipeline_default_parallelism=1 \ - --input_pipeline_default_prefetching=1 \ + --input_pipeline_default_prefetching=0 \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd @@ -217,7 +217,6 @@ function step_autotune_benchmark { } function step_plumber_benchmark { - #TODO(mkuchnik): Error piping not working name=step_plumber_benchmark experiment_name=${experiment_prefix}_${name} mkdir -p ${experiment_name} @@ -231,8 +230,9 @@ function step_plumber_benchmark { --log_dir=$log_dir \ --train_batch_size=32 \ --eval_batch_size=32 \ - ${benchmark_global_opt} | tee -a ${name}_log.txt - #${benchmark_global_opt} 2>&1 | tee -a ${name}_log.txt + --input_pipeline_default_parallelism=1 \ + --input_pipeline_default_prefetching=0 \ + ${benchmark_global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd } @@ -277,7 +277,12 @@ function step_naive_benchmark { popd } -#step_naive_benchmark +# Get Plumber recommendation with: +# name="params" +# python3 get_params.py 2>&1 | tee ${name}_log.txt +# after tracing pipeline. +# Alternatively, get parameters by running Plumber benchmark + list_python_programs kill_python_programs kill_python_programs diff --git a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/util/train_and_eval_runner.py b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/util/train_and_eval_runner.py index 8ca7b1b..20171d7 100644 --- a/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/util/train_and_eval_runner.py +++ b/end_to_end/TPU/maskrcnn-preview-TF-tpu-v4-128/util/train_and_eval_runner.py @@ -201,7 +201,7 @@ def __init__(self, self.post_train_loop_callbacks = [] self.num_outfeeds = self.eval_steps self.config = tf.ConfigProto( - operation_timeout_in_ms=600 * 60 * 1000, + operation_timeout_in_ms=2 * 600 * 60 * 1000, allow_soft_placement=True, graph_options=tf.GraphOptions( rewrite_options=rewriter_config_pb2.RewriterConfig( diff --git a/end_to_end/TPU/resnet-research-JAX-Distributed-Shampoo-tpu-v4-256/mlperf_input_pipeline.py b/end_to_end/TPU/resnet-research-JAX-Distributed-Shampoo-tpu-v4-256/mlperf_input_pipeline.py index 493d92d..0c70601 100644 --- a/end_to_end/TPU/resnet-research-JAX-Distributed-Shampoo-tpu-v4-256/mlperf_input_pipeline.py +++ b/end_to_end/TPU/resnet-research-JAX-Distributed-Shampoo-tpu-v4-256/mlperf_input_pipeline.py @@ -32,6 +32,11 @@ import augmentations import logging +from plumber_analysis import pipeline_optimizer as _pipeline_optimizer + +_pipeline_optimizer.DEFAULT_BENCHMARK_TIME = 60 + 2 +pipeline_optimizer.BENCHMARK_TIME = 60 + 2 + config.enable_compat_logging() FLAGS = flags.FLAGS @@ -306,7 +311,8 @@ def set_shapes(images, labels): file_pattern = train_file_pattern if train else val_file_pattern dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.shard(num_hosts, index) - concurrent_files = min(10, 1024 // num_hosts) + # concurrent_files = min(10, 1024 // num_hosts) + concurrent_files = DEFAULT_PARALLELISM() dataset = dataset.interleave( tf.data.TFRecordDataset, concurrent_files, 1, concurrent_files) diff --git a/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/input_pipeline.py b/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/input_pipeline.py index 7bad9c4..4d8e9c0 100644 --- a/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/input_pipeline.py +++ b/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/input_pipeline.py @@ -16,6 +16,12 @@ import ssd_constants import plumber_analysis.annotations +from plumber_analysis import pipeline_optimizer_wrapper as pipeline_optimizer +from plumber_analysis import pipeline_optimizer as _pipeline_optimizer +print("Default time: {}".format(_pipeline_optimizer.DEFAULT_BENCHMARK_TIME)) +print("Default time: {}".format(pipeline_optimizer.BENCHMARK_TIME)) +_pipeline_optimizer.DEFAULT_BENCHMARK_TIME = 60 + 2 +pipeline_optimizer.BENCHMARK_TIME = 60 + 2 FLAGS = flags.FLAGS diff --git a/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run.sh b/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run.sh index d789aa2..b7d9c8a 100755 --- a/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run.sh +++ b/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run.sh @@ -11,7 +11,7 @@ curr_dir="$(pwd)" script_name="$curr_dir/ssd_train.py" benchmark_script_name="$curr_dir/benchmark_mlperf.py" benchmark_global_opt="--time_limit_s=$time_limit_s --dataset_threadpool_size=48" -global_opt="--num_epochs=5" +global_opt="--num_epochs=5 --no_eval=True" experiment_dir="official_experiments/default_model_48_core_48_thread" experiment_prefix="${experiment_dir}/run_0/ssd_train" @@ -35,7 +35,6 @@ function step_0 { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=1 \ --map_parse_parallelism=1 \ --map_tfrecord_decode_parallelism=1 \ @@ -77,7 +76,6 @@ function step_plumber { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=1 \ --map_parse_parallelism=1 \ --map_tfrecord_decode_parallelism=1 \ @@ -101,7 +99,6 @@ function step_plumber_fake { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=1 \ --map_parse_parallelism=1 \ --map_tfrecord_decode_parallelism=1 \ @@ -144,7 +141,6 @@ function step_autotune { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=-1 \ --map_parse_parallelism=-1 \ --map_tfrecord_decode_parallelism=-1 \ @@ -168,7 +164,6 @@ function step_autotune_benchmark { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=-1 \ --map_parse_parallelism=-1 \ --map_tfrecord_decode_parallelism=-1 \ @@ -192,7 +187,6 @@ function step_heuristic { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=48 \ --map_parse_parallelism=48 \ --map_tfrecord_decode_parallelism=48 \ @@ -201,6 +195,7 @@ function step_heuristic { --shard_parallelism=48 \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb + popd } function step_heuristic_benchmark { @@ -217,6 +212,7 @@ function step_heuristic_benchmark { --shard_parallelism=48 \ ${benchmark_global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb + popd } list_python_programs diff --git a/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run_96.sh b/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run_96.sh index 72886a3..b7b80d7 100755 --- a/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run_96.sh +++ b/end_to_end/TPU/ssd-research-JAX-tpu-v3-4096/official_runners/run_96.sh @@ -11,7 +11,7 @@ curr_dir="$(pwd)" script_name="$curr_dir/ssd_train.py" benchmark_script_name="$curr_dir/benchmark_mlperf.py" benchmark_global_opt="--time_limit_s=$time_limit_s --dataset_threadpool_size=96" -global_opt="--num_epochs=5" +global_opt="--num_epochs=5 --no_eval=True" # TODO(mkuchnik): Directory nesting not correct experiment_dir="official_experiments/default_model_96_core_96_thread" @@ -36,7 +36,6 @@ function step_0 { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=1 \ --map_parse_parallelism=1 \ --map_tfrecord_decode_parallelism=1 \ @@ -78,7 +77,6 @@ function step_plumber { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=1 \ --map_parse_parallelism=1 \ --map_tfrecord_decode_parallelism=1 \ @@ -102,7 +100,6 @@ function step_plumber_fake { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=1 \ --map_parse_parallelism=1 \ --map_tfrecord_decode_parallelism=1 \ @@ -145,7 +142,6 @@ function step_autotune { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=-1 \ --map_parse_parallelism=-1 \ --map_tfrecord_decode_parallelism=-1 \ @@ -169,7 +165,6 @@ function step_autotune_benchmark { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=-1 \ --map_parse_parallelism=-1 \ --map_tfrecord_decode_parallelism=-1 \ @@ -193,7 +188,6 @@ function step_heuristic { --validation_file_pattern=$validation_path \ --detailed_time=True \ --precompile_eval=True \ - --no_eval=False \ --read_parallelism=96 \ --map_parse_parallelism=96 \ --map_tfrecord_decode_parallelism=96 \ @@ -202,6 +196,7 @@ function step_heuristic { --shard_parallelism=96 \ ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb + popd } function step_heuristic_benchmark { @@ -218,6 +213,7 @@ function step_heuristic_benchmark { --shard_parallelism=96 \ ${benchmark_global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb + popd } list_python_programs diff --git a/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/mlperf_input_pipeline.py b/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/mlperf_input_pipeline.py index 2643517..4063db3 100644 --- a/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/mlperf_input_pipeline.py +++ b/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/mlperf_input_pipeline.py @@ -19,9 +19,14 @@ import plumber_analysis.config import plumber_analysis.annotations +from plumber_analysis import pipeline_optimizer_wrapper as pipeline_optimizer +from plumber_analysis import pipeline_optimizer as _pipeline_optimizer plumber_analysis.config.enable_compat_logging() +_pipeline_optimizer.DEFAULT_BENCHMARK_TIME = 60 + 2 +pipeline_optimizer.BENCHMARK_TIME = 60 + 2 + # MLPerf Dataset Constants. # Packed WMT17 training data. MAX_TRAIN_LEN = 256 # multiple sequences are packed into this length. diff --git a/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/official_runners/run.sh b/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/official_runners/run.sh index 2687df3..8dd8b4e 100755 --- a/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/official_runners/run.sh +++ b/end_to_end/TPU/transformer-research-JAX-tpu-v3-8192/official_runners/run.sh @@ -80,9 +80,8 @@ function step_naive { --model_dir="my_model_dir" \ --precompile=${precompile} \ --input_pipeline_default_parallelism=1 \ - --input_pipeline_default_prefetching=1 \ - --num_epochs=6 \ - | tee ${name}_log.txt + --input_pipeline_default_prefetching=0 \ + ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd } @@ -119,7 +118,9 @@ function step_plumber { --vocab_path=${vocab_path} \ --model_dir="my_model_dir" \ --precompile=${precompile} \ - ${global_opt} | tee ${name}_log.txt + --input_pipeline_default_parallelism=1 \ + --input_pipeline_default_prefetching=0 \ + ${global_opt} 2>&1 | tee ${name}_log.txt cp stats.pb $name.pb popd } @@ -221,7 +222,7 @@ function step_heuristic { --eval_data_path=${eval_data_path} \ --vocab_path=${vocab_path} \ --model_dir="my_model_dir" \ - --precompile=True \ + --precompile=${precompile} \ --input_pipeline_default_parallelism=96 \ --input_pipeline_default_prefetching=1024 \ ${global_opt} 2>&1 | tee ${name}_log.txt diff --git a/plumber_analysis/src/plumber_analysis/annotations.py b/plumber_analysis/src/plumber_analysis/annotations.py index 297af6b..762dbad 100644 --- a/plumber_analysis/src/plumber_analysis/annotations.py +++ b/plumber_analysis/src/plumber_analysis/annotations.py @@ -196,7 +196,12 @@ def env_variable_optimize_check(): logging.info("Optimization not enabled! Passing through dataset.") return should_optimize def wrapped_kwargs_precondition_f(kwargs): - return env_variable_optimize_check() and kwargs_precondition_f(kwargs) + good_env = env_variable_optimize_check() + if kwargs_precondition_f: + good_precondition = kwargs_precondition_f(kwargs) + else: + good_precondition = True + return good_env and good_precondition return _optimize_pipeline( kwargs_precondition_f=wrapped_kwargs_precondition_f, diff --git a/plumber_analysis/src/plumber_analysis/pipeline_optimizer.py b/plumber_analysis/src/plumber_analysis/pipeline_optimizer.py index 322e4e0..840df7f 100644 --- a/plumber_analysis/src/plumber_analysis/pipeline_optimizer.py +++ b/plumber_analysis/src/plumber_analysis/pipeline_optimizer.py @@ -23,6 +23,7 @@ # TODO(mkuchnik): Move common constants to different file FRACTION_CACHEABLE_MEMORY = 0.9 +DEFAULT_BENCHMARK_TIME = 62 # 62 seconds def _instantiate_pipeline(graphdef, element_spec): placeholders = graphdef_util.find_placeholders(graphdef) @@ -836,7 +837,7 @@ def _update_plumber(self, benchmark_time_s, patch_caches=True): can be discarded if plumber file overwrites the graphdef. """ if benchmark_time_s is None: - benchmark_time_s = 22 + benchmark_time_s = DEFAULT_BENCHMARK_TIME def remove_file_if_exists(filename): if filename: diff --git a/plumber_analysis/src/plumber_analysis/pipeline_optimizer_wrapper.py b/plumber_analysis/src/plumber_analysis/pipeline_optimizer_wrapper.py index 912d170..cc2db0c 100644 --- a/plumber_analysis/src/plumber_analysis/pipeline_optimizer_wrapper.py +++ b/plumber_analysis/src/plumber_analysis/pipeline_optimizer_wrapper.py @@ -15,6 +15,7 @@ import pickle import pprint import time +import shutil import networkx as nx import numpy as np @@ -24,7 +25,10 @@ from plumber_analysis import pipeline_optimizer, gen_util, extensions, bandwidth_utilities import plumber_analysis.machine_info +DEFAULT_STATS_FILENAME = "stats.pb" +DEFAULT_BACKUP_STATS_FILENAME = "stats_original_backup.pb" PARAMS_FILENAME = "params.p" +BENCHMARK_TIME = 62 # 62 seconds def graph_wrapped_benchmark_dataset(dataset, time_limit_s): # NOTE(mkuchnik): Wrapping benchmark_dataset variants naively @@ -65,7 +69,8 @@ def create_optimizer(): # Instances can also be procured via # https://cloud.google.com/compute/docs/tutorials/python-guide - filename = "stats.pb" # TODO(mkuchnik): Don't hardcode + filename = DEFAULT_STATS_FILENAME # TODO(mkuchnik): Don't hardcode + shutil.copyfile(filename, DEFAULT_BACKUP_STATS_FILENAME) # Note(mkuchnik): we can instantiate a machine_info here like this: # machine_info = {'HOSTNAME': 'Localhost', 'CORES': 96, 'MEMORY': int(299e9), @@ -303,7 +308,7 @@ def get_optimized_pipeline(dataset, override_presets=True, logging.info("Running preliminary benchmark") dataset = apply_default_options(dataset, override_presets) - ret = _benchmark_dataset(dataset, time_limit_s=12) + ret = _benchmark_dataset(dataset, time_limit_s=BENCHMARK_TIME) logging.info("End preliminary benchmark") logging.info("benchmark {}".format(ret)) @@ -343,7 +348,7 @@ def get_optimized_pipeline(dataset, override_presets=True, def get_fake_pipeline(dataset, override_presets=True): dataset = apply_default_options(dataset, override_presets) - ret = _benchmark_dataset(dataset, time_limit_s=12) + ret = _benchmark_dataset(dataset, time_limit_s=BENCHMARK_TIME) logging.info("benchmark {}".format(ret)) optimizer = create_optimizer() dataset = optimizer.fake_dataset() @@ -351,7 +356,7 @@ def get_fake_pipeline(dataset, override_presets=True): def get_source_pipeline(dataset, override_presets=True): dataset = apply_default_options(dataset, override_presets) - ret = _benchmark_dataset(dataset, time_limit_s=12) + ret = _benchmark_dataset(dataset, time_limit_s=BENCHMARK_TIME) logging.info("benchmark {}".format(ret)) optimizer = create_optimizer() dataset = optimizer.source_dataset() @@ -419,7 +424,7 @@ def _benchmark_source_parallelisms(optimizer, override_presets=True, def benchmark_source_parallelisms(dataset, override_presets=True, sweep_range=None): """Sweeps a range of parallelism and report benchmark rates""" dataset = apply_default_options(dataset, override_presets, apply_tracing=False) - ret = _benchmark_dataset(dataset, time_limit_s=12) + ret = _benchmark_dataset(dataset, time_limit_s=BENCHMARK_TIME) logging.info("benchmark {}".format(ret)) optimizer = create_optimizer() rets, _ = _benchmark_source_parallelisms(