Change to export_for_training in XNNPACK tests (#137238)

Summary: as title

Test Plan: CI

Differential Revision: D63344674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137238
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Shangdi Yu
2024-10-03 21:28:05 +00:00
committed by PyTorch MergeBot
parent ce14f1f0c9
commit c83178d894
2 changed files with 24 additions and 27 deletions

View File

@ -4,7 +4,6 @@ import operator
import torch
import torch._dynamo as torchdynamo
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import (
default_dynamic_fake_quant,
@ -682,19 +681,17 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
}
capture_pre_autograd_graph_node_occurrence = None
if capture_pre_autograd_graph_using_training_ir():
capture_pre_autograd_graph_node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
# In training IR, the decomposition is different.
# `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes
# `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes.
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
}
training_ir_node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
# In training IR, the decomposition is different.
# `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes
# `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes.
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
}
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_tensor_affine,
@ -718,7 +715,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
capture_pre_autograd_graph_node_occurrence=capture_pre_autograd_graph_node_occurrence,
training_ir_node_occurrence=training_ir_node_occurrence,
)
def test_gru(self):