[BE] Delete all pre py-3.10 checks (#163653)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163653
Approved by: https://github.com/jansel
ghstack dependencies: #163648, #163649
This commit is contained in:
Nikita Shulga
2025-09-23 12:20:31 -07:00
committed by PyTorch MergeBot
parent f3f67ff43a
commit f9fa138a39
23 changed files with 64 additions and 488 deletions

View File

@ -6,7 +6,7 @@
# files. # files.
[mypy] [mypy]
python_version = 3.9 python_version = 3.10
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/strict cache_dir = .mypy_cache/strict

View File

@ -188,7 +188,6 @@ ignore = [
# TODO: Remove Python-3.10 specific suppressions # TODO: Remove Python-3.10 specific suppressions
"B905", "B905",
"UP035", "UP035",
"UP036",
"FURB161", "FURB161",
] ]
select = [ select = [

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
import sys
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -305,23 +304,16 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
backends_group = "torch_dynamo_backends" backends_group = "torch_dynamo_backends"
name = "mycustombackend" name = "mycustombackend"
mock_3_9 = MagicMock()
mock_3_9.load.return_value = lambda: "mocked 3.9"
mock_3_9.name = name
mock_3_10 = MagicMock() mock_3_10 = MagicMock()
mock_3_10.load.return_value = lambda: "mocked 3.10" mock_3_10.load.return_value = lambda: "mocked 3.10"
def mock_eps(group=None): def mock_eps(group=None):
if sys.version_info < (3, 10): assert group == backends_group, group
return {backends_group: [mock_3_9]} mock_group = MagicMock()
else: mock_group.names = [name]
assert group == backends_group, group mock_group[name] = mock_3_10
mock_group = MagicMock() # mock_group[name].load.return_value = lambda: "mocked 3.10"
mock_group.names = [name] return mock_group
mock_group[name] = mock_3_10
# mock_group[name].load.return_value = lambda: "mocked 3.10"
return mock_group
with patch("importlib.metadata.entry_points", mock_eps): with patch("importlib.metadata.entry_points", mock_eps):
from torch._dynamo.backends import registry from torch._dynamo.backends import registry

View File

@ -50,7 +50,6 @@ from torch._dynamo.testing import (
CompileCounter, CompileCounter,
CompileCounterWithBackend, CompileCounterWithBackend,
expectedFailureDynamic, expectedFailureDynamic,
requiresPy310,
same, same,
skipIfNotPy311, skipIfNotPy311,
unsupported, unsupported,
@ -10615,7 +10614,6 @@ def ___make_guard_fn():
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@requiresPy310
def test_frozen_dataclass_kw_only(self): def test_frozen_dataclass_kw_only(self):
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TestDataClass: class TestDataClass:

View File

@ -48,280 +48,7 @@ CURRENT_FILE_NAME = os.path.basename(__file__)
_MODULE_NAME_ALLOW_LIST: set[str] = set() _MODULE_NAME_ALLOW_LIST: set[str] = set()
# Add builtin modules. # Add builtin modules.
if sys.version_info >= (3, 10): _MODULE_NAME_ALLOW_LIST.update(sys.stdlib_module_names)
_MODULE_NAME_ALLOW_LIST.update(sys.stdlib_module_names)
else:
assert (sys.version_info.major, sys.version_info.minor) == (3, 9)
# Taken from `stdlib_list("3.9")` to avoid introducing a new dependency.
_MODULE_NAME_ALLOW_LIST.update(
[
"__future__",
"_abc",
"_aix_support",
"_ast",
"_bootlocale",
"_bootsubprocess",
"_codecs",
"_collections",
"_collections_abc",
"_compat_pickle",
"_compression",
"_crypt",
"_functools",
"_hashlib",
"_imp",
"_io",
"_locale",
"_lsprof",
"_markupbase",
"_operator",
"_osx_support",
"_peg_parser",
"_posixsubprocess",
"_py_abc",
"_pydecimal",
"_pyio",
"_random",
"_signal",
"_sitebuiltins",
"_socket",
"_sre",
"_ssl",
"_stat",
"_string",
"_strptime",
"_symtable",
"_sysconfigdata_x86_64_conda_cos6_linux_gnu",
"_sysconfigdata_x86_64_conda_linux_gnu",
"_thread",
"_threading_local",
"_tracemalloc",
"_uuid",
"_warnings",
"_weakref",
"_weakrefset",
"abc",
"aifc",
"antigravity",
"argparse",
"array",
"ast",
"asynchat",
"asyncio",
"asyncore",
"atexit",
"audioop",
"base64",
"bdb",
"binascii",
"binhex",
"bisect",
"builtins",
"bz2",
"cProfile",
"calendar",
"cgi",
"cgitb",
"chunk",
"cmath",
"cmd",
"code",
"codecs",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"crypt",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"distutils",
"doctest",
"email",
"encodings",
"ensurepip",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"formatter",
"fractions",
"ftplib",
"functools",
"gc",
"genericpath",
"getopt",
"getpass",
"gettext",
"glob",
"graphlib",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"idlelib",
"imaplib",
"imghdr",
"imp",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"msilib",
"msvcrt",
"multiprocessing",
"netrc",
"nis",
"nntplib",
"ntpath",
"nturl2path",
"numbers",
"opcode",
"operator",
"optparse",
"os",
"ossaudiodev",
"parser",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posix",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"pydoc_data",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtpd",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"spwd",
"sqlite3",
"sre_compile",
"sre_constants",
"sre_parse",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"sunau",
"symbol",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"telnetlib",
"tempfile",
"termios",
"test",
"textwrap",
"this",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"turtledemo",
"types",
"typing",
"unicodedata",
"unittest",
"urllib",
"uu",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"winreg",
"winsound",
"wsgiref",
"xdrlib",
"xml",
"xmlrpc",
"xxsubtype",
"zipapp",
"zipfile",
"zipimport",
"zlib",
"zoneinfo",
]
)
# Add the allowed third party libraries. Please avoid updating this unless you # Add the allowed third party libraries. Please avoid updating this unless you
# understand the risks -- see `_ERROR_MESSAGE` for why. # understand the risks -- see `_ERROR_MESSAGE` for why.

View File

@ -2819,10 +2819,7 @@ def _import_device_backends():
from importlib.metadata import entry_points from importlib.metadata import entry_points
group_name = "torch.backends" group_name = "torch.backends"
if sys.version_info < (3, 10): backend_extensions = entry_points(group=group_name)
backend_extensions = entry_points().get(group_name, ())
else:
backend_extensions = entry_points(group=group_name)
for backend_extension in backend_extensions: for backend_extension in backend_extensions:
try: try:

View File

@ -60,7 +60,6 @@ optimized_model = torch.compile(model, backend="my_compiler")
import functools import functools
import logging import logging
import sys
from collections.abc import Sequence from collections.abc import Sequence
from importlib.metadata import EntryPoint from importlib.metadata import EntryPoint
from typing import Any, Callable, Optional, Protocol, Union from typing import Any, Callable, Optional, Protocol, Union
@ -174,12 +173,7 @@ def _discover_entrypoint_backends() -> None:
from importlib.metadata import entry_points from importlib.metadata import entry_points
group_name = "torch_dynamo_backends" group_name = "torch_dynamo_backends"
if sys.version_info < (3, 10): eps = entry_points(group=group_name)
eps = entry_points() eps_dict = {name: eps[name] for name in eps.names}
eps = eps[group_name] if group_name in eps else []
eps_dict = {ep.name: ep for ep in eps}
else:
eps = entry_points(group=group_name)
eps_dict = {name: eps[name] for name in eps.names}
for backend_name in eps_dict: for backend_name in eps_dict:
_BACKENDS[backend_name] = eps_dict[backend_name] _BACKENDS[backend_name] = eps_dict[backend_name]

View File

@ -251,22 +251,6 @@ def create_rot_n(n: int) -> list[Instruction]:
# e.g. rotate 3 is equivalent to swap 3, swap 2 # e.g. rotate 3 is equivalent to swap 3, swap 2
return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)]
# ROT_N does not exist in Python <= 3.9, but we can simulate it
if sys.version_info < (3, 10) and n >= 5:
"""
0 1 2 3 4
[0 1 2 3 4]
4 3 2 1 0
4 [3 2 1 0]
4 0 1 2 3
"""
return [
create_instruction("BUILD_TUPLE", arg=n),
create_instruction("UNPACK_SEQUENCE", arg=n),
create_instruction("BUILD_TUPLE", arg=n - 1),
create_instruction("UNPACK_SEQUENCE", arg=n - 1),
]
if n <= 4: if n <= 4:
return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])]
return [create_instruction("ROT_N", arg=n)] return [create_instruction("ROT_N", arg=n)]
@ -545,30 +529,6 @@ def create_print_value(value: Any) -> list[Instruction]:
] ]
def lnotab_writer(
lineno: int, byteno: int = 0
) -> tuple[list[int], Callable[[int, int], None]]:
"""
Used to create typing.CodeType.co_lnotab
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
This is the internal format of the line number table if Python < 3.10
"""
assert sys.version_info < (3, 10)
lnotab: list[int] = []
def update(lineno_new: int, byteno_new: int) -> None:
nonlocal byteno, lineno
while byteno_new != byteno or lineno_new != lineno:
byte_offset = max(0, min(byteno_new - byteno, 255))
line_offset = max(-128, min(lineno_new - lineno, 127))
assert byte_offset != 0 or line_offset != 0
byteno += byte_offset
lineno += line_offset
lnotab.extend((byte_offset, line_offset & 0xFF))
return lnotab, update
def linetable_310_writer( def linetable_310_writer(
first_lineno: int, first_lineno: int,
) -> tuple[list[int], Callable[[int, int], None], Callable[[int], None]]: ) -> tuple[list[int], Callable[[int, int], None], Callable[[int], None]]:
@ -577,7 +537,7 @@ def linetable_310_writer(
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
This is the internal format of the line number table for Python 3.10 This is the internal format of the line number table for Python 3.10
""" """
assert sys.version_info >= (3, 10) and sys.version_info < (3, 11) assert sys.version_info[:2] == (3, 10)
linetable: list[int] = [] linetable: list[int] = []
lineno = first_lineno lineno = first_lineno
lineno_delta = 0 lineno_delta = 0
@ -799,10 +759,7 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes,
for _ in range(instruction_size(inst) // 2 - 1): for _ in range(instruction_size(inst) // 2 - 1):
code.extend((0, 0)) code.extend((0, 0))
else: else:
if sys.version_info < (3, 10): lnotab, update_lineno, end = linetable_310_writer(firstlineno)
lnotab, update_lineno = lnotab_writer(firstlineno)
else:
lnotab, update_lineno, end = linetable_310_writer(firstlineno)
for inst in instructions: for inst in instructions:
if inst.starts_line is not None: if inst.starts_line is not None:
@ -810,8 +767,7 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes,
arg = inst.arg or 0 arg = inst.arg or 0
code.extend((inst.opcode, arg & 0xFF)) code.extend((inst.opcode, arg & 0xFF))
if sys.version_info >= (3, 10): end(len(code))
end(len(code))
return bytes(code), bytes(lnotab) return bytes(code), bytes(lnotab)
@ -903,9 +859,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None:
assert inst.target is not None assert inst.target is not None
target = _get_instruction_front(instructions, indexof[inst.target]) target = _get_instruction_front(instructions, indexof[inst.target])
if inst.opcode in dis.hasjabs: if inst.opcode in dis.hasjabs:
if sys.version_info < (3, 10): if sys.version_info < (3, 11):
inst.arg = target.offset
elif sys.version_info < (3, 11):
# `arg` is expected to be bytecode offset, whereas `offset` is byte offset. # `arg` is expected to be bytecode offset, whereas `offset` is byte offset.
# Divide since bytecode is 2 bytes large. # Divide since bytecode is 2 bytes large.
inst.arg = int(target.offset / 2) inst.arg = int(target.offset / 2)
@ -917,9 +871,8 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None:
inst.arg = abs( inst.arg = abs(
int(target.offset - inst.offset - instruction_size(inst)) int(target.offset - inst.offset - instruction_size(inst))
) )
if sys.version_info >= (3, 10): # see bytecode size comment in the absolute jump case above
# see bytecode size comment in the absolute jump case above inst.arg //= 2
inst.arg //= 2
inst.argval = target.offset inst.argval = target.offset
inst.argrepr = f"to {target.offset}" inst.argrepr = f"to {target.offset}"
@ -1558,10 +1511,7 @@ def get_code_keys() -> list[str]:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
keys.append("co_qualname") keys.append("co_qualname")
keys.append("co_firstlineno") keys.append("co_firstlineno")
if sys.version_info >= (3, 10): keys.append("co_linetable")
keys.append("co_linetable")
else:
keys.append("co_lnotab")
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
# not documented, but introduced in https://github.com/python/cpython/issues/84403 # not documented, but introduced in https://github.com/python/cpython/issues/84403
keys.append("co_exceptiontable") keys.append("co_exceptiontable")
@ -1618,11 +1568,8 @@ def clean_and_assemble_instructions(
remove_extra_line_nums(instructions) remove_extra_line_nums(instructions)
bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"]) bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"])
if sys.version_info < (3, 10):
code_options["co_lnotab"] = lnotab
else:
code_options["co_linetable"] = lnotab
code_options["co_linetable"] = lnotab
code_options["co_code"] = bytecode code_options["co_code"] = bytecode
code_options["co_stacksize"] = stacksize_analysis(instructions) code_options["co_stacksize"] = stacksize_analysis(instructions)
assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - { assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - {

View File

@ -6,7 +6,6 @@ from __future__ import annotations
import itertools import itertools
import operator import operator
import sys
from typing import Callable, Optional, overload, TYPE_CHECKING, TypeVar from typing import Callable, Optional, overload, TYPE_CHECKING, TypeVar
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -26,6 +25,7 @@ __all__ = [
"dropwhile", "dropwhile",
"filterfalse", "filterfalse",
"islice", "islice",
"pairwise",
"tee", "tee",
"zip_longest", "zip_longest",
] ]
@ -163,20 +163,16 @@ def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
# Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise # Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise
if sys.version_info >= (3, 10): @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type]
def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]:
@substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] a = None
def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: first = True
a = None for b in iterable:
first = True if first:
for b in iterable: first = False
if first: else:
first = False yield a, b # type: ignore[misc]
else: a = b
yield a, b # type: ignore[misc]
a = b
__all__ += ["pairwise"]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee

