Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add max_acc_splits (facebookincubator#1017)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#1017

X-link: pytorch/pytorch#133041

Model owners can set the lower_settings with max_acc_splits=2, and lowering will fail during model iteration, to alert them of possible performance degradation from increased fragmentation.

Reviewed By: frank-wei

Differential Revision: D60133589
qxy11 authored and facebook-github-bot committed Aug 13, 2024
1 parent 2aef297 commit 96d2f06
Showing 4 changed files with 61 additions and 0 deletions.
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/ait_splitter.py
Original file line number Diff line number Diff line change
@@ -115,12 +115,14 @@ def __init__(
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
allow_int_inputs=False,
debug_operator_range=None,
max_acc_splits=-1,
):
super().__init__()
self.min_acc_module_size = min_acc_module_size
self.exclude_support_node_name: set = set()
self.allow_int_inputs: bool = allow_int_inputs
self.debug_operator_range = debug_operator_range
self.max_acc_splits = max_acc_splits


class SelectedOperatorSupport(ops.OperatorSupportBase):
1 change: 1 addition & 0 deletions fx2ait/fx2ait/lower/lower.py
Original file line number Diff line number Diff line change
@@ -100,6 +100,7 @@ def default_split_function(
settings = AITSplitterSettings(
min_acc_module_size=lower_settings.min_acc_module_size,
allow_int_inputs=lower_settings.allow_int_inputs,
max_acc_splits=lower_settings.max_acc_splits,
)
splitter = AITSplitter(model, inputs, settings=settings)
splitter.node_support_preview()
3 changes: 3 additions & 0 deletions fx2ait/fx2ait/lower/lower_settings.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,9 @@ class LowerSettings:

max_batch_size: int = 2048
min_acc_module_size: int = 10
# Maximum number of splits for lowered module
# (eg. if lowered module is split into _run_on_gpu_0(unlowered submodule) and _run_on_acc_1(lowered submodule) it has 2 splits)
max_acc_splits: int = -1
workdir: str = ""
name: str = ""
dll_name: str = "ait_engine.so"
55 changes: 55 additions & 0 deletions fx2ait/fx2ait/test/test_ait_splitter.py
Original file line number Diff line number Diff line change
@@ -250,3 +250,58 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
dict(split_results_relu_allowed.split_module.named_children()).keys(),
{"_run_on_acc_0"},
)

def test_fail_if_exceed_max_acc_split_limit(self):
class TestModule(torch.nn.Module):
def forward(self, a):
b = torch.sin(a)
c = torch.relu(b)
d = torch.cos(c)
e = torch.sigmoid(d)
f = torch.tanh(e)
return f

# Support all ops
_support_dict = {
"acc_ops.sin": None,
"acc_ops.cos": None,
"acc_ops.relu": None,
"acc_ops.sigmoid": None,
"acc_ops.tanh": None,
}
custom_op_support = op_support.OperatorSupport(_support_dict)

# With no ops excluded, the entire module should be lowered
# into one acc graph
mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)])
settings = AITSplitterSettings(min_acc_module_size=0, max_acc_splits=1)
splitter = AITSplitter(
mod,
(torch.randn(2, 3),),
custom_op_support,
settings,
)

res_all_nodes_supported = splitter.generate_split_results()
split_named_mods = dict(res_all_nodes_supported.split_module.named_children())
self.assertEqual(len(split_named_mods), 1)
self.assertIn("_run_on_acc_0", split_named_mods)

# Add "relu" to exclude_support_node_name
# The graph should be split into 3 parts now(_run_on_acc_0, _run_on_gpu_1, _run_on_acc_2)
mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)])
for node in mod.graph.nodes:
if node.target == acc_ops.relu:
settings.exclude_support_node_name.add(node.name)
splitter = AITSplitter(
mod,
(torch.randn(2, 3),),
custom_op_support,
settings,
)
# Split should fail now
with self.assertRaisesRegex(
ValueError,
"Cannot fulfill max_acc_splits limit. This may cause split fragmentation and result in performance issues.",
):
splitter.generate_split_results()

0 comments on commit 96d2f06

Please sign in to comment.