mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][fx][graphmode][be] Change the type for output of convert to be torch.nn.Module (#69959)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69959 GraphModule is an implementation detail, We don't want to expose it in quantization apis Test Plan: python test/test_quantization.py TestQuantizeFx.test_quantized_model_type Imported from OSS Reviewed By: supriyar Differential Revision: D33119103 fbshipit-source-id: d8736ff08b42ee009d6cfd74dcb3f6150f71f3d2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fb78a31916
commit
c627211651
@ -2034,6 +2034,33 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
# quantize, should run with no errors
|
||||
quantized = convert_fx(prepared_copy)
|
||||
|
||||
def test_quantized_model_type(self):
|
||||
""" Test state_dict and deepcopy works properly in the quantized model
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
data = torch.rand(8, 5)
|
||||
m = M().eval()
|
||||
m = prepare_fx(m, {"": default_qconfig})
|
||||
m = convert_fx(m)
|
||||
# test deepcopy
|
||||
m_copy = copy.deepcopy(m)
|
||||
self.assertEqual(m_copy(data), m(data))
|
||||
|
||||
# test state_dict
|
||||
state_dict = m.state_dict()
|
||||
m_new = M().eval()
|
||||
m_new = prepare_fx(m_new, {"": default_qconfig})
|
||||
m_new = convert_fx(m_new)
|
||||
m_new.load_state_dict(state_dict)
|
||||
self.assertEqual(m_new(data), m(data))
|
||||
|
||||
def test_dequantize(self):
|
||||
r""" Test to make sure dequantize node are placed before
|
||||
non-quantizable node
|
||||
|
@ -1,17 +1,17 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from typing import Dict, Any, Optional
|
||||
from .quantize_fx import (
|
||||
_check_is_graph_module,
|
||||
check_is_valid_convert_custom_config_dict
|
||||
)
|
||||
from .fx.graph_module import QuantizedGraphModule
|
||||
from .fx._convert_do_not_use import _convert_do_not_use
|
||||
|
||||
def _convert_fx_do_not_use(
|
||||
graph_module: GraphModule, is_reference: bool = False,
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
_remove_qconfig: bool = True,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None) -> QuantizedGraphModule:
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
|
||||
"""
|
||||
Please do not use, this is a temporary function to migrate convert_fx
|
||||
to a new implementation
|
||||
|
@ -59,7 +59,7 @@ def _convert_do_not_use(
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig_flag: bool = True,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None) -> QuantizedGraphModule:
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
|
||||
"""
|
||||
We will convert an observed model (a module with observer calls) to a reference
|
||||
quantized model, the rule is simple:
|
||||
|
@ -206,7 +206,7 @@ def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphMo
|
||||
|
||||
|
||||
def restore_state(
|
||||
observed: GraphModule
|
||||
observed: torch.nn.Module
|
||||
) -> Tuple[Dict[Pattern, QuantizeHandler],
|
||||
Dict[str, Tuple[str, type]],
|
||||
Dict[str, Any],
|
||||
@ -224,7 +224,7 @@ def convert(model: GraphModule, is_reference: bool = False,
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig_flag: bool = True,
|
||||
convert_qconfig_dict: Dict[str, Any] = None) -> QuantizedGraphModule:
|
||||
convert_qconfig_dict: Dict[str, Any] = None) -> torch.nn.Module:
|
||||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
|
@ -8,7 +8,7 @@ from torch.nn.intrinsic import _FusedModule
|
||||
from .fx import Fuser # noqa: F401
|
||||
from .fx import prepare, convert # noqa: F401
|
||||
from .fx import get_tensorrt_backend_config_dict # noqa: F401
|
||||
from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule
|
||||
from .fx.graph_module import ObservedGraphModule
|
||||
from .fx.qconfig_utils import (
|
||||
check_is_valid_convert_custom_config_dict,
|
||||
check_is_valid_fuse_custom_config_dict,
|
||||
@ -570,7 +570,7 @@ def _convert_fx(
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_dict: Dict[str, Any] = None,
|
||||
) -> QuantizedGraphModule:
|
||||
) -> torch.nn.Module:
|
||||
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
||||
"""
|
||||
if convert_custom_config_dict is None:
|
||||
@ -600,7 +600,7 @@ def convert_fx(
|
||||
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_dict: Dict[str, Any] = None,
|
||||
) -> QuantizedGraphModule:
|
||||
) -> torch.nn.Module:
|
||||
r""" Convert a calibrated or trained model to a quantized model
|
||||
|
||||
Args:
|
||||
@ -694,7 +694,7 @@ def _convert_standalone_module_fx(
|
||||
graph_module: GraphModule,
|
||||
is_reference: bool = False,
|
||||
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> QuantizedGraphModule:
|
||||
) -> torch.nn.Module:
|
||||
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
||||
and convert it to a quantized model
|
||||
|
||||
|
@ -26,5 +26,4 @@ from torch.ao.quantization.quantize_fx import (
|
||||
|
||||
from torch.ao.quantization.fx.graph_module import (
|
||||
ObservedGraphModule,
|
||||
QuantizedGraphModule
|
||||
)
|
||||
|
Reference in New Issue
Block a user