[torch][ao] Properly strip tracking stats in _fold_conv_bn_qat for 1D (#152982)

Summary: _fold_conv_bn_qat has logic to remove the tracking stats. Currently, this includes a check that includes only torch.nn.modules.batchnorm.BatchNorm2d. As a result, the tracking stats are not properly removed when 1D is used. This diff updates to fix this.

Test Plan:
Run N7113483 without this fix.

{F1977726982}

```
bento kernel build sensorml
```

Re-run with local version of kernel, containing this diff:

{F1977727151}

Notice that now, num_batches is removed.

Differential Revision: D74269649

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152982
Approved by: https://github.com/andrewor14, https://github.com/yushangdi
This commit is contained in:
Jake Stevens
2025-05-10 01:20:15 +00:00
committed by PyTorch MergeBot
parent 9c99ea2991
commit b86d46ff21
2 changed files with 39 additions and 2 deletions

View File

@ -844,6 +844,39 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
self.assertTrue(conv_node is not None)
self.assertTrue(bn_node is None)
def test_fold_bn_erases_add_node(self):
"""
Test that batch norm stat tracking (which results in an add_ tensor) is removed when folding batch norm.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = export_for_training(m, self.example_inputs, strict=True).module()
def _has_add_(graph):
for node in graph.nodes:
if node.target == torch.ops.aten.add_.Tensor:
return True
return False
# Verify that add_ tensor exists in the exported model (for tracking batch norm stats)
has_add_tensor_before = _has_add_(m.graph)
self.assertTrue(
has_add_tensor_before, "Expected to find add_ tensor in the exported model"
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
)
m = prepare_qat_pt2e(m, quantizer)
m = convert_pt2e(m)
# Verify that add_ tensor is removed in the quantized model
has_add_tensor_after = _has_add_(m.graph)
self.assertFalse(
has_add_tensor_after,
"Expected add_ tensor to be removed in the quantized model",
)
@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):

View File

@ -882,8 +882,12 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
node.target == torch.ops.aten.add_.Tensor
and node.args[0].op == "get_attr"
and node.args[1] == 1
and torch.nn.modules.batchnorm.BatchNorm2d
in [val[1] for val in node.meta["source_fn_stack"]]
and (
torch.nn.modules.batchnorm.BatchNorm2d
in [val[1] for val in node.meta["source_fn_stack"]]
or torch.nn.modules.batchnorm.BatchNorm1d
in [val[1] for val in node.meta["source_fn_stack"]]
)
):
m.graph.erase_node(node)