mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export][training ir migration] fix batch norm pattern match in quantization (#134157)
Summary: In the new training ir, we produce `torch.ops.aten.batch_norm.default` instead of `torch.ops.aten._native_batch_norm_legit.default` or `torch.ops.aten._native_batch_norm_legit_no_training.default`. So we need to change the pattern match to accomodate the new op. - Add `torch.ops.aten.batch_norm.default` to pattern matcher list so it's identified as a batch norm node - `torch.ops.aten.batch_norm.default` doesn't have a getitem user anymore, so when removing the bn norm, we need to do `bn_node.replace_all_uses_with(conv_node)` instead of `getitem_node.replace_all_uses_with(conv_node)` The behavior of capture_pre_autograd_graph is consistent for each run. If the run is a fbcode test, then capture_pre_autograd_graph uses training IR. This means both _get_aten_graph_module_for_pattern and replace_pattern_with_filters see the same training IR. If the run is not a fbcode test, then both would see the old IR. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_conv2d_binary2 buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_conv2d_unary buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_linear_unary buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_dynamic_quant_linear buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_qat_dynamic_quant_linear buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_flatten_recipe buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_linear_unary ``` Reviewed By: andrewor14, tugsbayasgalan Differential Revision: D61291077 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134157 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
fee677eeb6
commit
978c5a80a0
@ -171,6 +171,7 @@ def _is_supported_batch_norm_for_training(node: Node):
|
||||
Return True if the given node refers to an aten batch norm op QAT supports.
|
||||
"""
|
||||
supported_ops = [
|
||||
torch.ops.aten.batch_norm.default,
|
||||
torch.ops.aten._native_batch_norm_legit.default,
|
||||
# Note: we won't need this op anymore after batch norm consolidation
|
||||
# For now, we need to continue to support it because it gives better
|
||||
@ -279,25 +280,35 @@ def fold_bn_weights_into_conv_node(
|
||||
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
||||
# and we want to replace the uses of getitem 0 with the output of conv
|
||||
#
|
||||
# Before:
|
||||
# conv -> bn - (first output) -> users1
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# After:
|
||||
# conv -> (first output) -> users1
|
||||
# bn -
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
||||
|
||||
for user in bn_node.users:
|
||||
if (
|
||||
user.op != "call_function"
|
||||
or user.target != operator.getitem
|
||||
or user.args[1] != 0
|
||||
):
|
||||
continue
|
||||
user.replace_all_uses_with(conv_node)
|
||||
if bn_node.target == torch.ops.aten.batch_norm.default:
|
||||
# With the new training ir, instead of batch_norm + getitem,
|
||||
# we only have the batch_norm node.
|
||||
#
|
||||
# Before:
|
||||
# conv -> bn -> users
|
||||
# After:
|
||||
# conv -> users
|
||||
# bn has no users now
|
||||
bn_node.replace_all_uses_with(conv_node)
|
||||
else:
|
||||
# Before:
|
||||
# conv -> bn - (first output) -> users1
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# After:
|
||||
# conv -> (first output) -> users1
|
||||
# bn -
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
||||
for user in bn_node.users:
|
||||
if (
|
||||
user.op != "call_function"
|
||||
or user.target != operator.getitem
|
||||
or user.args[1] != 0
|
||||
):
|
||||
continue
|
||||
user.replace_all_uses_with(conv_node)
|
||||
|
||||
# If the BN node does not have users, erase it from the graph
|
||||
# Note: we need to do this manually because the model can still be in train
|
||||
@ -315,9 +326,9 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
if not has_bn:
|
||||
return
|
||||
for n in m.graph.nodes:
|
||||
if (
|
||||
n.op != "call_function"
|
||||
or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
if n.op != "call_function" or n.target not in (
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default,
|
||||
torch.ops.aten.batch_norm.default,
|
||||
):
|
||||
continue
|
||||
bn_node = n
|
||||
|
Reference in New Issue
Block a user