mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
[quant][graphmode][fx] Support preserving attributes in deepcopy of observed/quantized graphmodule (#56550)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56550 Add support for preserving a list of attributes on observed/quantized GraphModule Test Plan: python test/test_quantization.py TestQuantizeFx.test_deepcopy_preserve_attributes Imported from OSS Reviewed By: vkuzo, kazhang Differential Revision: D27899317 fbshipit-source-id: ebf21334715e5ab764aaa27eed534cc0cdf9f2b5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3a44d269ac
commit
1719cb82f3
@ -8,7 +8,7 @@ from .fx.utils import graph_pretty_str # noqa: F401
|
||||
from .fx.utils import get_custom_module_class_keys # noqa: F401
|
||||
from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule
|
||||
from torch.nn.intrinsic import _FusedModule
|
||||
from typing import Dict, Any, List, Callable, Tuple, Optional
|
||||
from typing import Dict, Any, List, Callable, Tuple, Optional, Set
|
||||
|
||||
def _check_is_graph_module(model: torch.nn.Module) -> None:
|
||||
if not isinstance(model, GraphModule):
|
||||
@ -167,9 +167,13 @@ forward graph of the parent module,
|
||||
float_custom_module_classes = get_custom_module_class_keys(
|
||||
prepare_custom_config_dict, "float_to_observed_custom_module_class")
|
||||
skipped_module_classes += float_custom_module_classes
|
||||
|
||||
preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", [])
|
||||
tracer = QuantizationTracer(
|
||||
skipped_module_names, skipped_module_classes)
|
||||
graph_module = GraphModule(model, tracer.trace(model))
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
|
||||
quantizer = Quantizer()
|
||||
prepared = quantizer.prepare(
|
||||
@ -179,7 +183,6 @@ forward graph of the parent module,
|
||||
prepare_custom_config_dict=prepare_custom_config_dict,
|
||||
is_standalone_module=is_standalone_module)
|
||||
|
||||
preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", [])
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(prepared, attr_name, getattr(model, attr_name))
|
||||
return prepared
|
||||
@ -221,6 +224,12 @@ def fuse_fx(model: torch.nn.Module,
|
||||
"additional_fuser_method_mapping": {
|
||||
(Module1, Module2): fuse_module1_module2
|
||||
}
|
||||
|
||||
# Attributes that are not used in forward function will
|
||||
# be removed when constructing GraphModule, this is a list of attributes
|
||||
# to preserve as an attribute of the GraphModule even when they are
|
||||
# not used in the code, these attributes will also persist through deepcopy
|
||||
"preserved_attributes": ["preserved_attr"],
|
||||
}
|
||||
|
||||
Example:
|
||||
@ -233,6 +242,11 @@ def fuse_fx(model: torch.nn.Module,
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
|
||||
assert not model.training, 'fuse_fx only works on models in eval mode'
|
||||
graph_module = torch.fx.symbolic_trace(model)
|
||||
preserved_attributes: Set[str] = set()
|
||||
if fuse_custom_config_dict:
|
||||
preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||
return _fuse_fx(graph_module, fuse_custom_config_dict)
|
||||
|
||||
def prepare_fx(
|
||||
@ -344,7 +358,7 @@ def prepare_fx(
|
||||
# Attributes that are not used in forward function will
|
||||
# be removed when constructing GraphModule, this is a list of attributes
|
||||
# to preserve as an attribute of the GraphModule even when they are
|
||||
# not used in the code
|
||||
# not used in the code, these attributes will also persist through deepcopy
|
||||
"preserved_attributes": ["preserved_attr"],
|
||||
}
|
||||
|
||||
@ -359,7 +373,6 @@ def prepare_fx(
|
||||
from torch.quantization import prepare_fx
|
||||
|
||||
float_model.eval()
|
||||
graph_module = torch.fx.symbolic_trace(float_model)
|
||||
qconfig = get_default_qconfig('fbgemm')
|
||||
def calibrate(model, data_loader):
|
||||
model.eval()
|
||||
@ -368,7 +381,7 @@ def prepare_fx(
|
||||
model(image)
|
||||
|
||||
qconfig_dict = {"": qconfig}
|
||||
prepared_model = prepare_fx(graph_module, qconfig_dict)
|
||||
prepared_model = prepare_fx(float_model, qconfig_dict)
|
||||
# Run calibration
|
||||
calibrate(prepared_model, sample_inference_data)
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user