mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch][ao] Properly strip tracking stats in _fold_conv_bn_qat for 1D (#152982)
Summary: _fold_conv_bn_qat has logic to remove the tracking stats. Currently, this includes a check that includes only torch.nn.modules.batchnorm.BatchNorm2d. As a result, the tracking stats are not properly removed when 1D is used. This diff updates to fix this. Test Plan: Run N7113483 without this fix. {F1977726982} ``` bento kernel build sensorml ``` Re-run with local version of kernel, containing this diff: {F1977727151} Notice that now, num_batches is removed. Differential Revision: D74269649 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152982 Approved by: https://github.com/andrewor14, https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c99ea2991
commit
b86d46ff21
@ -844,6 +844,39 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
self.assertTrue(conv_node is not None)
|
||||
self.assertTrue(bn_node is None)
|
||||
|
||||
def test_fold_bn_erases_add_node(self):
|
||||
"""
|
||||
Test that batch norm stat tracking (which results in an add_ tensor) is removed when folding batch norm.
|
||||
"""
|
||||
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
||||
m = export_for_training(m, self.example_inputs, strict=True).module()
|
||||
|
||||
def _has_add_(graph):
|
||||
for node in graph.nodes:
|
||||
if node.target == torch.ops.aten.add_.Tensor:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Verify that add_ tensor exists in the exported model (for tracking batch norm stats)
|
||||
has_add_tensor_before = _has_add_(m.graph)
|
||||
self.assertTrue(
|
||||
has_add_tensor_before, "Expected to find add_ tensor in the exported model"
|
||||
)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
||||
)
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
# Verify that add_ tensor is removed in the quantized model
|
||||
has_add_tensor_after = _has_add_(m.graph)
|
||||
self.assertFalse(
|
||||
has_add_tensor_after,
|
||||
"Expected add_ tensor to be removed in the quantized model",
|
||||
)
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
|
||||
|
@ -882,8 +882,12 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
node.target == torch.ops.aten.add_.Tensor
|
||||
and node.args[0].op == "get_attr"
|
||||
and node.args[1] == 1
|
||||
and torch.nn.modules.batchnorm.BatchNorm2d
|
||||
in [val[1] for val in node.meta["source_fn_stack"]]
|
||||
and (
|
||||
torch.nn.modules.batchnorm.BatchNorm2d
|
||||
in [val[1] for val in node.meta["source_fn_stack"]]
|
||||
or torch.nn.modules.batchnorm.BatchNorm1d
|
||||
in [val[1] for val in node.meta["source_fn_stack"]]
|
||||
)
|
||||
):
|
||||
m.graph.erase_node(node)
|
||||
|
||||
|
Reference in New Issue
Block a user