View File

@ -465,9 +465,6 @@ def stack_op(fn: Callable[..., object]) -> Callable[..., Any]:
def is_stdlib(mod: object) -> bool: def is_stdlib(mod: object) -> bool:
if sys.version_info < (3, 10):
# For < 3.10, no easy way to identify a stdlib module name.
return False
if not isinstance(mod, types.ModuleType): if not isinstance(mod, types.ModuleType):
return False return False
return mod.__name__.split(".")[0] in sys.stdlib_module_names return mod.__name__.split(".")[0] in sys.stdlib_module_names
@ -3769,18 +3766,17 @@ class InstructionTranslatorBase(
self.package = package self.package = package
if sys.version_info >= (3, 10): from .resume_execution import (
from .resume_execution import ( CO_ASYNC_GENERATOR,
CO_ASYNC_GENERATOR, CO_COROUTINE,
CO_COROUTINE, CO_GENERATOR,
CO_GENERATOR, CO_ITERABLE_COROUTINE,
CO_ITERABLE_COROUTINE, )
)
if f_code.co_flags & ( if f_code.co_flags & (
CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
): ):
self.push(BuiltinVariable(None)) self.push(BuiltinVariable(None))
self.inline_depth = inline_depth self.inline_depth = inline_depth
self.inconsistent_side_effects = False self.inconsistent_side_effects = False

View File

@ -517,13 +517,6 @@ def skipIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
return fn return fn
def requiresPy310(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 10):
return fn
else:
return unittest.skip("Requires Python 3.10+")(fn)
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
# and test/dynamo/test_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py
def expectedFailureDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]: def expectedFailureDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]:

