Use new type statement to fix public API of types (#158487)

Since type statement breaks older python version, trying to find equivalent behavior without the type mechanics.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158487
Approved by: https://github.com/andrewor14
This commit is contained in:
albanD
2025-07-17 18:46:39 +00:00
committed by PyTorch MergeBot
parent ad223a6c5f
commit 25f4d7e482
3 changed files with 43 additions and 10 deletions

View File

@ -33,9 +33,15 @@ from .stubs import * # noqa: F403
# ensure __module__ is set correctly for public APIs
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
if sys.version_info < (3, 14):
if sys.version_info < (3, 12):
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
else:
from typing import TypeAliasType
ObserverOrFakeQuantize = TypeAliasType(
"ObserverOrFakeQuantize", Union[ObserverBase, FakeQuantizeBase]
)
for _f in [
compare_results,

View File

@ -568,9 +568,13 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
)
QConfigAny = Optional[QConfig]
if sys.version_info < (3, 14):
if sys.version_info < (3, 12):
QConfigAny = Optional[QConfig]
QConfigAny.__module__ = "torch.ao.quantization.qconfig"
else:
from typing import TypeAliasType
QConfigAny = TypeAliasType("QConfigAny", Optional[QConfig])
def _add_module_to_qconfig_obs_ctr(

View File

@ -16,9 +16,16 @@ from torch.fx import Node
from torch.nn.utils.parametrize import is_parametrized
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
if sys.version_info < (3, 14):
if sys.version_info < (3, 12):
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
NodePattern.__module__ = "torch.ao.quantization.utils"
else:
from typing import TypeAliasType
NodePattern = TypeAliasType(
"NodePattern", Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
)
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
# Define separately to prevent circular imports.
@ -30,11 +37,27 @@ QuantizerCls = Any
# Type for fusion patterns, it can be more complicated than the following actually,
# see pattern.md for docs
# TODO: not sure if typing supports recursive data types
Pattern = Union[
Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any
]
if sys.version_info < (3, 14):
if sys.version_info < (3, 12):
Pattern = Union[
Callable,
tuple[Callable, Callable],
tuple[Callable, tuple[Callable, Callable]],
Any,
]
Pattern.__module__ = "torch.ao.quantization.utils"
else:
from typing import TypeAliasType
Pattern = TypeAliasType(
"Pattern",
Union[
Callable,
tuple[Callable, Callable],
tuple[Callable, tuple[Callable, Callable]],
Any,
],
)
# TODO: maybe rename this to MatchInputNode