mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao] fixing public v private for quantization_types (#86031)
Summary: the main problem with this was that the different objects defined simply as 'Any' should theoretically be public but making them public either A) results in an error about the module being 'typing' rather than whatever module it should be or B) you set the module manually, thereby changing the module for the original 'Any' class. note: QuantizeHandler has a similar issue where its simply defined as 'Any' Pattern was defined in multiple places which was causing issues so i just moved it to a single place given the note at the top of quantization_types.py indicating these definitions should be moved to utils at some point anyway. Finally i changed any references to these objects to point at the correct locations. Note: i didn't see any fb internal references to NodePattern or QuantizerCls that would cause issues. Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/86031 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef58a132f2
commit
25476f2e4b
@ -35,15 +35,6 @@
|
||||
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
|
||||
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
|
||||
},
|
||||
"torch.ao.quantization.quantization_types": [
|
||||
"Any",
|
||||
"Node",
|
||||
"NodePattern",
|
||||
"Pattern",
|
||||
"QuantizerCls",
|
||||
"Tuple",
|
||||
"Union"
|
||||
],
|
||||
"torch.ao.quantization.fx.graph_module": [
|
||||
"Any",
|
||||
"Dict",
|
||||
|
@ -169,7 +169,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
|
||||
# we removed matching test for torch.quantization.fx.quantization_types
|
||||
# old: torch.quantization.fx.quantization_types
|
||||
# new: torch.ao.quantization.quantization_types
|
||||
# new: torch.ao.quantization.utils
|
||||
# both are valid, but we'll deprecate the old path in the future
|
||||
|
||||
def test_package_import_fx_utils(self):
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .backend_config import BackendConfig, DTypeConfig
|
||||
from ..quantization_types import Pattern
|
||||
from ..utils import Pattern
|
||||
|
||||
__all__ = [
|
||||
"get_pattern_to_dtype_configs",
|
||||
|
@ -2,10 +2,7 @@ import torch.nn as nn
|
||||
import torch.nn.intrinsic as nni
|
||||
|
||||
from typing import Union, Callable, Tuple, Dict, Optional, Type
|
||||
from torch.ao.quantization.utils import Pattern
|
||||
|
||||
from torch.ao.quantization.utils import get_combined_dict
|
||||
from torch.ao.quantization.utils import MatchAllNode
|
||||
from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode
|
||||
import itertools
|
||||
|
||||
__all__ = [
|
||||
|
@ -4,14 +4,12 @@ from torch.ao.quantization.backend_config import (
|
||||
get_native_backend_config,
|
||||
ObservationType,
|
||||
)
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
Pattern,
|
||||
NodePattern,
|
||||
QuantizerCls,
|
||||
)
|
||||
from torch.ao.quantization.utils import (
|
||||
activation_dtype,
|
||||
get_combined_dict,
|
||||
Pattern,
|
||||
NodePattern,
|
||||
QuantizerCls,
|
||||
)
|
||||
|
||||
from ..backend_config import BackendConfig
|
||||
|
@ -33,7 +33,7 @@ from .fusion_patterns import * # noqa: F401,F403
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from torch.ao.quantization.quantization_types import Pattern, NodePattern
|
||||
from torch.ao.quantization.utils import Pattern, NodePattern
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from torch.fx.graph import Node, Graph
|
||||
from ..utils import _parent_name
|
||||
from torch.ao.quantization.quantization_types import NodePattern, Pattern
|
||||
from ..utils import _parent_name, NodePattern, Pattern
|
||||
from ..fuser_method_mappings import get_fuser_method_new
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
@ -4,7 +4,7 @@ from torch.fx.graph import (
|
||||
Graph,
|
||||
Node,
|
||||
)
|
||||
from torch.ao.quantization.quantization_types import Pattern
|
||||
from torch.ao.quantization.utils import Pattern
|
||||
from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Any
|
||||
from torch.ao.quantization.quantization_types import Pattern
|
||||
from torch.ao.quantization.utils import Pattern
|
||||
from ..fake_quantize import FixedQParamsFakeQuantize
|
||||
# from .quantization_patterns import BinaryOpQuantizeHandler
|
||||
from ..observer import ObserverBase
|
||||
|
@ -41,7 +41,7 @@ from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
from torch.ao.quantization.utils import (
|
||||
Pattern,
|
||||
NodePattern,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ from torch.fx.graph import (
|
||||
from .utils import (
|
||||
all_node_args_have_no_tensors,
|
||||
)
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
from torch.ao.quantization.utils import (
|
||||
Pattern,
|
||||
NodePattern,
|
||||
)
|
||||
|
@ -1,18 +0,0 @@
|
||||
# TODO: the name of this file is probably confusing, remove this file and move the type
|
||||
# definitions to somewhere else, e.g. to .utils
|
||||
from typing import Any, Tuple, Union
|
||||
from torch.fx import Node
|
||||
from .utils import Pattern # noqa: F401
|
||||
|
||||
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
|
||||
|
||||
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
|
||||
# Define separately to prevent circular imports.
|
||||
# TODO(future PR): improve this.
|
||||
QuantizerCls = Any
|
||||
|
||||
__all__ = [
|
||||
"Pattern",
|
||||
"NodePattern",
|
||||
"QuantizerCls",
|
||||
]
|
@ -4,6 +4,7 @@ Utils shared by different modes of quantization (eager/graph)
|
||||
import warnings
|
||||
import functools
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
from torch.ao.quantization.quant_type import QuantType
|
||||
from typing import Tuple, Any, Union, Callable, Dict, Optional
|
||||
from torch.nn.utils.parametrize import is_parametrized
|
||||
@ -11,10 +12,22 @@ from collections import OrderedDict
|
||||
from inspect import signature
|
||||
from inspect import getfullargspec
|
||||
|
||||
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
|
||||
NodePattern.__module__ = "torch.ao.quantization.utils"
|
||||
|
||||
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
|
||||
# Define separately to prevent circular imports.
|
||||
# TODO(future PR): improve this.
|
||||
# make this public once fixed (can't be public as is because setting the module directly
|
||||
# doesn't work)
|
||||
QuantizerCls = Any
|
||||
|
||||
# Type for fusion patterns, it can be more complicated than the following actually,
|
||||
# see pattern.md for docs
|
||||
# TODO: not sure if typing supports recursive data types
|
||||
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
|
||||
Pattern = Union[
|
||||
Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
|
||||
]
|
||||
Pattern.__module__ = "torch.ao.quantization.utils"
|
||||
|
||||
# TODO: maybe rename this to MatchInputNode
|
||||
@ -524,6 +537,7 @@ def get_fqn_to_example_inputs(
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NodePattern",
|
||||
"Pattern",
|
||||
"MatchAllNode",
|
||||
"check_node",
|
||||
|
@ -6,7 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
from torch.ao.quantization.utils import (
|
||||
Pattern,
|
||||
QuantizerCls
|
||||
)
|
||||
|
Reference in New Issue
Block a user