[ao][fx] fixing public v private graph_module.py (#88395)

Summary: made _is_observed_module, _is_observed_standalone_module
private

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D41015545](https://our.internmc.facebook.com/intern/diff/D41015545)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88395
Approved by: https://github.com/jcaip
This commit is contained in:
HDCharles
2022-12-13 18:20:12 -08:00
committed by PyTorch MergeBot
parent 283cf718ed
commit f286cbebce
5 changed files with 12 additions and 14 deletions

View File

@ -46,9 +46,9 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
function_list = [
'FusedGraphModule',
'ObservedGraphModule',
'is_observed_module',
'_is_observed_module',
'ObservedStandaloneGraphModule',
'is_observed_standalone_module',
'_is_observed_standalone_module',
'QuantizedGraphModule'
]
self._test_function_import('fx.graph_module', function_list)

View File

@ -42,8 +42,8 @@ from torch.ao.quantization.backend_config import (
)
from .graph_module import (
QuantizedGraphModule,
is_observed_module,
is_observed_standalone_module,
_is_observed_module,
_is_observed_standalone_module,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from torch.nn.utils.parametrize import type_before_parametrizations
@ -450,7 +450,7 @@ def _restore_state(
) -> Tuple[Dict[str, Tuple[str, type]],
PrepareCustomConfig,
Set[str]]:
assert is_observed_module(observed), \
assert _is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
prepare_custom_config: PrepareCustomConfig = observed._prepare_custom_config # type: ignore[assignment]
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
@ -1017,7 +1017,7 @@ def convert(
node_name_to_qconfig)
elif isinstance(mod, DeQuantStub):
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
elif is_observed_standalone_module(mod):
elif _is_observed_standalone_module(mod):
convert_standalone_module(
node, modules, model, is_reference, backend_config)
# below this point `type_before_parametrizations` is used

View File

@ -7,9 +7,7 @@ from typing import Union, Dict, Any, Set
__all__ = [
"FusedGraphModule",
"ObservedGraphModule",
"is_observed_module",
"ObservedStandaloneGraphModule",
"is_observed_standalone_module",
"QuantizedGraphModule",
]
@ -56,7 +54,7 @@ class ObservedGraphModule(GraphModule):
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
def is_observed_module(module: Any) -> bool:
def _is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)
class ObservedStandaloneGraphModule(ObservedGraphModule):
@ -71,7 +69,7 @@ class ObservedStandaloneGraphModule(ObservedGraphModule):
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
def is_observed_standalone_module(module: Any) -> bool:
def _is_observed_standalone_module(module: Any) -> bool:
return isinstance(module, ObservedStandaloneGraphModule)
def _save_packed_weight(self, destination, prefix, keep_vars):

View File

@ -15,7 +15,7 @@ from ..utils import (
MatchAllNode
)
from .graph_module import (
is_observed_standalone_module,
_is_observed_standalone_module,
)
from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
@ -228,7 +228,7 @@ def _find_matches(
for node in graph.nodes:
if node.op == 'call_module' and \
(is_standalone_module(node.target, modules) or
is_observed_standalone_module(modules[node.target])):
_is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
match_map[node.name] = (
node, node, None,

View File

@ -10,8 +10,8 @@ from torch.ao.quantization.fx.graph_module import (
GraphModule,
FusedGraphModule,
ObservedGraphModule,
is_observed_module,
_is_observed_module,
ObservedStandaloneGraphModule,
is_observed_standalone_module,
_is_observed_standalone_module,
QuantizedGraphModule
)