[export] Temporarily bypass torch_fn in partitioner (#134292)

Summary:
"torch_fn" is not correct for the decomposed add node from batch norm. This is a temporary workaround to bypass torch fn.

For example, for the graph below (test_qat_conv2d_unary graph):
```
graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %bn_weight : [num_users=1] = get_attr[target=bn.weight]
    %bn_bias : [num_users=1] = get_attr[target=bn.bias]
    %bn_running_mean : [num_users=1] = get_attr[target=bn.running_mean]
    %bn_running_var : [num_users=1] = get_attr[target=bn.running_var]
    %bn_num_batches_tracked : [num_users=1] = get_attr[target=bn.num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %conv_weight, None, [1, 1], [1, 1]), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%bn_num_batches_tracked, 1), kwargs = {})
    %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %bn_weight, %bn_bias, %bn_running_mean, %bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%batch_norm,), kwargs = {})
    %max_pool2d : [num_users=1] = call_function[target=torch.ops.aten.max_pool2d.default](args = (%relu, [3, 3], [3, 3]), kwargs = {})
    return (max_pool2d,)
```

the add_ node has `'torch_fn': ('add__1', 'method_descriptor.add_'),` in its meta.

If we run the line below in `_annotate_qat_conv2d_bn_binary_unary`, we'll have a partition without output nodes.

```
 find_sequential_partitions(
            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU]
        )
````

```
partition_list
[
SourcePartition(nodes=[conv_weight, conv2d], source=<class 'torch.nn.modules.conv.Conv2d'>, input_nodes=[x], output_nodes=[conv2d], params=[conv_weight]),

SourcePartition(nodes=[bn_weight, bn_bias, bn_running_mean, bn_running_var, bn_num_batches_tracked, add_, batch_norm], source=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, input_nodes=[conv2d], output_nodes=[batch_norm], params=[bn_num_batches_tracked, bn_running_var, bn_bias, bn_weight, bn_running_mean]),

SourcePartition(nodes=[add_], source='add_', input_nodes=[bn_num_batches_tracked], output_nodes=[], params=[])
]
```
We should not have the last partition.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_qat_conv2d
```

Differential Revision: D61569049

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134292
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2024-08-24 05:50:16 +00:00
committed by PyTorch MergeBot
parent f260cc2edf
commit 0694918aeb
2 changed files with 26 additions and 1 deletions

View File

@ -253,6 +253,11 @@ class TestSourceMatcher(JitTestCase):
gm = torch.export.export(M(), inputs, strict=strict).module()
gm.graph.eliminate_dead_code()
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
# TODO: remove this after we fix "torch_fn". T199561090
for node in gm.graph.nodes:
node.meta["source_fn_stack"] = None
module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])
self.assertEqual(len(module_partitions), 2)
@ -310,6 +315,11 @@ class TestSourceMatcher(JitTestCase):
).module()
gm.graph.eliminate_dead_code()
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
# TODO: remove this after we fix "torch_fn". T199561090
for node in gm.graph.nodes:
node.meta["source_fn_stack"] = None
module_partitions = get_source_partitions(
gm.graph, ["conv2d", "relu", "max_pool2d"]
)
@ -390,6 +400,11 @@ class TestSourceMatcher(JitTestCase):
gm = torch.export.export(M(), inputs, strict=strict).module()
gm.graph.eliminate_dead_code()
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
# TODO: remove this after we fix "torch_fn". T199561090
for node in gm.graph.nodes:
node.meta["source_fn_stack"] = None
module_partitions = get_source_partitions(gm.graph, ["conv2d"])
self.assertEqual(len(module_partitions), 1)
@ -417,6 +432,11 @@ class TestSourceMatcher(JitTestCase):
gm = torch.export.export(M(), inputs, strict=strict).module()
gm.graph.eliminate_dead_code()
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
# TODO: remove this after we fix "torch_fn". T199561090
for node in gm.graph.nodes:
node.meta["source_fn_stack"] = None
module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])
self.assertEqual(len(module_partitions), 2)

View File

@ -73,7 +73,12 @@ def get_source_partitions(
# function, or the type of module if the node is decomposed from a leaf
# module
if (torch_fn := node.meta.get("torch_fn", None)) is not None:
# TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can
# be different from "source_fn_stack", for example for the add_ node
# decomposed from batch norm. We should remove the check on "source_fn_stack"
# after we fix "torch_fn". T199561090
if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and
(torch_fn := node.meta.get("torch_fn", None)) is not None):
node_fqn, source_fn = torch_fn
source_fn_name = source_fn.split(".")[1]
if source_fn_name in wanted_sources: