Update ruff linter for PEP585 (#147540)

This turns on PEP585 enforcement in RUFF.

- Updates the target python version
- Stops ignoring UP006 warnings (PEP585)
- Fixes a few issues which crept into the tree in the last day

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147540
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-02-21 07:57:30 -08:00
committed by PyTorch MergeBot
parent 77d2780657
commit 086d146f6f
19 changed files with 131 additions and 88 deletions

View File

@ -916,10 +916,13 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg
# inputs will incur high penalty then the next one. # inputs will incur high penalty then the next one.
maybe_mark_step(args) maybe_mark_step(args)
with maybe_mark_profile(p=p, mark=mark), maybe_enable_compiled_autograd( with (
args.compiled_autograd, maybe_mark_profile(p=p, mark=mark),
fullgraph=args.nopython, maybe_enable_compiled_autograd(
dynamic=args.dynamic_shapes, args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
): ):
timings[rep], actual_output = timed( timings[rep], actual_output = timed(
model, model,
@ -1090,10 +1093,13 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
# call mark_step between the 2 calls to make the comparison fair. # call mark_step between the 2 calls to make the comparison fair.
maybe_mark_step(args) maybe_mark_step(args)
with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd( with (
args.compiled_autograd, maybe_mark_profile(p=p, mark="actual"),
fullgraph=args.nopython, maybe_enable_compiled_autograd(
dynamic=args.dynamic_shapes, args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
): ):
timings[rep, 1], actual_output = timed( timings[rep, 1], actual_output = timed(
model, model,
@ -2445,12 +2451,15 @@ class BenchmarkRunner:
else: else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd( with (
self.args.compiled_autograd, maybe_enable_compiled_autograd(
fullgraph=self.args.nopython, self.args.compiled_autograd,
dynamic=self.args.dynamic_shapes, fullgraph=self.args.nopython,
), maybe_snapshot_memory( dynamic=self.args.dynamic_shapes,
self.args.snapshot_memory, f"compiled_{self.args.only}" ),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
): ):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup( dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo" optimized_model_iter_fn, model, example_inputs, "dynamo"
@ -2598,12 +2607,15 @@ class BenchmarkRunner:
else: else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd( with (
self.args.compiled_autograd, maybe_enable_compiled_autograd(
fullgraph=self.args.nopython, self.args.compiled_autograd,
dynamic=self.args.dynamic_shapes, fullgraph=self.args.nopython,
), maybe_snapshot_memory( dynamic=self.args.dynamic_shapes,
self.args.snapshot_memory, f"compiled_{self.args.only}" ),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
): ):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup( dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo" optimized_model_iter_fn, model, example_inputs, "dynamo"

View File

@ -56,9 +56,11 @@ class Benchmark(BenchmarkBase):
def _work(self): def _work(self):
# enable_cpp_symbolic_shape_guards has impact on this benchmark # enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency. # Keep using False value for consistency.
with fresh_inductor_cache(), torch._inductor.config.patch( with (
force_shape_pad=self._force_shape_pad fresh_inductor_cache(),
), torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False): torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
):
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
self.m.cuda() if self._is_gpu else self.m self.m.cuda() if self._is_gpu else self.m
) )

View File

@ -40,7 +40,7 @@ standard_library = ["typing_extensions"]
[tool.ruff] [tool.ruff]
target-version = "py38" target-version = "py39"
line-length = 88 line-length = 88
src = ["caffe2", "torch", "torchgen", "functorch", "test"] src = ["caffe2", "torch", "torchgen", "functorch", "test"]
@ -85,7 +85,6 @@ ignore = [
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117", "SIM117",
"SIM118", "SIM118",
"UP006", # keep-runtime-typing
"UP007", # keep-runtime-typing "UP007", # keep-runtime-typing
] ]
select = [ select = [

View File

@ -12,16 +12,8 @@ import pprint
import sys import sys
import unittest import unittest
import warnings import warnings
from typing import ( from collections.abc import Collection, Iterable, Mapping, Sequence
Any, from typing import Any, Callable, Optional, TypeVar
Callable,
Collection,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
)
import error_reproduction import error_reproduction
import numpy as np import numpy as np

View File

@ -25,7 +25,7 @@ errors.
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Callable, Optional, Sequence, Tuple, TYPE_CHECKING from typing import Callable, Optional, TYPE_CHECKING
import error_reproduction import error_reproduction
import numpy as np import numpy as np
@ -44,6 +44,7 @@ from torch.utils import _pytree as pytree
if TYPE_CHECKING: if TYPE_CHECKING:
import unittest import unittest
from collections.abc import Sequence
from torch.testing._internal.opinfo import core as opinfo_core from torch.testing._internal.opinfo import core as opinfo_core
@ -73,7 +74,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
def _should_skip_xfail_test_sample( def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype, device_type: str op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped.""" """Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS: if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None return None, None

View File

@ -513,8 +513,9 @@ def verify(
"could mean that your network is numerically unstable. Otherwise\n" "could mean that your network is numerically unstable. Otherwise\n"
"it indicates a bug in PyTorch/ONNX; please file a bug report." "it indicates a bug in PyTorch/ONNX; please file a bug report."
) )
with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt( with (
result_hint Errors(msg, rtol=rtol, atol=atol) as errs,
errs.addErrCtxt(result_hint),
): ):
for i, (x, y) in enumerate(zip(torch_out, backend_out)): for i, (x, y) in enumerate(zip(torch_out, backend_out)):
errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}") errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}")

View File

@ -580,9 +580,12 @@ class TestDecomp(TestCase):
args = [sample_input.input] + list(sample_input.args) args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs kwargs = sample_input.kwargs
func = partial(op.get_op(), **kwargs) func = partial(op.get_op(), **kwargs)
with self.DecompCrossRefMode( with (
self, self.precision, self.rel_tol, dtype, run_all=False self.DecompCrossRefMode(
) as mode, enable_python_dispatcher(): self, self.precision, self.rel_tol, dtype, run_all=False
) as mode,
enable_python_dispatcher(),
):
torch.autograd.gradcheck(func, args) torch.autograd.gradcheck(func, args)
self.check_decomposed(aten_name, mode) self.check_decomposed(aten_name, mode)
@ -677,9 +680,12 @@ class TestDecomp(TestCase):
module_input.forward_input.args, module_input.forward_input.args,
module_input.forward_input.kwargs, module_input.forward_input.kwargs,
) )
with self.DecompCrossRefMode( with (
self, self.precision, self.rel_tol, dtype, run_all=True self.DecompCrossRefMode(
), enable_python_dispatcher(): self, self.precision, self.rel_tol, dtype, run_all=True
),
enable_python_dispatcher(),
):
decomp_out = m(*args, **kwargs) decomp_out = m(*args, **kwargs)
non_decomp_out = m(*args, **kwargs) non_decomp_out = m(*args, **kwargs)
@ -955,9 +961,12 @@ def forward(self, scores_1, mask_1, value_1):
# store the called list on the mode object instance and no # store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode # explicit clearing is necessary as I will create a fresh mode
# for each region # for each region
with self.DecompCrossRefMode( with (
self, self.precision, self.rel_tol, dtype, run_all self.DecompCrossRefMode(
) as mode, enable_python_dispatcher(): self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if run_without_python_dispatcher(mode): if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because # without this check, incorrect decomps at the python dispatcher level can still pass because
@ -974,9 +983,12 @@ def forward(self, scores_1, mask_1, value_1):
): ):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
with self.DecompCrossRefMode( with (
self, self.precision, self.rel_tol, dtype, run_all self.DecompCrossRefMode(
) as mode, enable_python_dispatcher(): self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
decomp_vjp_fn(cotangents) decomp_vjp_fn(cotangents)
if run_without_python_dispatcher(mode): if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because # without this check, incorrect decomps at the python dispatcher level can still pass because
@ -993,9 +1005,12 @@ def forward(self, scores_1, mask_1, value_1):
kwargs = sample_input.kwargs kwargs = sample_input.kwargs
# A failure here might be because the decomposition for the op is wrong or because a # A failure here might be because the decomposition for the op is wrong or because a
# decomposition used by the particular op is wrong. # decomposition used by the particular op is wrong.
with self.DecompCrossRefMode( with (
self, self.precision, self.rel_tol, dtype, run_all self.DecompCrossRefMode(
) as mode, enable_python_dispatcher(): self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
func(*args, **kwargs) func(*args, **kwargs)
if run_without_python_dispatcher(mode): if run_without_python_dispatcher(mode):

View File

@ -255,9 +255,11 @@ class TestForeach(TestCase):
else inputs else inputs
) )
try: try:
with InplaceForeachVersionBumpCheck( with (
self, inputs[0] InplaceForeachVersionBumpCheck(self, inputs[0])
) if op.is_inplace else nullcontext(): if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath) actual = op(inputs, self.is_cuda, is_fastpath)
except RuntimeError as e: except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
@ -278,9 +280,11 @@ class TestForeach(TestCase):
try: try:
op_kwargs = {} op_kwargs = {}
op_kwargs.update(kwargs) op_kwargs.update(kwargs)
with InplaceForeachVersionBumpCheck( with (
self, inputs[0] InplaceForeachVersionBumpCheck(self, inputs[0])
) if op.is_inplace else nullcontext(): if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs) actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
except RuntimeError as e: except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):

View File

@ -2782,8 +2782,10 @@ class TestFakeTensor(TestCase):
with torch._subclasses.CrossRefFakeMode( with torch._subclasses.CrossRefFakeMode(
ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True
): ):
with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled( with (
False warnings.catch_warnings(),
context(),
torch.autograd.set_multithreading_enabled(False),
): ):
composite_compliance.compute_expected_grads( composite_compliance.compute_expected_grads(
op.get_op(), op.get_op(),

View File

@ -866,9 +866,10 @@ class TracingContext:
@contextlib.contextmanager @contextlib.contextmanager
def clear_frame(): def clear_frame():
tc = TracingContext.get() tc = TracingContext.get()
with unittest.mock.patch.object( with (
tc, "frame_summary_stack", [] unittest.mock.patch.object(tc, "frame_summary_stack", []),
), unittest.mock.patch.object(tc, "loc_in_frame", None): unittest.mock.patch.object(tc, "loc_in_frame", None),
):
try: try:
yield yield
except Exception as e: except Exception as e:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, List, Literal, Optional, TYPE_CHECKING, Union from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
from sympy import Expr, symbols from sympy import Expr, symbols
@ -160,8 +160,8 @@ class CUDATemplateKernel(CUDAKernel):
def __init__( def __init__(
self, self,
kernel_name: str, kernel_name: str,
runtime_arg_info: List["ArgInfo"], runtime_arg_info: list["ArgInfo"],
runtime_arg_values: List[Any], runtime_arg_values: list[Any],
) -> None: ) -> None:
""" """
Initializes a new instance of the CUDATemplateKernel class. Initializes a new instance of the CUDATemplateKernel class.

View File

@ -834,7 +834,9 @@ class OpOverload(OperatorBase):
if curr_mode not in self.python_key_table: if curr_mode not in self.python_key_table:
if isinstance(self, TorchBindOpOverload): if isinstance(self, TorchBindOpOverload):
with torch.utils._python_dispatch._pop_mode_temporarily() as mode: with (
torch.utils._python_dispatch._pop_mode_temporarily() as mode
):
return torch._library.utils.handle_dispatch_mode( return torch._library.utils.handle_dispatch_mode(
mode, self, *args, **kwargs mode, self, *args, **kwargs
) )

View File

@ -569,9 +569,13 @@ class Exporter:
# https://github.com/pytorch/pytorch/issues/103764 # https://github.com/pytorch/pytorch/issues/103764
from torch.onnx._internal.fx import decomposition_skip from torch.onnx._internal.fx import decomposition_skip
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips( with (
self.options self.options.diagnostic_context,
), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): decomposition_skip.enable_decomposition_skips(self.options),
torch._dynamo.config.patch(
dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
),
):
graph_module = self.options.fx_tracer.generate_fx( graph_module = self.options.fx_tracer.generate_fx(
self.options, self.model, self.model_args, self.model_kwargs self.options, self.model, self.model_args, self.model_kwargs
) )

View File

@ -8,7 +8,8 @@ from __future__ import annotations
__all__ = ["onnx_impl", "get_torchlib_ops"] __all__ = ["onnx_impl", "get_torchlib_ops"]
import logging import logging
from typing import Any, Callable, Sequence, TypeVar from collections.abc import Sequence
from typing import Any, Callable, TypeVar
import onnxscript import onnxscript

View File

@ -68,7 +68,11 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph. # Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode: with (
fake_tensor.unset_fake_temporarily(),
python_dispatch.enable_python_dispatcher(),
fake_mode,
):
decomposed_module = proxy_tensor.make_fx( decomposed_module = proxy_tensor.make_fx(
module, module,
decomposition_table=self.decomposition_table, decomposition_table=self.decomposition_table,

View File

@ -179,13 +179,12 @@ def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
.. deprecated:: 2.7 .. deprecated:: 2.7
Please set training mode before exporting the model. Please set training mode before exporting the model.
""" """
with select_model_mode_for_export( with (
model, mode select_model_mode_for_export(model, mode) as mode_ctx,
) as mode_ctx, disable_apex_o2_state_dict_hook( disable_apex_o2_state_dict_hook(model) as apex_ctx,
model setup_onnx_logging(verbose) as log_ctx,
) as apex_ctx, setup_onnx_logging( diagnostics.create_export_diagnostic_context() as diagnostic_ctx,
verbose ):
) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx:
yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx)

View File

@ -1606,9 +1606,12 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
return saved_id[0] return saved_id[0]
return deserialized_objects[int(saved_id)] return deserialized_objects[int(saved_id)]
with closing( with (
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) closing(
) as tar, mkdtemp() as tmpdir: tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
) as tar,
mkdtemp() as tmpdir,
):
if pickle_module is _weights_only_unpickler: if pickle_module is _weights_only_unpickler:
raise RuntimeError( raise RuntimeError(
"Cannot use ``weights_only=True`` with files saved in the " "Cannot use ``weights_only=True`` with files saved in the "

View File

@ -6,7 +6,8 @@ import os
import contextlib import contextlib
import torch._logging import torch._logging
import torch._logging._internal import torch._logging._internal
from typing import Callable, ContextManager, List, Tuple from contextlib import AbstractContextManager
from typing import Callable
from torch._dynamo.utils import LazyString from torch._dynamo.utils import LazyString
from torch._inductor import config as inductor_config from torch._inductor import config as inductor_config
import logging import logging
@ -214,7 +215,7 @@ def logs_to_string(module, log_option):
return log_stream, ctx_manager return log_stream, ctx_manager
def multiple_logs_to_string(module: str, *log_options: str) -> Tuple[List[io.StringIO], Callable[[], ContextManager[None]]]: def multiple_logs_to_string(module: str, *log_options: str) -> tuple[list[io.StringIO], Callable[[], AbstractContextManager[None]]]:
"""Example: """Example:
multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs") multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs")
returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the
@ -234,7 +235,7 @@ def multiple_logs_to_string(module: str, *log_options: str) -> Tuple[List[io.Str
for logger, handler in zip(loggers, handlers): for logger, handler in zip(loggers, handlers):
logger.removeHandler(handler) logger.removeHandler(handler)
def ctx_manager() -> ContextManager[None]: def ctx_manager() -> AbstractContextManager[None]:
exit_stack = log_settings(", ".join(log_options)) exit_stack = log_settings(", ".join(log_options))
exit_stack.enter_context(tmp_redirect_logs()) exit_stack.enter_context(tmp_redirect_logs())
return exit_stack # type: ignore[return-value] return exit_stack # type: ignore[return-value]

View File

@ -10,7 +10,7 @@ from collections.abc import Sequence
# targets fail to typecheck with: # targets fail to typecheck with:
# TypeError: Cannot create a consistent method resolution order (MRO) for # TypeError: Cannot create a consistent method resolution order (MRO) for
# bases Iterable, Generic # bases Iterable, Generic
from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP006 from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP035
from typing_extensions import deprecated from typing_extensions import deprecated
# No 'default_generator' in torch/__init__.pyi # No 'default_generator' in torch/__init__.pyi