From f9fa138a3910bd1de1e7acb95265fa040672a952 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 23 Sep 2025 12:20:31 -0700 Subject: [PATCH] [BE] Delete all pre py-3.10 checks (#163653) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163653 Approved by: https://github.com/jansel ghstack dependencies: #163648, #163649 --- mypy-strict.ini | 2 +- pyproject.toml | 1 - test/dynamo/test_backends.py | 20 +- test/dynamo/test_misc.py | 2 - tools/linter/adapters/import_linter.py | 275 +----------------- torch/__init__.py | 5 +- torch/_dynamo/backends/registry.py | 10 +- torch/_dynamo/bytecode_transformation.py | 69 +---- torch/_dynamo/polyfills/itertools.py | 26 +- torch/_dynamo/symbolic_convert.py | 24 +- torch/_dynamo/testing.py | 7 - torch/_dynamo/variables/iter.py | 20 +- torch/_inductor/compile_fx.py | 6 +- torch/_inductor/fuzzer.py | 4 - torch/_inductor/tiling_utils.py | 23 +- torch/_inductor/utils.py | 7 +- torch/_jit_internal.py | 18 +- .../elastic/rendezvous/registry.py | 7 +- torch/jit/_recursive.py | 3 - torch/utils/_config_module.py | 6 +- torch/utils/_cxx_pytree.py | 5 +- torch/utils/_pytree.py | 5 +- torchgen/utils.py | 7 +- 23 files changed, 64 insertions(+), 488 deletions(-) diff --git a/mypy-strict.ini b/mypy-strict.ini index dddbb623047f..11e520d9ad82 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -6,7 +6,7 @@ # files. [mypy] -python_version = 3.9 +python_version = 3.10 plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin cache_dir = .mypy_cache/strict diff --git a/pyproject.toml b/pyproject.toml index 5e4202260e6a..321af034f854 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -188,7 +188,6 @@ ignore = [ # TODO: Remove Python-3.10 specific suppressions "B905", "UP035", - "UP036", "FURB161", ] select = [ diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 30c354b6ec11..243bb7b94fca 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import sys import unittest from unittest.mock import MagicMock, patch @@ -305,23 +304,16 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase): backends_group = "torch_dynamo_backends" name = "mycustombackend" - mock_3_9 = MagicMock() - mock_3_9.load.return_value = lambda: "mocked 3.9" - mock_3_9.name = name - mock_3_10 = MagicMock() mock_3_10.load.return_value = lambda: "mocked 3.10" def mock_eps(group=None): - if sys.version_info < (3, 10): - return {backends_group: [mock_3_9]} - else: - assert group == backends_group, group - mock_group = MagicMock() - mock_group.names = [name] - mock_group[name] = mock_3_10 - # mock_group[name].load.return_value = lambda: "mocked 3.10" - return mock_group + assert group == backends_group, group + mock_group = MagicMock() + mock_group.names = [name] + mock_group[name] = mock_3_10 + # mock_group[name].load.return_value = lambda: "mocked 3.10" + return mock_group with patch("importlib.metadata.entry_points", mock_eps): from torch._dynamo.backends import registry diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 61ecf52c3e37..f4168f7754c3 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -50,7 +50,6 @@ from torch._dynamo.testing import ( CompileCounter, CompileCounterWithBackend, expectedFailureDynamic, - requiresPy310, same, skipIfNotPy311, unsupported, @@ -10615,7 +10614,6 @@ def ___make_guard_fn(): self.assertEqual(actual, expected) - @requiresPy310 def test_frozen_dataclass_kw_only(self): @dataclasses.dataclass(frozen=True) class TestDataClass: diff --git a/tools/linter/adapters/import_linter.py b/tools/linter/adapters/import_linter.py index 1b24556a03bd..90a485456b9a 100644 --- a/tools/linter/adapters/import_linter.py +++ b/tools/linter/adapters/import_linter.py @@ -48,280 +48,7 @@ CURRENT_FILE_NAME = os.path.basename(__file__) _MODULE_NAME_ALLOW_LIST: set[str] = set() # Add builtin modules. -if sys.version_info >= (3, 10): - _MODULE_NAME_ALLOW_LIST.update(sys.stdlib_module_names) -else: - assert (sys.version_info.major, sys.version_info.minor) == (3, 9) - # Taken from `stdlib_list("3.9")` to avoid introducing a new dependency. - _MODULE_NAME_ALLOW_LIST.update( - [ - "__future__", - "_abc", - "_aix_support", - "_ast", - "_bootlocale", - "_bootsubprocess", - "_codecs", - "_collections", - "_collections_abc", - "_compat_pickle", - "_compression", - "_crypt", - "_functools", - "_hashlib", - "_imp", - "_io", - "_locale", - "_lsprof", - "_markupbase", - "_operator", - "_osx_support", - "_peg_parser", - "_posixsubprocess", - "_py_abc", - "_pydecimal", - "_pyio", - "_random", - "_signal", - "_sitebuiltins", - "_socket", - "_sre", - "_ssl", - "_stat", - "_string", - "_strptime", - "_symtable", - "_sysconfigdata_x86_64_conda_cos6_linux_gnu", - "_sysconfigdata_x86_64_conda_linux_gnu", - "_thread", - "_threading_local", - "_tracemalloc", - "_uuid", - "_warnings", - "_weakref", - "_weakrefset", - "abc", - "aifc", - "antigravity", - "argparse", - "array", - "ast", - "asynchat", - "asyncio", - "asyncore", - "atexit", - "audioop", - "base64", - "bdb", - "binascii", - "binhex", - "bisect", - "builtins", - "bz2", - "cProfile", - "calendar", - "cgi", - "cgitb", - "chunk", - "cmath", - "cmd", - "code", - "codecs", - "codeop", - "collections", - "colorsys", - "compileall", - "concurrent", - "configparser", - "contextlib", - "contextvars", - "copy", - "copyreg", - "crypt", - "csv", - "ctypes", - "curses", - "dataclasses", - "datetime", - "dbm", - "decimal", - "difflib", - "dis", - "distutils", - "doctest", - "email", - "encodings", - "ensurepip", - "enum", - "errno", - "faulthandler", - "fcntl", - "filecmp", - "fileinput", - "fnmatch", - "formatter", - "fractions", - "ftplib", - "functools", - "gc", - "genericpath", - "getopt", - "getpass", - "gettext", - "glob", - "graphlib", - "grp", - "gzip", - "hashlib", - "heapq", - "hmac", - "html", - "http", - "idlelib", - "imaplib", - "imghdr", - "imp", - "importlib", - "inspect", - "io", - "ipaddress", - "itertools", - "json", - "keyword", - "lib2to3", - "linecache", - "locale", - "logging", - "lzma", - "mailbox", - "mailcap", - "marshal", - "math", - "mimetypes", - "mmap", - "modulefinder", - "msilib", - "msvcrt", - "multiprocessing", - "netrc", - "nis", - "nntplib", - "ntpath", - "nturl2path", - "numbers", - "opcode", - "operator", - "optparse", - "os", - "ossaudiodev", - "parser", - "pathlib", - "pdb", - "pickle", - "pickletools", - "pipes", - "pkgutil", - "platform", - "plistlib", - "poplib", - "posix", - "posixpath", - "pprint", - "profile", - "pstats", - "pty", - "pwd", - "py_compile", - "pyclbr", - "pydoc", - "pydoc_data", - "queue", - "quopri", - "random", - "re", - "readline", - "reprlib", - "resource", - "rlcompleter", - "runpy", - "sched", - "secrets", - "select", - "selectors", - "shelve", - "shlex", - "shutil", - "signal", - "site", - "smtpd", - "smtplib", - "sndhdr", - "socket", - "socketserver", - "spwd", - "sqlite3", - "sre_compile", - "sre_constants", - "sre_parse", - "ssl", - "stat", - "statistics", - "string", - "stringprep", - "struct", - "subprocess", - "sunau", - "symbol", - "symtable", - "sys", - "sysconfig", - "syslog", - "tabnanny", - "tarfile", - "telnetlib", - "tempfile", - "termios", - "test", - "textwrap", - "this", - "threading", - "time", - "timeit", - "tkinter", - "token", - "tokenize", - "trace", - "traceback", - "tracemalloc", - "tty", - "turtle", - "turtledemo", - "types", - "typing", - "unicodedata", - "unittest", - "urllib", - "uu", - "uuid", - "venv", - "warnings", - "wave", - "weakref", - "webbrowser", - "winreg", - "winsound", - "wsgiref", - "xdrlib", - "xml", - "xmlrpc", - "xxsubtype", - "zipapp", - "zipfile", - "zipimport", - "zlib", - "zoneinfo", - ] - ) +_MODULE_NAME_ALLOW_LIST.update(sys.stdlib_module_names) # Add the allowed third party libraries. Please avoid updating this unless you # understand the risks -- see `_ERROR_MESSAGE` for why. diff --git a/torch/__init__.py b/torch/__init__.py index 08dee0624350..957d6fc9581a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2819,10 +2819,7 @@ def _import_device_backends(): from importlib.metadata import entry_points group_name = "torch.backends" - if sys.version_info < (3, 10): - backend_extensions = entry_points().get(group_name, ()) - else: - backend_extensions = entry_points(group=group_name) + backend_extensions = entry_points(group=group_name) for backend_extension in backend_extensions: try: diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 699d82fff3f0..c6a334359d0e 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -60,7 +60,6 @@ optimized_model = torch.compile(model, backend="my_compiler") import functools import logging -import sys from collections.abc import Sequence from importlib.metadata import EntryPoint from typing import Any, Callable, Optional, Protocol, Union @@ -174,12 +173,7 @@ def _discover_entrypoint_backends() -> None: from importlib.metadata import entry_points group_name = "torch_dynamo_backends" - if sys.version_info < (3, 10): - eps = entry_points() - eps = eps[group_name] if group_name in eps else [] - eps_dict = {ep.name: ep for ep in eps} - else: - eps = entry_points(group=group_name) - eps_dict = {name: eps[name] for name in eps.names} + eps = entry_points(group=group_name) + eps_dict = {name: eps[name] for name in eps.names} for backend_name in eps_dict: _BACKENDS[backend_name] = eps_dict[backend_name] diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 14a6f78bfcd4..d0e9297d33b5 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -251,22 +251,6 @@ def create_rot_n(n: int) -> list[Instruction]: # e.g. rotate 3 is equivalent to swap 3, swap 2 return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] - # ROT_N does not exist in Python <= 3.9, but we can simulate it - if sys.version_info < (3, 10) and n >= 5: - """ - 0 1 2 3 4 - [0 1 2 3 4] - 4 3 2 1 0 - 4 [3 2 1 0] - 4 0 1 2 3 - """ - return [ - create_instruction("BUILD_TUPLE", arg=n), - create_instruction("UNPACK_SEQUENCE", arg=n), - create_instruction("BUILD_TUPLE", arg=n - 1), - create_instruction("UNPACK_SEQUENCE", arg=n - 1), - ] - if n <= 4: return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] return [create_instruction("ROT_N", arg=n)] @@ -545,30 +529,6 @@ def create_print_value(value: Any) -> list[Instruction]: ] -def lnotab_writer( - lineno: int, byteno: int = 0 -) -> tuple[list[int], Callable[[int, int], None]]: - """ - Used to create typing.CodeType.co_lnotab - See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt - This is the internal format of the line number table if Python < 3.10 - """ - assert sys.version_info < (3, 10) - lnotab: list[int] = [] - - def update(lineno_new: int, byteno_new: int) -> None: - nonlocal byteno, lineno - while byteno_new != byteno or lineno_new != lineno: - byte_offset = max(0, min(byteno_new - byteno, 255)) - line_offset = max(-128, min(lineno_new - lineno, 127)) - assert byte_offset != 0 or line_offset != 0 - byteno += byte_offset - lineno += line_offset - lnotab.extend((byte_offset, line_offset & 0xFF)) - - return lnotab, update - - def linetable_310_writer( first_lineno: int, ) -> tuple[list[int], Callable[[int, int], None], Callable[[int], None]]: @@ -577,7 +537,7 @@ def linetable_310_writer( See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt This is the internal format of the line number table for Python 3.10 """ - assert sys.version_info >= (3, 10) and sys.version_info < (3, 11) + assert sys.version_info[:2] == (3, 10) linetable: list[int] = [] lineno = first_lineno lineno_delta = 0 @@ -799,10 +759,7 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, for _ in range(instruction_size(inst) // 2 - 1): code.extend((0, 0)) else: - if sys.version_info < (3, 10): - lnotab, update_lineno = lnotab_writer(firstlineno) - else: - lnotab, update_lineno, end = linetable_310_writer(firstlineno) + lnotab, update_lineno, end = linetable_310_writer(firstlineno) for inst in instructions: if inst.starts_line is not None: @@ -810,8 +767,7 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, arg = inst.arg or 0 code.extend((inst.opcode, arg & 0xFF)) - if sys.version_info >= (3, 10): - end(len(code)) + end(len(code)) return bytes(code), bytes(lnotab) @@ -903,9 +859,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: assert inst.target is not None target = _get_instruction_front(instructions, indexof[inst.target]) if inst.opcode in dis.hasjabs: - if sys.version_info < (3, 10): - inst.arg = target.offset - elif sys.version_info < (3, 11): + if sys.version_info < (3, 11): # `arg` is expected to be bytecode offset, whereas `offset` is byte offset. # Divide since bytecode is 2 bytes large. inst.arg = int(target.offset / 2) @@ -917,9 +871,8 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: inst.arg = abs( int(target.offset - inst.offset - instruction_size(inst)) ) - if sys.version_info >= (3, 10): - # see bytecode size comment in the absolute jump case above - inst.arg //= 2 + # see bytecode size comment in the absolute jump case above + inst.arg //= 2 inst.argval = target.offset inst.argrepr = f"to {target.offset}" @@ -1558,10 +1511,7 @@ def get_code_keys() -> list[str]: if sys.version_info >= (3, 11): keys.append("co_qualname") keys.append("co_firstlineno") - if sys.version_info >= (3, 10): - keys.append("co_linetable") - else: - keys.append("co_lnotab") + keys.append("co_linetable") if sys.version_info >= (3, 11): # not documented, but introduced in https://github.com/python/cpython/issues/84403 keys.append("co_exceptiontable") @@ -1618,11 +1568,8 @@ def clean_and_assemble_instructions( remove_extra_line_nums(instructions) bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"]) - if sys.version_info < (3, 10): - code_options["co_lnotab"] = lnotab - else: - code_options["co_linetable"] = lnotab + code_options["co_linetable"] = lnotab code_options["co_code"] = bytecode code_options["co_stacksize"] = stacksize_analysis(instructions) assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - { diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 2b64327b93de..ef9b2c28d603 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -6,7 +6,6 @@ from __future__ import annotations import itertools import operator -import sys from typing import Callable, Optional, overload, TYPE_CHECKING, TypeVar from typing_extensions import TypeAlias @@ -26,6 +25,7 @@ __all__ = [ "dropwhile", "filterfalse", "islice", + "pairwise", "tee", "zip_longest", ] @@ -163,20 +163,16 @@ def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: # Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise -if sys.version_info >= (3, 10): - - @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] - def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: - a = None - first = True - for b in iterable: - if first: - first = False - else: - yield a, b # type: ignore[misc] - a = b - - __all__ += ["pairwise"] +@substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] +def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: + a = None + first = True + for b in iterable: + if first: + first = False + else: + yield a, b # type: ignore[misc] + a = b # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index b762bd95fa11..76bb976706f5 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -465,9 +465,6 @@ def stack_op(fn: Callable[..., object]) -> Callable[..., Any]: def is_stdlib(mod: object) -> bool: - if sys.version_info < (3, 10): - # For < 3.10, no easy way to identify a stdlib module name. - return False if not isinstance(mod, types.ModuleType): return False return mod.__name__.split(".")[0] in sys.stdlib_module_names @@ -3769,18 +3766,17 @@ class InstructionTranslatorBase( self.package = package - if sys.version_info >= (3, 10): - from .resume_execution import ( - CO_ASYNC_GENERATOR, - CO_COROUTINE, - CO_GENERATOR, - CO_ITERABLE_COROUTINE, - ) + from .resume_execution import ( + CO_ASYNC_GENERATOR, + CO_COROUTINE, + CO_GENERATOR, + CO_ITERABLE_COROUTINE, + ) - if f_code.co_flags & ( - CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR - ): - self.push(BuiltinVariable(None)) + if f_code.co_flags & ( + CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR + ): + self.push(BuiltinVariable(None)) self.inline_depth = inline_depth self.inconsistent_side_effects = False diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 805c3be524e8..1ce88f1d744c 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -517,13 +517,6 @@ def skipIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]: return fn -def requiresPy310(fn: Callable[_P, _T]) -> Callable[_P, _T]: - if sys.version_info >= (3, 10): - return fn - else: - return unittest.skip("Requires Python 3.10+")(fn) - - # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py def expectedFailureDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]: diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 80b9915aaa21..cff52901126f 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -16,7 +16,6 @@ handling of iterator operations during code transformation and optimization. """ import itertools -import sys from typing import TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables @@ -442,17 +441,14 @@ class ZipVariable(IteratorVariable): codegen.append_output( create_instruction("BUILD_TUPLE", arg=len(self.iterables)) ) - if sys.version_info >= (3, 10): - codegen.extend_output( - [ - codegen.create_load_const("strict"), - codegen.create_load_const(self.strict), - create_instruction("BUILD_MAP", arg=1), - create_instruction("CALL_FUNCTION_EX", arg=1), - ] - ) - else: - codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + create_instruction("CALL_FUNCTION_EX", arg=1), + ] + ) class MapVariable(ZipVariable): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 5306919ecf6d..8520831750bf 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1326,11 +1326,7 @@ class _InProcessFxCompile(FxCompile): metrics_context = get_metrics_context() if metrics_context.in_progress(): - # TODO: Remove this when 3.9 is no longer supported - if sys.version_info < (3, 10): - num_graph_breaks = sum(counters["graph_break"].values()) - else: - num_graph_breaks = counters["graph_break"].total() + num_graph_breaks = counters["graph_break"].total() CompileEventLogger.compilation_metric( overwrite=True, num_graph_breaks=num_graph_breaks ) diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 8149bc7e98e7..b0c89caa89c3 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -5,7 +5,6 @@ import pickle import random import signal import string -import sys import traceback from collections.abc import KeysView, Sequence from enum import Enum @@ -610,9 +609,6 @@ class ConfigFuzzer: sm: How type value samples are generated, default TOGGLE. test_timeout: max time a test can take. """ - if sys.version_info < (3, 10): - log.error("Only python 3.10 and later supported") - return self.seed = seed self.test_timeout = test_timeout self.detailed_results: dict[ComboType, dict[str, Any]] = {} diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 4a1febe08e99..ea7d61cf9315 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -1,9 +1,6 @@ import dataclasses -import functools import itertools -import sys from collections import Counter, defaultdict -from collections.abc import Iterable, Iterator from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union import sympy @@ -373,20 +370,6 @@ class NodeSplitGetter: return pw, red -if sys.version_info >= (3, 10): - # On Python 3.10+ we can use zip(strict=True) - zip_equal = functools.partial(zip, strict=True) -else: - # Fallback for older versions - def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]: - """ - Zip two iterables, raising ValueError if their lengths differ. - """ - if len(it1) != len(it2): - raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}") - return zip(it1, it2) - - def apply_var_mapping( iter_vars: list[sympy.Symbol], red_vars: list[sympy.Symbol], @@ -424,7 +407,7 @@ def apply_var_mapping( iter_vars_to_flat_vars = {} for i, (group, var_group) in enumerate( - zip_equal(apply_groups, ((iter_vars, red_vars))) + zip(apply_groups, (iter_vars, red_vars), strict=True) ): # if the node has sizes (p0, 1) and the fused node is (p0, r0) # the reduction var gets filled in for split_iteration_range @@ -437,7 +420,9 @@ def apply_var_mapping( count = 0 flat_vars_to_new_vars = {} - for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars): + for new_range, new_var in zip( + new_ranges, norm_pw_vars + norm_red_vars, strict=True + ): range_vars = [] for i in range(len(new_range)): range_vars.append(flat_vars[count]) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 4e0bc3d19444..f64213c26383 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3335,12 +3335,7 @@ class ScopedDict(MutableMapping[KeyType, ValType]): @dataclass_transform(frozen_default=True) def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any: def wrap(cls: _T) -> _T: - if sys.version_info >= (3, 10): - return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload] - else: - # Polyfill for python=3.9. kw_only simply introduces an extra check - # that only kwargs are used (and is not available on 3.9) - return dataclasses.dataclass(cls, frozen=frozen) + return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload] if cls is None: return wrap diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index be6d23bbbc53..928e0781c857 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -52,15 +52,8 @@ from torch.futures import Future _P = ParamSpec("_P") _R = TypeVar("_R") -IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) - BuiltinUnionType: Union[type, tuple[type, ...]] -if sys.version_info >= (3, 10): - # NOTE: IS_PY310_PLUS doesn't work with mypy. - # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks - BuiltinUnionType = types.UnionType -else: - BuiltinUnionType = () # trick: this makes isinstance short circuit. +BuiltinUnionType = types.UnionType LockType: type try: @@ -1257,12 +1250,9 @@ def _get_named_tuple_properties( defaults = [] # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function # Also, annotations from base class are not inherited so they need to be queried explicitly - if sys.version_info[:2] < (3, 10): - obj_annotations = getattr(obj, "__annotations__", {}) - else: - obj_annotations = inspect.get_annotations(obj) - if len(obj_annotations) == 0 and hasattr(obj, "__base__"): - obj_annotations = inspect.get_annotations(obj.__base__) + obj_annotations = inspect.get_annotations(obj) + if len(obj_annotations) == 0 and hasattr(obj, "__base__"): + obj_annotations = inspect.get_annotations(obj.__base__) annotations = [] for field in obj._fields: diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index 75f0d16f7d19..ebada4623a81 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -import sys +from importlib.metadata import entry_points from .api import ( rendezvous_handler_registry as handler_registry, @@ -15,11 +15,6 @@ from .api import ( from .dynamic_rendezvous import create_handler -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points - log = logging.getLogger(__name__) __all__ = ["get_rendezvous_handler"] diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index e89bcc47dff6..aa213dcad35f 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -2,7 +2,6 @@ import collections import functools import inspect -import sys import textwrap import types import warnings @@ -158,8 +157,6 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): def get_annotations(obj): - if sys.version_info < (3, 10): - return getattr(obj, "__annotations__", {}) # In Python-3.10+ it is recommended to use inspect.get_annotations # See https://docs.python.org/3.10/howto/annotations.html # But also, in 3.10 annotations from base class are not inherited diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 811b45fd1d69..8b87d80002e5 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -6,7 +6,6 @@ import inspect import io import os import pickle -import sys import tokenize import unittest from dataclasses import dataclass @@ -178,10 +177,7 @@ def install_config_module(module: ModuleType) -> None: prefix: str, ) -> None: """Walk the module structure and move everything to module._config""" - if sys.version_info[:2] < (3, 10): - type_hints = getattr(source, "__annotations__", {}) - else: - type_hints = inspect.get_annotations(source) + type_hints = inspect.get_annotations(source) for key, value in list(source.__dict__.items()): if ( key.startswith("__") diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index efe140f10f01..f9b1390d02dc 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -570,10 +570,7 @@ def tree_map_( Type2 = tuple[type[T], type[S]] Type3 = tuple[type[T], type[S], type[U]] -if sys.version_info >= (3, 10): - TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] -else: - TypeAny = Union[type[Any], tuple[type[Any], ...]] +TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] Fn2 = Callable[[Union[T, S]], R] Fn3 = Callable[[Union[T, S, U]], R] diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 773e9f00e3d1..ff3ea897e4b1 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -1412,10 +1412,7 @@ def tree_map_( Type2 = tuple[type[T], type[S]] Type3 = tuple[type[T], type[S], type[U]] -if sys.version_info >= (3, 10): - TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] -else: - TypeAny = Union[type[Any], tuple[type[Any], ...]] +TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] Fn2 = Callable[[Union[T, S]], R] Fn3 = Callable[[Union[T, S, U]], R] diff --git a/torchgen/utils.py b/torchgen/utils.py index 905d6fd0c0b6..ffb59557233b 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -355,12 +355,9 @@ def dataclass_repr( width: int = 80, ) -> str: # built-in pprint module support dataclasses from python 3.10 - if sys.version_info >= (3, 10): - from pprint import pformat + from pprint import pformat - return pformat(obj, indent, width) - - return _pformat(obj, indent=indent, width=width) + return pformat(obj, indent, width) def _pformat(