mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Migrate to training ir in quantization_pt2e_qat unittests (#137232)
Summary: Change capture_pre_autograd_graph to export_for_training in unit tests. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:quantization_pt2e_qat ``` Reviewed By: tugsbayasgalan Differential Revision: D63336660 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137232 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
b44f25e1ba
commit
4096ed7dc2
@ -5,8 +5,6 @@ import unittest
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
from torch.ao.quantization import (
|
||||
default_fake_quant,
|
||||
FusedMovingAvgObsFakeQuantize,
|
||||
@ -36,6 +34,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
@ -140,10 +139,10 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
is_per_channel=is_per_channel, is_qat=True
|
||||
)
|
||||
)
|
||||
model_pt2e = capture_pre_autograd_graph(
|
||||
model_pt2e = export_for_training(
|
||||
model_pt2e,
|
||||
example_inputs,
|
||||
)
|
||||
).module()
|
||||
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
after_prepare_result_pt2e = model_pt2e(*example_inputs)
|
||||
@ -230,10 +229,10 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel, is_qat=True)
|
||||
)
|
||||
m = capture_pre_autograd_graph(
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
).module()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
|
||||
@ -253,34 +252,15 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
# Verify: getitem(bn, 0) or relu(getitem(bn, 0))
|
||||
if has_relu:
|
||||
relu_node = output_fq_node.args[0]
|
||||
getitem_node = relu_node.args[0]
|
||||
bn_node = relu_node.args[0]
|
||||
self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
|
||||
else:
|
||||
relu_node = None
|
||||
getitem_node = output_fq_node.args[0]
|
||||
bn_node = output_fq_node.args[0]
|
||||
|
||||
is_training_ir_flag = capture_pre_autograd_graph_using_training_ir()
|
||||
if is_training_ir_flag:
|
||||
# The relu node takes in the output of bn.
|
||||
# See NOTE [training ir has no getitem for bn node].
|
||||
bn_node = getitem_node
|
||||
self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
|
||||
else:
|
||||
# TODO: This branch is going through a deprecated branch and should be deleted soon,
|
||||
# after capture_pre_autograd_graph fully migrate to training IR
|
||||
# T199018392
|
||||
self.assertEqual(getitem_node.target, operator.getitem)
|
||||
bn_node = getitem_node.args[0]
|
||||
|
||||
expected_bn_op = None
|
||||
if is_cuda:
|
||||
if torch.version.cuda is not None:
|
||||
expected_bn_op = torch.ops.aten.cudnn_batch_norm.default
|
||||
elif torch.version.hip is not None:
|
||||
expected_bn_op = torch.ops.aten.miopen_batch_norm.default
|
||||
else:
|
||||
expected_bn_op = torch.ops.aten._native_batch_norm_legit.default
|
||||
self.assertEqual(bn_node.target, expected_bn_op)
|
||||
# The relu node takes in the output of bn.
|
||||
# See NOTE [training ir has no getitem for bn node].
|
||||
self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
|
||||
|
||||
# Verify: conv / scale_factor.reshape [+ bias.reshape]
|
||||
if has_bias:
|
||||
@ -366,12 +346,8 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
bn_running_var_add_node = sqrt_node.args[0]
|
||||
(bn_running_var_node, eps) = bn_running_var_add_node.args
|
||||
self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
|
||||
if is_training_ir_flag:
|
||||
self.assertTrue("bn.weight" in bn_weight_node.target)
|
||||
self.assertTrue("bn.running_var" in bn_running_var_node.target)
|
||||
else:
|
||||
self.assertTrue("bn_weight" in bn_weight_node.target)
|
||||
self.assertTrue("bn_running_var" in bn_running_var_node.target)
|
||||
self.assertTrue("bn.weight" in bn_weight_node.target)
|
||||
self.assertTrue("bn.running_var" in bn_running_var_node.target)
|
||||
self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
|
||||
self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
|
||||
self.assertEqual(eps, 1e-5)
|
||||
@ -603,7 +579,10 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
the `unrelated_getitem` node, which is not part of the conv-bn pattern but
|
||||
is returned as part of the match anyway (as a placeholder).
|
||||
"""
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
|
||||
# T199018392
|
||||
# remove this test after we kill capture_pre_autograd_graph()
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
self.skipTest("Not applicable to training IR")
|
||||
|
||||
@ -646,7 +625,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
|
||||
# Program capture
|
||||
m = M(self.conv_class, self.bn_class)
|
||||
m = capture_pre_autograd_graph(m, self.example_inputs)
|
||||
m = torch._export.capture_pre_autograd_graph(m, self.example_inputs)
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)
|
||||
@ -720,7 +699,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
m = M(self.conv_class, self.bn_class, backbone)
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
@ -778,7 +757,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_conv_bn_bias_derived_qspec(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
quantizer = ConvBnDerivedBiasQuantizer()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -825,7 +804,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_per_channel_weight_custom_dtype(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
quantizer = ConvBnInt32WeightQuantizer()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -879,7 +858,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_conv_bn_per_channel_weight_bias(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -936,7 +915,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
it into conv in `convert_pt2e` even in train mode.
|
||||
"""
|
||||
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
||||
m = capture_pre_autograd_graph(m, self.example_inputs)
|
||||
m = export_for_training(m, self.example_inputs).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
||||
@ -953,6 +932,9 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
TODO: Remove this test after training IR migration.
|
||||
T199018392
|
||||
"""
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
self.skipTest(
|
||||
"test doesn't apply when capture_pre_autograd_graph is using training IR"
|
||||
@ -1061,21 +1043,12 @@ class ConvBnInt32WeightQuantizer(Quantizer):
|
||||
},
|
||||
_annotated=True,
|
||||
)
|
||||
if getitem_node is not None:
|
||||
# TODO: This branch is going through a deprecated branch and should be deleted soon,
|
||||
# after capture_pre_autograd_graph fully migrate to training IR
|
||||
# T199018392
|
||||
getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
else:
|
||||
# See NOTE [training ir has no getitem for bn node].
|
||||
assert capture_pre_autograd_graph_using_training_ir()
|
||||
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
# See NOTE [training ir has no getitem for bn node].
|
||||
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
return model
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule):
|
||||
@ -1148,25 +1121,16 @@ class ConvBnDerivedBiasQuantizer(Quantizer):
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
if getitem_node is not None:
|
||||
# TODO: This branch is going through a deprecated branch and should be deleted soon,
|
||||
# after capture_pre_autograd_graph fully migrate to training IR
|
||||
# T199018392
|
||||
getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
else:
|
||||
# NOTE [training ir has no getitem for bn node].
|
||||
# getitem is None when we use the training IR. It outputs
|
||||
# aten.batch_norm.default, which do not need any getitem node.
|
||||
# In this case, we need to annotate on the batch norm node.
|
||||
# geteitem node should only be None if we are using training IR.
|
||||
assert capture_pre_autograd_graph_using_training_ir()
|
||||
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
# NOTE [training ir has no getitem for bn node].
|
||||
# getitem is None when we use the training IR. It outputs
|
||||
# aten.batch_norm.default, which do not need any getitem node.
|
||||
# In this case, we need to annotate on the batch norm node.
|
||||
# geteitem node should only be None if we are using training IR.
|
||||
|
||||
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
return model
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule):
|
||||
@ -1232,7 +1196,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
in_channels = child.linear1.weight.size(1)
|
||||
|
||||
example_input = (torch.rand((1, in_channels)),)
|
||||
traced_child = capture_pre_autograd_graph(child, example_input)
|
||||
traced_child = export_for_training(child, example_input).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True, is_qat=True
|
||||
@ -1263,10 +1227,10 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
self._convert_qat_linears(model)
|
||||
quant_result_pt2e = model(*example_inputs)
|
||||
|
||||
model_pt2e = capture_pre_autograd_graph(
|
||||
model_pt2e = export_for_training(
|
||||
model,
|
||||
example_inputs,
|
||||
)
|
||||
).module()
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_module_type(torch.nn.Linear, None)
|
||||
|
Reference in New Issue
Block a user