mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change to export_for_training in XNNPACK tests (#137238)
Summary: as title Test Plan: CI Differential Revision: D63344674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137238 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce14f1f0c9
commit
c83178d894
@ -4,7 +4,6 @@ import operator
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
from torch.ao.quantization import (
|
||||
default_dynamic_fake_quant,
|
||||
@ -682,19 +681,17 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
}
|
||||
|
||||
capture_pre_autograd_graph_node_occurrence = None
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
capture_pre_autograd_graph_node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
# In training IR, the decomposition is different.
|
||||
# `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes
|
||||
# `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes.
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
|
||||
}
|
||||
training_ir_node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
# In training IR, the decomposition is different.
|
||||
# `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes
|
||||
# `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes.
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
|
||||
}
|
||||
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
|
||||
dtype=torch.qint8,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
@ -718,7 +715,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
capture_pre_autograd_graph_node_occurrence=capture_pre_autograd_graph_node_occurrence,
|
||||
training_ir_node_occurrence=training_ir_node_occurrence,
|
||||
)
|
||||
|
||||
def test_gru(self):
|
||||
|
Reference in New Issue
Block a user