mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change to export_for_training in XNNPACK tests (#137238)
Summary: as title Test Plan: CI Differential Revision: D63344674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137238 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce14f1f0c9
commit
c83178d894
@ -14,7 +14,7 @@ from torch.ao.nn.intrinsic import _FusedModule
|
||||
import torch.distributed as dist
|
||||
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
|
||||
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch.export import export_for_training
|
||||
from torch.ao.quantization import (
|
||||
QuantType,
|
||||
default_dynamic_qat_qconfig,
|
||||
@ -1247,7 +1247,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
export_with_dynamic_shape=False,
|
||||
is_qat=False,
|
||||
is_debug_mode=False,
|
||||
capture_pre_autograd_graph_node_occurrence=None,
|
||||
training_ir_node_occurrence=None,
|
||||
):
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
@ -1259,11 +1259,11 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
{0: torch.export.Dim("dim")} if i == 0 else None
|
||||
for i in range(len(example_inputs))
|
||||
)
|
||||
m = capture_pre_autograd_graph(
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
|
||||
)
|
||||
).module()
|
||||
|
||||
if is_qat:
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
@ -1297,18 +1297,18 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
m_fx = _convert_to_reference_decomposed_fx(
|
||||
m_fx, backend_config=backend_config
|
||||
)
|
||||
m_fx = capture_pre_autograd_graph(
|
||||
m_fx = export_for_training(
|
||||
m_fx,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
|
||||
)
|
||||
).module()
|
||||
node_occurrence = {}
|
||||
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
|
||||
if k in expected_node_occurrence:
|
||||
node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
|
||||
if capture_pre_autograd_graph_node_occurrence is not None:
|
||||
if training_ir_node_occurrence is not None:
|
||||
node_occurrence = {
|
||||
ns.call_function(k): v for k, v in capture_pre_autograd_graph_node_occurrence.items()
|
||||
ns.call_function(k): v for k, v in training_ir_node_occurrence.items()
|
||||
}
|
||||
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
|
||||
fx_quant_output = m_fx(*example_inputs)
|
||||
@ -1319,10 +1319,10 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
|
||||
m = capture_pre_autograd_graph(
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
).module()
|
||||
if is_qat:
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
else:
|
||||
@ -2953,10 +2953,10 @@ def _generate_qdq_quantized_model(
|
||||
|
||||
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
|
||||
with maybe_no_grad:
|
||||
export_model = capture_pre_autograd_graph(
|
||||
export_model = export_for_training(
|
||||
mod,
|
||||
inputs,
|
||||
)
|
||||
).module()
|
||||
quantizer = (
|
||||
quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic)
|
||||
)
|
||||
|
Reference in New Issue
Block a user