[ao] fixing public v private for quantization_types (#86031)

Summary: the main problem with this was that the different objects
defined simply as 'Any' should theoretically be public but making them
public either A) results in an error about the module being 'typing'
rather than whatever module it should be or B) you set the module
manually, thereby changing the module for the original 'Any' class.

note: QuantizeHandler has a similar issue where its simply defined as
'Any'

Pattern was defined in multiple places which was causing issues so i just moved it to a single
place given the note at the top of quantization_types.py indicating
these definitions should be moved to utils at some point anyway.

Finally i changed any references to these objects to point at the
correct locations. Note: i didn't see any fb internal references to
NodePattern or QuantizerCls that would cause issues.

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86031
Approved by: https://github.com/jerryzh168
This commit is contained in:
HDCharles
2022-10-12 10:04:04 -07:00
committed by PyTorch MergeBot
parent ef58a132f2
commit 25476f2e4b
14 changed files with 28 additions and 47 deletions

View File

@ -35,15 +35,6 @@
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
},
"torch.ao.quantization.quantization_types": [
"Any",
"Node",
"NodePattern",
"Pattern",
"QuantizerCls",
"Tuple",
"Union"
],
"torch.ao.quantization.fx.graph_module": [
"Any",
"Dict",

View File

@ -169,7 +169,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
# we removed matching test for torch.quantization.fx.quantization_types
# old: torch.quantization.fx.quantization_types
# new: torch.ao.quantization.quantization_types
# new: torch.ao.quantization.utils
# both are valid, but we'll deprecate the old path in the future
def test_package_import_fx_utils(self):

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .backend_config import BackendConfig, DTypeConfig
from ..quantization_types import Pattern
from ..utils import Pattern
__all__ = [
"get_pattern_to_dtype_configs",

View File

@ -2,10 +2,7 @@ import torch.nn as nn
import torch.nn.intrinsic as nni
from typing import Union, Callable, Tuple, Dict, Optional, Type
from torch.ao.quantization.utils import Pattern
from torch.ao.quantization.utils import get_combined_dict
from torch.ao.quantization.utils import MatchAllNode
from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode
import itertools
__all__ = [

View File

@ -4,14 +4,12 @@ from torch.ao.quantization.backend_config import (
get_native_backend_config,
ObservationType,
)
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern,
QuantizerCls,
)
from torch.ao.quantization.utils import (
activation_dtype,
get_combined_dict,
Pattern,
NodePattern,
QuantizerCls,
)
from ..backend_config import BackendConfig

View File

@ -33,7 +33,7 @@ from .fusion_patterns import * # noqa: F401,F403
from typing import Any, Callable, Dict, List, Tuple, Union
import warnings
from torch.ao.quantization.quantization_types import Pattern, NodePattern
from torch.ao.quantization.utils import Pattern, NodePattern
__all__ = [

View File

@ -1,7 +1,6 @@
import torch
from torch.fx.graph import Node, Graph
from ..utils import _parent_name
from torch.ao.quantization.quantization_types import NodePattern, Pattern
from ..utils import _parent_name, NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method_new
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union, List

View File

@ -4,7 +4,7 @@ from torch.fx.graph import (
Graph,
Node,
)
from torch.ao.quantization.quantization_types import Pattern
from torch.ao.quantization.utils import Pattern
from .quantization_patterns import (
QuantizeHandler,
)

View File

@ -1,6 +1,6 @@
from collections import OrderedDict
from typing import Dict, Any
from torch.ao.quantization.quantization_types import Pattern
from torch.ao.quantization.utils import Pattern
from ..fake_quantize import FixedQParamsFakeQuantize
# from .quantization_patterns import BinaryOpQuantizeHandler
from ..observer import ObserverBase

View File

@ -41,7 +41,7 @@ from .quantization_patterns import (
QuantizeHandler,
)
from torch.ao.quantization.quantization_types import (
from torch.ao.quantization.utils import (
Pattern,
NodePattern,
)

View File

@ -6,7 +6,7 @@ from torch.fx.graph import (
from .utils import (
all_node_args_have_no_tensors,
)
from torch.ao.quantization.quantization_types import (
from torch.ao.quantization.utils import (
Pattern,
NodePattern,
)

View File

@ -1,18 +0,0 @@
# TODO: the name of this file is probably confusing, remove this file and move the type
# definitions to somewhere else, e.g. to .utils
from typing import Any, Tuple, Union
from torch.fx import Node
from .utils import Pattern # noqa: F401
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
# Define separately to prevent circular imports.
# TODO(future PR): improve this.
QuantizerCls = Any
__all__ = [
"Pattern",
"NodePattern",
"QuantizerCls",
]

View File

@ -4,6 +4,7 @@ Utils shared by different modes of quantization (eager/graph)
import warnings
import functools
import torch
from torch.fx import Node
from torch.ao.quantization.quant_type import QuantType
from typing import Tuple, Any, Union, Callable, Dict, Optional
from torch.nn.utils.parametrize import is_parametrized
@ -11,10 +12,22 @@ from collections import OrderedDict
from inspect import signature
from inspect import getfullargspec
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
NodePattern.__module__ = "torch.ao.quantization.utils"
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
# Define separately to prevent circular imports.
# TODO(future PR): improve this.
# make this public once fixed (can't be public as is because setting the module directly
# doesn't work)
QuantizerCls = Any
# Type for fusion patterns, it can be more complicated than the following actually,
# see pattern.md for docs
# TODO: not sure if typing supports recursive data types
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
Pattern = Union[
Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
]
Pattern.__module__ = "torch.ao.quantization.utils"
# TODO: maybe rename this to MatchInputNode
@ -524,6 +537,7 @@ def get_fqn_to_example_inputs(
__all__ = [
"NodePattern",
"Pattern",
"MatchAllNode",
"check_node",

View File

@ -6,7 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.quantization_types import (
from torch.ao.quantization.utils import (
Pattern,
QuantizerCls
)