mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use new type statement to fix public API of types (#158487)
Since type statement breaks older python version, trying to find equivalent behavior without the type mechanics. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158487 Approved by: https://github.com/andrewor14
This commit is contained in:
@ -33,9 +33,15 @@ from .stubs import * # noqa: F403
|
||||
|
||||
|
||||
# ensure __module__ is set correctly for public APIs
|
||||
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
|
||||
if sys.version_info < (3, 14):
|
||||
if sys.version_info < (3, 12):
|
||||
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
|
||||
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
|
||||
else:
|
||||
from typing import TypeAliasType
|
||||
|
||||
ObserverOrFakeQuantize = TypeAliasType(
|
||||
"ObserverOrFakeQuantize", Union[ObserverBase, FakeQuantizeBase]
|
||||
)
|
||||
|
||||
for _f in [
|
||||
compare_results,
|
||||
|
@ -568,9 +568,13 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
|
||||
)
|
||||
|
||||
|
||||
QConfigAny = Optional[QConfig]
|
||||
if sys.version_info < (3, 14):
|
||||
if sys.version_info < (3, 12):
|
||||
QConfigAny = Optional[QConfig]
|
||||
QConfigAny.__module__ = "torch.ao.quantization.qconfig"
|
||||
else:
|
||||
from typing import TypeAliasType
|
||||
|
||||
QConfigAny = TypeAliasType("QConfigAny", Optional[QConfig])
|
||||
|
||||
|
||||
def _add_module_to_qconfig_obs_ctr(
|
||||
|
@ -16,9 +16,16 @@ from torch.fx import Node
|
||||
from torch.nn.utils.parametrize import is_parametrized
|
||||
|
||||
|
||||
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
|
||||
if sys.version_info < (3, 14):
|
||||
if sys.version_info < (3, 12):
|
||||
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
|
||||
NodePattern.__module__ = "torch.ao.quantization.utils"
|
||||
else:
|
||||
from typing import TypeAliasType
|
||||
|
||||
NodePattern = TypeAliasType(
|
||||
"NodePattern", Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
|
||||
)
|
||||
|
||||
|
||||
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
|
||||
# Define separately to prevent circular imports.
|
||||
@ -30,11 +37,27 @@ QuantizerCls = Any
|
||||
# Type for fusion patterns, it can be more complicated than the following actually,
|
||||
# see pattern.md for docs
|
||||
# TODO: not sure if typing supports recursive data types
|
||||
Pattern = Union[
|
||||
Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any
|
||||
]
|
||||
if sys.version_info < (3, 14):
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
Pattern = Union[
|
||||
Callable,
|
||||
tuple[Callable, Callable],
|
||||
tuple[Callable, tuple[Callable, Callable]],
|
||||
Any,
|
||||
]
|
||||
Pattern.__module__ = "torch.ao.quantization.utils"
|
||||
else:
|
||||
from typing import TypeAliasType
|
||||
|
||||
Pattern = TypeAliasType(
|
||||
"Pattern",
|
||||
Union[
|
||||
Callable,
|
||||
tuple[Callable, Callable],
|
||||
tuple[Callable, tuple[Callable, Callable]],
|
||||
Any,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# TODO: maybe rename this to MatchInputNode
|
||||
|
Reference in New Issue
Block a user