[quant][fx][graphmode][be] Change the type for output of convert to be torch.nn.Module (#69959)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69959

GraphModule is an implementation detail, We don't want to expose it in quantization apis

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_quantized_model_type

Imported from OSS

Reviewed By: supriyar

Differential Revision: D33119103

fbshipit-source-id: d8736ff08b42ee009d6cfd74dcb3f6150f71f3d2
This commit is contained in:
Jerry Zhang
2021-12-29 20:31:48 -08:00
committed by Facebook GitHub Bot
parent fb78a31916
commit c627211651
6 changed files with 37 additions and 11 deletions

View File

@ -2034,6 +2034,33 @@ class TestQuantizeFx(QuantizationTestCase):
# quantize, should run with no errors
quantized = convert_fx(prepared_copy)
def test_quantized_model_type(self):
""" Test state_dict and deepcopy works properly in the quantized model
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
data = torch.rand(8, 5)
m = M().eval()
m = prepare_fx(m, {"": default_qconfig})
m = convert_fx(m)
# test deepcopy
m_copy = copy.deepcopy(m)
self.assertEqual(m_copy(data), m(data))
# test state_dict
state_dict = m.state_dict()
m_new = M().eval()
m_new = prepare_fx(m_new, {"": default_qconfig})
m_new = convert_fx(m_new)
m_new.load_state_dict(state_dict)
self.assertEqual(m_new(data), m(data))
def test_dequantize(self):
r""" Test to make sure dequantize node are placed before
non-quantizable node

View File

@ -1,17 +1,17 @@
from typing import Dict, Any, Optional
import torch
from torch.fx import GraphModule
from typing import Dict, Any, Optional
from .quantize_fx import (
_check_is_graph_module,
check_is_valid_convert_custom_config_dict
)
from .fx.graph_module import QuantizedGraphModule
from .fx._convert_do_not_use import _convert_do_not_use
def _convert_fx_do_not_use(
graph_module: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
_remove_qconfig: bool = True,
backend_config_dict: Optional[Dict[str, Any]] = None) -> QuantizedGraphModule:
backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
"""
Please do not use, this is a temporary function to migrate convert_fx
to a new implementation

View File

@ -59,7 +59,7 @@ def _convert_do_not_use(
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True,
backend_config_dict: Optional[Dict[str, Any]] = None) -> QuantizedGraphModule:
backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
"""
We will convert an observed model (a module with observer calls) to a reference
quantized model, the rule is simple:

View File

@ -206,7 +206,7 @@ def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphMo
def restore_state(
observed: GraphModule
observed: torch.nn.Module
) -> Tuple[Dict[Pattern, QuantizeHandler],
Dict[str, Tuple[str, type]],
Dict[str, Any],
@ -224,7 +224,7 @@ def convert(model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True,
convert_qconfig_dict: Dict[str, Any] = None) -> QuantizedGraphModule:
convert_qconfig_dict: Dict[str, Any] = None) -> torch.nn.Module:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.

View File

@ -8,7 +8,7 @@ from torch.nn.intrinsic import _FusedModule
from .fx import Fuser # noqa: F401
from .fx import prepare, convert # noqa: F401
from .fx import get_tensorrt_backend_config_dict # noqa: F401
from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule
from .fx.graph_module import ObservedGraphModule
from .fx.qconfig_utils import (
check_is_valid_convert_custom_config_dict,
check_is_valid_fuse_custom_config_dict,
@ -570,7 +570,7 @@ def _convert_fx(
is_standalone_module: bool = False,
_remove_qconfig: bool = True,
qconfig_dict: Dict[str, Any] = None,
) -> QuantizedGraphModule:
) -> torch.nn.Module:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
if convert_custom_config_dict is None:
@ -600,7 +600,7 @@ def convert_fx(
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
_remove_qconfig: bool = True,
qconfig_dict: Dict[str, Any] = None,
) -> QuantizedGraphModule:
) -> torch.nn.Module:
r""" Convert a calibrated or trained model to a quantized model
Args:
@ -694,7 +694,7 @@ def _convert_standalone_module_fx(
graph_module: GraphModule,
is_reference: bool = False,
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
) -> QuantizedGraphModule:
) -> torch.nn.Module:
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
and convert it to a quantized model

View File

@ -26,5 +26,4 @@ from torch.ao.quantization.quantize_fx import (
from torch.ao.quantization.fx.graph_module import (
ObservedGraphModule,
QuantizedGraphModule
)