From 0694918aeb115bae98c6ab19af69d7d259fed09e Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 24 Aug 2024 05:50:16 +0000 Subject: [PATCH] [export] Temporarily bypass torch_fn in partitioner (#134292) Summary: "torch_fn" is not correct for the decomposed add node from batch norm. This is a temporary workaround to bypass torch fn. For example, for the graph below (test_qat_conv2d_unary graph): ``` graph(): %conv_weight : [num_users=1] = get_attr[target=conv.weight] %bn_weight : [num_users=1] = get_attr[target=bn.weight] %bn_bias : [num_users=1] = get_attr[target=bn.bias] %bn_running_mean : [num_users=1] = get_attr[target=bn.running_mean] %bn_running_var : [num_users=1] = get_attr[target=bn.running_var] %bn_num_batches_tracked : [num_users=1] = get_attr[target=bn.num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %conv_weight, None, [1, 1], [1, 1]), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%bn_num_batches_tracked, 1), kwargs = {}) %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %bn_weight, %bn_bias, %bn_running_mean, %bn_running_var, True, 0.1, 1e-05, True), kwargs = {}) %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%batch_norm,), kwargs = {}) %max_pool2d : [num_users=1] = call_function[target=torch.ops.aten.max_pool2d.default](args = (%relu, [3, 3], [3, 3]), kwargs = {}) return (max_pool2d,) ``` the add_ node has `'torch_fn': ('add__1', 'method_descriptor.add_'),` in its meta. If we run the line below in `_annotate_qat_conv2d_bn_binary_unary`, we'll have a partition without output nodes. ``` find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] ) ```` ``` partition_list [ SourcePartition(nodes=[conv_weight, conv2d], source=, input_nodes=[x], output_nodes=[conv2d], params=[conv_weight]), SourcePartition(nodes=[bn_weight, bn_bias, bn_running_mean, bn_running_var, bn_num_batches_tracked, add_, batch_norm], source=, input_nodes=[conv2d], output_nodes=[batch_norm], params=[bn_num_batches_tracked, bn_running_var, bn_bias, bn_weight, bn_running_mean]), SourcePartition(nodes=[add_], source='add_', input_nodes=[bn_num_batches_tracked], output_nodes=[], params=[]) ] ``` We should not have the last partition. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_qat_conv2d ``` Differential Revision: D61569049 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134292 Approved by: https://github.com/angelayi --- test/fx/test_source_matcher_utils.py | 20 +++++++++++++++++++ torch/fx/passes/utils/source_matcher_utils.py | 7 ++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/test/fx/test_source_matcher_utils.py b/test/fx/test_source_matcher_utils.py index 27120e0d99f8..e0882c75b634 100644 --- a/test/fx/test_source_matcher_utils.py +++ b/test/fx/test_source_matcher_utils.py @@ -253,6 +253,11 @@ class TestSourceMatcher(JitTestCase): gm = torch.export.export(M(), inputs, strict=strict).module() gm.graph.eliminate_dead_code() + # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. + # TODO: remove this after we fix "torch_fn". T199561090 + for node in gm.graph.nodes: + node.meta["source_fn_stack"] = None + module_partitions = get_source_partitions(gm.graph, ["linear", "relu"]) self.assertEqual(len(module_partitions), 2) @@ -310,6 +315,11 @@ class TestSourceMatcher(JitTestCase): ).module() gm.graph.eliminate_dead_code() + # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. + # TODO: remove this after we fix "torch_fn". T199561090 + for node in gm.graph.nodes: + node.meta["source_fn_stack"] = None + module_partitions = get_source_partitions( gm.graph, ["conv2d", "relu", "max_pool2d"] ) @@ -390,6 +400,11 @@ class TestSourceMatcher(JitTestCase): gm = torch.export.export(M(), inputs, strict=strict).module() gm.graph.eliminate_dead_code() + # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. + # TODO: remove this after we fix "torch_fn". T199561090 + for node in gm.graph.nodes: + node.meta["source_fn_stack"] = None + module_partitions = get_source_partitions(gm.graph, ["conv2d"]) self.assertEqual(len(module_partitions), 1) @@ -417,6 +432,11 @@ class TestSourceMatcher(JitTestCase): gm = torch.export.export(M(), inputs, strict=strict).module() gm.graph.eliminate_dead_code() + # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. + # TODO: remove this after we fix "torch_fn". T199561090 + for node in gm.graph.nodes: + node.meta["source_fn_stack"] = None + module_partitions = get_source_partitions(gm.graph, ["linear", "relu"]) self.assertEqual(len(module_partitions), 2) diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index ca3dacf47b6e..0a4f072644cd 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -73,7 +73,12 @@ def get_source_partitions( # function, or the type of module if the node is decomposed from a leaf # module - if (torch_fn := node.meta.get("torch_fn", None)) is not None: + # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can + # be different from "source_fn_stack", for example for the add_ node + # decomposed from batch norm. We should remove the check on "source_fn_stack" + # after we fix "torch_fn". T199561090 + if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and + (torch_fn := node.meta.get("torch_fn", None)) is not None): node_fqn, source_fn = torch_fn source_fn_name = source_fn.split(".")[1] if source_fn_name in wanted_sources: