diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 53ad10309f43e6..e0fed78b4e61d6 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -41,7 +41,8 @@ def __init__( self, min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, skip_fusion=DEFAULT_SKIP_FUSION, - allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + max_acc_splits: int = -1, ): parser = argparse.ArgumentParser() parser.add_argument( @@ -51,6 +52,13 @@ def __init__( type=int, help="Minimum size limit of an accelerator subgraph.", ) + parser.add_argument( + "--max-acc_splits", + "--max-acc_splits", + required=False, + type=int, + help="Enforce a maximum number of split subgraphs.", + ) parser.add_argument( "--skip-fusion", "--skip_fusion", @@ -78,6 +86,7 @@ def __init__( self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + self.max_acc_splits: int = max_acc_splits @compatibility(is_backward_compatible=False) @@ -876,5 +885,11 @@ def generate_split_results(self) -> SplitResult: submodule_names = [] for name, mod in split_module.named_children(): submodule_names.append(name) + if ( + self.settings.max_acc_splits > 0 + and len(submodule_names) > self.settings.max_acc_splits + ): + raise ValueError("Cannot fulfill max_acc_splits limit. This may cause split fragmentation and result in performance issues.") + submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)