mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Temporarily bypass torch_fn in partitioner (#134292)
Summary: "torch_fn" is not correct for the decomposed add node from batch norm. This is a temporary workaround to bypass torch fn. For example, for the graph below (test_qat_conv2d_unary graph): ``` graph(): %conv_weight : [num_users=1] = get_attr[target=conv.weight] %bn_weight : [num_users=1] = get_attr[target=bn.weight] %bn_bias : [num_users=1] = get_attr[target=bn.bias] %bn_running_mean : [num_users=1] = get_attr[target=bn.running_mean] %bn_running_var : [num_users=1] = get_attr[target=bn.running_var] %bn_num_batches_tracked : [num_users=1] = get_attr[target=bn.num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %conv_weight, None, [1, 1], [1, 1]), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%bn_num_batches_tracked, 1), kwargs = {}) %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %bn_weight, %bn_bias, %bn_running_mean, %bn_running_var, True, 0.1, 1e-05, True), kwargs = {}) %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%batch_norm,), kwargs = {}) %max_pool2d : [num_users=1] = call_function[target=torch.ops.aten.max_pool2d.default](args = (%relu, [3, 3], [3, 3]), kwargs = {}) return (max_pool2d,) ``` the add_ node has `'torch_fn': ('add__1', 'method_descriptor.add_'),` in its meta. If we run the line below in `_annotate_qat_conv2d_bn_binary_unary`, we'll have a partition without output nodes. ``` find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] ) ```` ``` partition_list [ SourcePartition(nodes=[conv_weight, conv2d], source=<class 'torch.nn.modules.conv.Conv2d'>, input_nodes=[x], output_nodes=[conv2d], params=[conv_weight]), SourcePartition(nodes=[bn_weight, bn_bias, bn_running_mean, bn_running_var, bn_num_batches_tracked, add_, batch_norm], source=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, input_nodes=[conv2d], output_nodes=[batch_norm], params=[bn_num_batches_tracked, bn_running_var, bn_bias, bn_weight, bn_running_mean]), SourcePartition(nodes=[add_], source='add_', input_nodes=[bn_num_batches_tracked], output_nodes=[], params=[]) ] ``` We should not have the last partition. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_qat_conv2d ``` Differential Revision: D61569049 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134292 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
f260cc2edf
commit
0694918aeb
@ -253,6 +253,11 @@ class TestSourceMatcher(JitTestCase):
|
||||
gm = torch.export.export(M(), inputs, strict=strict).module()
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
|
||||
# TODO: remove this after we fix "torch_fn". T199561090
|
||||
for node in gm.graph.nodes:
|
||||
node.meta["source_fn_stack"] = None
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])
|
||||
|
||||
self.assertEqual(len(module_partitions), 2)
|
||||
@ -310,6 +315,11 @@ class TestSourceMatcher(JitTestCase):
|
||||
).module()
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
|
||||
# TODO: remove this after we fix "torch_fn". T199561090
|
||||
for node in gm.graph.nodes:
|
||||
node.meta["source_fn_stack"] = None
|
||||
|
||||
module_partitions = get_source_partitions(
|
||||
gm.graph, ["conv2d", "relu", "max_pool2d"]
|
||||
)
|
||||
@ -390,6 +400,11 @@ class TestSourceMatcher(JitTestCase):
|
||||
gm = torch.export.export(M(), inputs, strict=strict).module()
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
|
||||
# TODO: remove this after we fix "torch_fn". T199561090
|
||||
for node in gm.graph.nodes:
|
||||
node.meta["source_fn_stack"] = None
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, ["conv2d"])
|
||||
|
||||
self.assertEqual(len(module_partitions), 1)
|
||||
@ -417,6 +432,11 @@ class TestSourceMatcher(JitTestCase):
|
||||
gm = torch.export.export(M(), inputs, strict=strict).module()
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
# Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
|
||||
# TODO: remove this after we fix "torch_fn". T199561090
|
||||
for node in gm.graph.nodes:
|
||||
node.meta["source_fn_stack"] = None
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])
|
||||
|
||||
self.assertEqual(len(module_partitions), 2)
|
||||
|
@ -73,7 +73,12 @@ def get_source_partitions(
|
||||
# function, or the type of module if the node is decomposed from a leaf
|
||||
# module
|
||||
|
||||
if (torch_fn := node.meta.get("torch_fn", None)) is not None:
|
||||
# TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can
|
||||
# be different from "source_fn_stack", for example for the add_ node
|
||||
# decomposed from batch norm. We should remove the check on "source_fn_stack"
|
||||
# after we fix "torch_fn". T199561090
|
||||
if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and
|
||||
(torch_fn := node.meta.get("torch_fn", None)) is not None):
|
||||
node_fqn, source_fn = torch_fn
|
||||
source_fn_name = source_fn.split(".")[1]
|
||||
if source_fn_name in wanted_sources:
|
||||
|
Reference in New Issue
Block a user