Migrate to training ir in quantization_pt2e_qat unittests (#137232)

Summary: Change capture_pre_autograd_graph to export_for_training in unit tests.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:quantization_pt2e_qat
```

Reviewed By: tugsbayasgalan

Differential Revision: D63336660

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137232
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2024-10-03 22:57:04 +00:00
committed by PyTorch MergeBot
parent b44f25e1ba
commit 4096ed7dc2

View File

@ -5,8 +5,6 @@ import unittest
from typing import Any, Optional, Tuple, Type
import torch
from torch._export import capture_pre_autograd_graph
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
from torch.ao.quantization import (
default_fake_quant,
FusedMovingAvgObsFakeQuantize,
@ -36,6 +34,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
@ -140,10 +139,10 @@ class PT2EQATTestCase(QuantizationTestCase):
is_per_channel=is_per_channel, is_qat=True
)
)
model_pt2e = capture_pre_autograd_graph(
model_pt2e = export_for_training(
model_pt2e,
example_inputs,
)
).module()
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
torch.manual_seed(MANUAL_SEED)
after_prepare_result_pt2e = model_pt2e(*example_inputs)
@ -230,10 +229,10 @@ class PT2EQATTestCase(QuantizationTestCase):
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel, is_qat=True)
)
m = capture_pre_autograd_graph(
m = export_for_training(
m,
example_inputs,
)
).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -253,34 +252,15 @@ class PT2EQATTestCase(QuantizationTestCase):
# Verify: getitem(bn, 0) or relu(getitem(bn, 0))
if has_relu:
relu_node = output_fq_node.args[0]
getitem_node = relu_node.args[0]
bn_node = relu_node.args[0]
self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
else:
relu_node = None
getitem_node = output_fq_node.args[0]
bn_node = output_fq_node.args[0]
is_training_ir_flag = capture_pre_autograd_graph_using_training_ir()
if is_training_ir_flag:
# The relu node takes in the output of bn.
# See NOTE [training ir has no getitem for bn node].
bn_node = getitem_node
self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
else:
# TODO: This branch is going through a deprecated branch and should be deleted soon,
# after capture_pre_autograd_graph fully migrate to training IR
# T199018392
self.assertEqual(getitem_node.target, operator.getitem)
bn_node = getitem_node.args[0]
expected_bn_op = None
if is_cuda:
if torch.version.cuda is not None:
expected_bn_op = torch.ops.aten.cudnn_batch_norm.default
elif torch.version.hip is not None:
expected_bn_op = torch.ops.aten.miopen_batch_norm.default
else:
expected_bn_op = torch.ops.aten._native_batch_norm_legit.default
self.assertEqual(bn_node.target, expected_bn_op)
# The relu node takes in the output of bn.
# See NOTE [training ir has no getitem for bn node].
self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
# Verify: conv / scale_factor.reshape [+ bias.reshape]
if has_bias:
@ -366,12 +346,8 @@ class PT2EQATTestCase(QuantizationTestCase):
bn_running_var_add_node = sqrt_node.args[0]
(bn_running_var_node, eps) = bn_running_var_add_node.args
self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
if is_training_ir_flag:
self.assertTrue("bn.weight" in bn_weight_node.target)
self.assertTrue("bn.running_var" in bn_running_var_node.target)
else:
self.assertTrue("bn_weight" in bn_weight_node.target)
self.assertTrue("bn_running_var" in bn_running_var_node.target)
self.assertTrue("bn.weight" in bn_weight_node.target)
self.assertTrue("bn.running_var" in bn_running_var_node.target)
self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
self.assertEqual(eps, 1e-5)
@ -603,7 +579,10 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
the `unrelated_getitem` node, which is not part of the conv-bn pattern but
is returned as part of the match anyway (as a placeholder).
"""
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
# T199018392
# remove this test after we kill capture_pre_autograd_graph()
if capture_pre_autograd_graph_using_training_ir():
self.skipTest("Not applicable to training IR")
@ -646,7 +625,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
# Program capture
m = M(self.conv_class, self.bn_class)
m = capture_pre_autograd_graph(m, self.example_inputs)
m = torch._export.capture_pre_autograd_graph(m, self.example_inputs)
m.graph.eliminate_dead_code()
m.recompile()
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)
@ -720,7 +699,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
m = M(self.conv_class, self.bn_class, backbone)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
m = capture_pre_autograd_graph(m, example_inputs)
m = export_for_training(m, example_inputs).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
@ -778,7 +757,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_conv_bn_bias_derived_qspec(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = capture_pre_autograd_graph(m, example_inputs)
m = export_for_training(m, example_inputs).module()
quantizer = ConvBnDerivedBiasQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -825,7 +804,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_per_channel_weight_custom_dtype(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = capture_pre_autograd_graph(m, example_inputs)
m = export_for_training(m, example_inputs).module()
quantizer = ConvBnInt32WeightQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -879,7 +858,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_conv_bn_per_channel_weight_bias(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = capture_pre_autograd_graph(m, example_inputs)
m = export_for_training(m, example_inputs).module()
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -936,7 +915,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
it into conv in `convert_pt2e` even in train mode.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = capture_pre_autograd_graph(m, self.example_inputs)
m = export_for_training(m, self.example_inputs).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
@ -953,6 +932,9 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
TODO: Remove this test after training IR migration.
T199018392
"""
from torch._export import capture_pre_autograd_graph
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
if capture_pre_autograd_graph_using_training_ir():
self.skipTest(
"test doesn't apply when capture_pre_autograd_graph is using training IR"
@ -1061,21 +1043,12 @@ class ConvBnInt32WeightQuantizer(Quantizer):
},
_annotated=True,
)
if getitem_node is not None:
# TODO: This branch is going through a deprecated branch and should be deleted soon,
# after capture_pre_autograd_graph fully migrate to training IR
# T199018392
getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
else:
# See NOTE [training ir has no getitem for bn node].
assert capture_pre_autograd_graph_using_training_ir()
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
# See NOTE [training ir has no getitem for bn node].
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
return model
def validate(self, model: torch.fx.GraphModule):
@ -1148,25 +1121,16 @@ class ConvBnDerivedBiasQuantizer(Quantizer):
_annotated=True,
)
if getitem_node is not None:
# TODO: This branch is going through a deprecated branch and should be deleted soon,
# after capture_pre_autograd_graph fully migrate to training IR
# T199018392
getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
else:
# NOTE [training ir has no getitem for bn node].
# getitem is None when we use the training IR. It outputs
# aten.batch_norm.default, which do not need any getitem node.
# In this case, we need to annotate on the batch norm node.
# geteitem node should only be None if we are using training IR.
assert capture_pre_autograd_graph_using_training_ir()
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
# NOTE [training ir has no getitem for bn node].
# getitem is None when we use the training IR. It outputs
# aten.batch_norm.default, which do not need any getitem node.
# In this case, we need to annotate on the batch norm node.
# geteitem node should only be None if we are using training IR.
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
return model
def validate(self, model: torch.fx.GraphModule):
@ -1232,7 +1196,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
in_channels = child.linear1.weight.size(1)
example_input = (torch.rand((1, in_channels)),)
traced_child = capture_pre_autograd_graph(child, example_input)
traced_child = export_for_training(child, example_input).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
is_per_channel=True, is_qat=True
@ -1263,10 +1227,10 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
self._convert_qat_linears(model)
quant_result_pt2e = model(*example_inputs)
model_pt2e = capture_pre_autograd_graph(
model_pt2e = export_for_training(
model,
example_inputs,
)
).module()
quantizer = XNNPACKQuantizer()
quantizer.set_module_type(torch.nn.Linear, None)