View File

@ -16,7 +16,6 @@ handling of iterator operations during code transformation and optimization.
""" """
import itertools import itertools
import sys
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
from .. import graph_break_hints, polyfills, variables from .. import graph_break_hints, polyfills, variables
@ -442,17 +441,14 @@ class ZipVariable(IteratorVariable):
codegen.append_output( codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(self.iterables)) create_instruction("BUILD_TUPLE", arg=len(self.iterables))
) )
if sys.version_info >= (3, 10): codegen.extend_output(
codegen.extend_output( [
[ codegen.create_load_const("strict"),
codegen.create_load_const("strict"), codegen.create_load_const(self.strict),
codegen.create_load_const(self.strict), create_instruction("BUILD_MAP", arg=1),
create_instruction("BUILD_MAP", arg=1), create_instruction("CALL_FUNCTION_EX", arg=1),
create_instruction("CALL_FUNCTION_EX", arg=1), ]
] )
)
else:
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
class MapVariable(ZipVariable): class MapVariable(ZipVariable):

View File

@ -1326,11 +1326,7 @@ class _InProcessFxCompile(FxCompile):
metrics_context = get_metrics_context() metrics_context = get_metrics_context()
if metrics_context.in_progress(): if metrics_context.in_progress():
# TODO: Remove this when 3.9 is no longer supported num_graph_breaks = counters["graph_break"].total()
if sys.version_info < (3, 10):
num_graph_breaks = sum(counters["graph_break"].values())
else:
num_graph_breaks = counters["graph_break"].total()
CompileEventLogger.compilation_metric( CompileEventLogger.compilation_metric(
overwrite=True, num_graph_breaks=num_graph_breaks overwrite=True, num_graph_breaks=num_graph_breaks
) )

