mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao][fx] fixing public v private graph_module.py (#88395)
Summary: made _is_observed_module, _is_observed_standalone_module private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D41015545](https://our.internmc.facebook.com/intern/diff/D41015545) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88395 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
283cf718ed
commit
f286cbebce
@ -46,9 +46,9 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
function_list = [
|
||||
'FusedGraphModule',
|
||||
'ObservedGraphModule',
|
||||
'is_observed_module',
|
||||
'_is_observed_module',
|
||||
'ObservedStandaloneGraphModule',
|
||||
'is_observed_standalone_module',
|
||||
'_is_observed_standalone_module',
|
||||
'QuantizedGraphModule'
|
||||
]
|
||||
self._test_function_import('fx.graph_module', function_list)
|
||||
|
@ -42,8 +42,8 @@ from torch.ao.quantization.backend_config import (
|
||||
)
|
||||
from .graph_module import (
|
||||
QuantizedGraphModule,
|
||||
is_observed_module,
|
||||
is_observed_standalone_module,
|
||||
_is_observed_module,
|
||||
_is_observed_standalone_module,
|
||||
)
|
||||
from ._equalize import update_obs_for_equalization, convert_eq_obs
|
||||
from torch.nn.utils.parametrize import type_before_parametrizations
|
||||
@ -450,7 +450,7 @@ def _restore_state(
|
||||
) -> Tuple[Dict[str, Tuple[str, type]],
|
||||
PrepareCustomConfig,
|
||||
Set[str]]:
|
||||
assert is_observed_module(observed), \
|
||||
assert _is_observed_module(observed), \
|
||||
'incoming model must be produced by prepare_fx'
|
||||
prepare_custom_config: PrepareCustomConfig = observed._prepare_custom_config # type: ignore[assignment]
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
|
||||
@ -1017,7 +1017,7 @@ def convert(
|
||||
node_name_to_qconfig)
|
||||
elif isinstance(mod, DeQuantStub):
|
||||
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
||||
elif is_observed_standalone_module(mod):
|
||||
elif _is_observed_standalone_module(mod):
|
||||
convert_standalone_module(
|
||||
node, modules, model, is_reference, backend_config)
|
||||
# below this point `type_before_parametrizations` is used
|
||||
|
@ -7,9 +7,7 @@ from typing import Union, Dict, Any, Set
|
||||
__all__ = [
|
||||
"FusedGraphModule",
|
||||
"ObservedGraphModule",
|
||||
"is_observed_module",
|
||||
"ObservedStandaloneGraphModule",
|
||||
"is_observed_standalone_module",
|
||||
"QuantizedGraphModule",
|
||||
]
|
||||
|
||||
@ -56,7 +54,7 @@ class ObservedGraphModule(GraphModule):
|
||||
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
||||
return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
|
||||
|
||||
def is_observed_module(module: Any) -> bool:
|
||||
def _is_observed_module(module: Any) -> bool:
|
||||
return isinstance(module, ObservedGraphModule)
|
||||
|
||||
class ObservedStandaloneGraphModule(ObservedGraphModule):
|
||||
@ -71,7 +69,7 @@ class ObservedStandaloneGraphModule(ObservedGraphModule):
|
||||
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
||||
return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
|
||||
|
||||
def is_observed_standalone_module(module: Any) -> bool:
|
||||
def _is_observed_standalone_module(module: Any) -> bool:
|
||||
return isinstance(module, ObservedStandaloneGraphModule)
|
||||
|
||||
def _save_packed_weight(self, destination, prefix, keep_vars):
|
||||
|
@ -15,7 +15,7 @@ from ..utils import (
|
||||
MatchAllNode
|
||||
)
|
||||
from .graph_module import (
|
||||
is_observed_standalone_module,
|
||||
_is_observed_standalone_module,
|
||||
)
|
||||
from torch.nn.utils.parametrize import type_before_parametrizations
|
||||
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
|
||||
@ -228,7 +228,7 @@ def _find_matches(
|
||||
for node in graph.nodes:
|
||||
if node.op == 'call_module' and \
|
||||
(is_standalone_module(node.target, modules) or
|
||||
is_observed_standalone_module(modules[node.target])):
|
||||
_is_observed_standalone_module(modules[node.target])):
|
||||
# add node to matched nodes
|
||||
match_map[node.name] = (
|
||||
node, node, None,
|
||||
|
@ -10,8 +10,8 @@ from torch.ao.quantization.fx.graph_module import (
|
||||
GraphModule,
|
||||
FusedGraphModule,
|
||||
ObservedGraphModule,
|
||||
is_observed_module,
|
||||
_is_observed_module,
|
||||
ObservedStandaloneGraphModule,
|
||||
is_observed_standalone_module,
|
||||
_is_observed_standalone_module,
|
||||
QuantizedGraphModule
|
||||
)
|
||||
|
Reference in New Issue
Block a user