Change to export_for_training in XNNPACK tests (#137238)

Summary: as title

Test Plan: CI

Differential Revision: D63344674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137238
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Shangdi Yu
2024-10-03 21:28:05 +00:00
committed by PyTorch MergeBot
parent ce14f1f0c9
commit c83178d894
2 changed files with 24 additions and 27 deletions

View File

@ -14,7 +14,7 @@ from torch.ao.nn.intrinsic import _FusedModule
import torch.distributed as dist
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
from torch._export import capture_pre_autograd_graph
from torch.export import export_for_training
from torch.ao.quantization import (
QuantType,
default_dynamic_qat_qconfig,
@ -1247,7 +1247,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
export_with_dynamic_shape=False,
is_qat=False,
is_debug_mode=False,
capture_pre_autograd_graph_node_occurrence=None,
training_ir_node_occurrence=None,
):
# resetting dynamo cache
torch._dynamo.reset()
@ -1259,11 +1259,11 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
m = capture_pre_autograd_graph(
m = export_for_training(
m,
example_inputs,
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
)
).module()
if is_qat:
m = prepare_qat_pt2e(m, quantizer)
@ -1297,18 +1297,18 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
m_fx = _convert_to_reference_decomposed_fx(
m_fx, backend_config=backend_config
)
m_fx = capture_pre_autograd_graph(
m_fx = export_for_training(
m_fx,
example_inputs,
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
)
).module()
node_occurrence = {}
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
if k in expected_node_occurrence:
node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
if capture_pre_autograd_graph_node_occurrence is not None:
if training_ir_node_occurrence is not None:
node_occurrence = {
ns.call_function(k): v for k, v in capture_pre_autograd_graph_node_occurrence.items()
ns.call_function(k): v for k, v in training_ir_node_occurrence.items()
}
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
fx_quant_output = m_fx(*example_inputs)
@ -1319,10 +1319,10 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
# resetting dynamo cache
torch._dynamo.reset()
m = capture_pre_autograd_graph(
m = export_for_training(
m,
example_inputs,
)
).module()
if is_qat:
m = prepare_qat_pt2e(m, quantizer)
else:
@ -2953,10 +2953,10 @@ def _generate_qdq_quantized_model(
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
with maybe_no_grad:
export_model = capture_pre_autograd_graph(
export_model = export_for_training(
mod,
inputs,
)
).module()
quantizer = (
quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic)
)