View File

@ -5,7 +5,6 @@ import pickle
import random import random
import signal import signal
import string import string
import sys
import traceback import traceback
from collections.abc import KeysView, Sequence from collections.abc import KeysView, Sequence
from enum import Enum from enum import Enum
@ -610,9 +609,6 @@ class ConfigFuzzer:
sm: How type value samples are generated, default TOGGLE. sm: How type value samples are generated, default TOGGLE.
test_timeout: max time a test can take. test_timeout: max time a test can take.
""" """
if sys.version_info < (3, 10):
log.error("Only python 3.10 and later supported")
return
self.seed = seed self.seed = seed
self.test_timeout = test_timeout self.test_timeout = test_timeout
self.detailed_results: dict[ComboType, dict[str, Any]] = {} self.detailed_results: dict[ComboType, dict[str, Any]] = {}

View File

@ -1,9 +1,6 @@
import dataclasses import dataclasses
import functools
import itertools import itertools
import sys
from collections import Counter, defaultdict from collections import Counter, defaultdict
from collections.abc import Iterable, Iterator
from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union
import sympy import sympy
@ -373,20 +370,6 @@ class NodeSplitGetter:
return pw, red return pw, red
if sys.version_info >= (3, 10):
# On Python 3.10+ we can use zip(strict=True)
zip_equal = functools.partial(zip, strict=True)
else:
# Fallback for older versions
def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]:
"""
Zip two iterables, raising ValueError if their lengths differ.
"""
if len(it1) != len(it2):
raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}")
return zip(it1, it2)
def apply_var_mapping( def apply_var_mapping(
iter_vars: list[sympy.Symbol], iter_vars: list[sympy.Symbol],
red_vars: list[sympy.Symbol], red_vars: list[sympy.Symbol],
@ -424,7 +407,7 @@ def apply_var_mapping(
iter_vars_to_flat_vars = {} iter_vars_to_flat_vars = {}
for i, (group, var_group) in enumerate( for i, (group, var_group) in enumerate(
zip_equal(apply_groups, ((iter_vars, red_vars))) zip(apply_groups, (iter_vars, red_vars), strict=True)
): ):
# if the node has sizes (p0, 1) and the fused node is (p0, r0) # if the node has sizes (p0, 1) and the fused node is (p0, r0)
# the reduction var gets filled in for split_iteration_range # the reduction var gets filled in for split_iteration_range
@ -437,7 +420,9 @@ def apply_var_mapping(
count = 0 count = 0
flat_vars_to_new_vars = {} flat_vars_to_new_vars = {}
for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars): for new_range, new_var in zip(
new_ranges, norm_pw_vars + norm_red_vars, strict=True
):
range_vars = [] range_vars = []
for i in range(len(new_range)): for i in range(len(new_range)):
range_vars.append(flat_vars[count]) range_vars.append(flat_vars[count])

View File

@ -3335,12 +3335,7 @@ class ScopedDict(MutableMapping[KeyType, ValType]):
@dataclass_transform(frozen_default=True) @dataclass_transform(frozen_default=True)
def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any: def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
def wrap(cls: _T) -> _T: def wrap(cls: _T) -> _T:
if sys.version_info >= (3, 10): return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
else:
# Polyfill for python=3.9. kw_only simply introduces an extra check
# that only kwargs are used (and is not available on 3.9)
return dataclasses.dataclass(cls, frozen=frozen)
if cls is None: if cls is None:
return wrap return wrap

View File

@ -52,15 +52,8 @@ from torch.futures import Future
_P = ParamSpec("_P") _P = ParamSpec("_P")
_R = TypeVar("_R") _R = TypeVar("_R")
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
BuiltinUnionType: Union[type, tuple[type, ...]] BuiltinUnionType: Union[type, tuple[type, ...]]
if sys.version_info >= (3, 10): BuiltinUnionType = types.UnionType
# NOTE: IS_PY310_PLUS doesn't work with mypy.
# cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
BuiltinUnionType = types.UnionType
else:
BuiltinUnionType = () # trick: this makes isinstance short circuit.
LockType: type LockType: type
try: try:
@ -1257,12 +1250,9 @@ def _get_named_tuple_properties(
defaults = [] defaults = []
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
# Also, annotations from base class are not inherited so they need to be queried explicitly # Also, annotations from base class are not inherited so they need to be queried explicitly
if sys.version_info[:2] < (3, 10): obj_annotations = inspect.get_annotations(obj)
obj_annotations = getattr(obj, "__annotations__", {}) if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
else: obj_annotations = inspect.get_annotations(obj.__base__)
obj_annotations = inspect.get_annotations(obj)
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
obj_annotations = inspect.get_annotations(obj.__base__)
annotations = [] annotations = []
for field in obj._fields: for field in obj._fields:

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
import sys from importlib.metadata import entry_points
from .api import ( from .api import (
rendezvous_handler_registry as handler_registry, rendezvous_handler_registry as handler_registry,
@ -15,11 +15,6 @@ from .api import (
from .dynamic_rendezvous import create_handler from .dynamic_rendezvous import create_handler
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
__all__ = ["get_rendezvous_handler"] __all__ = ["get_rendezvous_handler"]

View File

@ -2,7 +2,6 @@
import collections import collections
import functools import functools
import inspect import inspect
import sys
import textwrap import textwrap
import types import types
import warnings import warnings
@ -158,8 +157,6 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
def get_annotations(obj): def get_annotations(obj):
if sys.version_info < (3, 10):
return getattr(obj, "__annotations__", {})
# In Python-3.10+ it is recommended to use inspect.get_annotations # In Python-3.10+ it is recommended to use inspect.get_annotations
# See https://docs.python.org/3.10/howto/annotations.html # See https://docs.python.org/3.10/howto/annotations.html
# But also, in 3.10 annotations from base class are not inherited # But also, in 3.10 annotations from base class are not inherited

View File

@ -6,7 +6,6 @@ import inspect
import io import io
import os import os
import pickle import pickle
import sys
import tokenize import tokenize
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
@ -178,10 +177,7 @@ def install_config_module(module: ModuleType) -> None:
prefix: str, prefix: str,
) -> None: ) -> None:
"""Walk the module structure and move everything to module._config""" """Walk the module structure and move everything to module._config"""
if sys.version_info[:2] < (3, 10): type_hints = inspect.get_annotations(source)
type_hints = getattr(source, "__annotations__", {})
else:
type_hints = inspect.get_annotations(source)
for key, value in list(source.__dict__.items()): for key, value in list(source.__dict__.items()):
if ( if (
key.startswith("__") key.startswith("__")

View File

@ -570,10 +570,7 @@ def tree_map_(
Type2 = tuple[type[T], type[S]] Type2 = tuple[type[T], type[S]]
Type3 = tuple[type[T], type[S], type[U]] Type3 = tuple[type[T], type[S], type[U]]
if sys.version_info >= (3, 10): TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
else:
TypeAny = Union[type[Any], tuple[type[Any], ...]]
Fn2 = Callable[[Union[T, S]], R] Fn2 = Callable[[Union[T, S]], R]
Fn3 = Callable[[Union[T, S, U]], R] Fn3 = Callable[[Union[T, S, U]], R]

View File

@ -1412,10 +1412,7 @@ def tree_map_(
Type2 = tuple[type[T], type[S]] Type2 = tuple[type[T], type[S]]
Type3 = tuple[type[T], type[S], type[U]] Type3 = tuple[type[T], type[S], type[U]]
if sys.version_info >= (3, 10): TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
else:
TypeAny = Union[type[Any], tuple[type[Any], ...]]
Fn2 = Callable[[Union[T, S]], R] Fn2 = Callable[[Union[T, S]], R]
Fn3 = Callable[[Union[T, S, U]], R] Fn3 = Callable[[Union[T, S, U]], R]

View File

@ -355,12 +355,9 @@ def dataclass_repr(
width: int = 80, width: int = 80,
) -> str: ) -> str:
# built-in pprint module support dataclasses from python 3.10 # built-in pprint module support dataclasses from python 3.10
if sys.version_info >= (3, 10): from pprint import pformat
from pprint import pformat
return pformat(obj, indent, width) return pformat(obj, indent, width)
return _pformat(obj, indent=indent, width=width)
def _pformat( def _pformat(