[quant][graphmode][fx] Support preserving attributes in deepcopy of observed/quantized graphmodule (#56550)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56550

Add support for preserving a list of attributes on observed/quantized GraphModule

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_deepcopy_preserve_attributes

Imported from OSS

Reviewed By: vkuzo, kazhang

Differential Revision: D27899317

fbshipit-source-id: ebf21334715e5ab764aaa27eed534cc0cdf9f2b5
This commit is contained in:
Jerry Zhang
2021-04-22 15:01:07 -07:00
committed by Facebook GitHub Bot
parent 3a44d269ac
commit 1719cb82f3
5 changed files with 95 additions and 32 deletions

View File

@ -8,7 +8,7 @@ from .fx.utils import graph_pretty_str # noqa: F401
from .fx.utils import get_custom_module_class_keys # noqa: F401
from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule
from torch.nn.intrinsic import _FusedModule
from typing import Dict, Any, List, Callable, Tuple, Optional
from typing import Dict, Any, List, Callable, Tuple, Optional, Set
def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
@ -167,9 +167,13 @@ forward graph of the parent module,
float_custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
skipped_module_classes += float_custom_module_classes
preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", [])
tracer = QuantizationTracer(
skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
quantizer = Quantizer()
prepared = quantizer.prepare(
@ -179,7 +183,6 @@ forward graph of the parent module,
prepare_custom_config_dict=prepare_custom_config_dict,
is_standalone_module=is_standalone_module)
preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", [])
for attr_name in preserved_attributes:
setattr(prepared, attr_name, getattr(model, attr_name))
return prepared
@ -221,6 +224,12 @@ def fuse_fx(model: torch.nn.Module,
"additional_fuser_method_mapping": {
(Module1, Module2): fuse_module1_module2
}
# Attributes that are not used in forward function will
# be removed when constructing GraphModule, this is a list of attributes
# to preserve as an attribute of the GraphModule even when they are
# not used in the code, these attributes will also persist through deepcopy
"preserved_attributes": ["preserved_attr"],
}
Example:
@ -233,6 +242,11 @@ def fuse_fx(model: torch.nn.Module,
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
assert not model.training, 'fuse_fx only works on models in eval mode'
graph_module = torch.fx.symbolic_trace(model)
preserved_attributes: Set[str] = set()
if fuse_custom_config_dict:
preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
return _fuse_fx(graph_module, fuse_custom_config_dict)
def prepare_fx(
@ -344,7 +358,7 @@ def prepare_fx(
# Attributes that are not used in forward function will
# be removed when constructing GraphModule, this is a list of attributes
# to preserve as an attribute of the GraphModule even when they are
# not used in the code
# not used in the code, these attributes will also persist through deepcopy
"preserved_attributes": ["preserved_attr"],
}
@ -359,7 +373,6 @@ def prepare_fx(
from torch.quantization import prepare_fx
float_model.eval()
graph_module = torch.fx.symbolic_trace(float_model)
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
@ -368,7 +381,7 @@ def prepare_fx(
model(image)
qconfig_dict = {"": qconfig}
prepared_model = prepare_fx(graph_module, qconfig_dict)
prepared_model = prepare_fx(float_model, qconfig_dict)
# Run calibration
calibrate(prepared_model, sample_inference_data)
```