[export][training ir migration] fix batch norm pattern match in quantization (#134157)

Summary:
In the new training ir, we produce `torch.ops.aten.batch_norm.default` instead of `torch.ops.aten._native_batch_norm_legit.default` or `torch.ops.aten._native_batch_norm_legit_no_training.default`.

So we need to change the pattern match to accomodate the new op.

- Add `torch.ops.aten.batch_norm.default` to pattern matcher list so it's identified as a batch norm node
- `torch.ops.aten.batch_norm.default` doesn't have a getitem user anymore, so when removing the bn norm,  we need to do `bn_node.replace_all_uses_with(conv_node)` instead of `getitem_node.replace_all_uses_with(conv_node)`

The behavior of capture_pre_autograd_graph is consistent for each run.

If the run is a fbcode test, then capture_pre_autograd_graph uses training IR. This means both _get_aten_graph_module_for_pattern and  replace_pattern_with_filters see the same training IR.

If the run is not a fbcode test, then both would see the old IR.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_conv2d_binary2
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_conv2d_unary
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_linear_unary
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_dynamic_quant_linear
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_qat_dynamic_quant_linear
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_flatten_recipe
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_linear_unary
```

Reviewed By: andrewor14, tugsbayasgalan

Differential Revision: D61291077

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134157
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Shangdi Yu
2024-08-22 18:25:45 +00:00
committed by PyTorch MergeBot
parent fee677eeb6
commit 978c5a80a0

View File

@ -171,6 +171,7 @@ def _is_supported_batch_norm_for_training(node: Node):
Return True if the given node refers to an aten batch norm op QAT supports.
"""
supported_ops = [
torch.ops.aten.batch_norm.default,
torch.ops.aten._native_batch_norm_legit.default,
# Note: we won't need this op anymore after batch norm consolidation
# For now, we need to continue to support it because it gives better
@ -279,25 +280,35 @@ def fold_bn_weights_into_conv_node(
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_node.users:
if (
user.op != "call_function"
or user.target != operator.getitem
or user.args[1] != 0
):
continue
user.replace_all_uses_with(conv_node)
if bn_node.target == torch.ops.aten.batch_norm.default:
# With the new training ir, instead of batch_norm + getitem,
# we only have the batch_norm node.
#
# Before:
# conv -> bn -> users
# After:
# conv -> users
# bn has no users now
bn_node.replace_all_uses_with(conv_node)
else:
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_node.users:
if (
user.op != "call_function"
or user.target != operator.getitem
or user.args[1] != 0
):
continue
user.replace_all_uses_with(conv_node)
# If the BN node does not have users, erase it from the graph
# Note: we need to do this manually because the model can still be in train
@ -315,9 +326,9 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
if not has_bn:
return
for n in m.graph.nodes:
if (
n.op != "call_function"
or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default
if n.op != "call_function" or n.target not in (
torch.ops.aten._native_batch_norm_legit_no_training.default,
torch.ops.aten.batch_norm.default,
):
continue
bn_node = n