Update ruff linter for PEP585 (#147540)

This turns on PEP585 enforcement in RUFF.

- Updates the target python version
- Stops ignoring UP006 warnings (PEP585)
- Fixes a few issues which crept into the tree in the last day

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147540
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-02-21 07:57:30 -08:00
committed by PyTorch MergeBot
parent 77d2780657
commit 086d146f6f
19 changed files with 131 additions and 88 deletions

View File

@ -916,10 +916,13 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg
# inputs will incur high penalty then the next one.
maybe_mark_step(args)
with maybe_mark_profile(p=p, mark=mark), maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
with (
maybe_mark_profile(p=p, mark=mark),
maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
):
timings[rep], actual_output = timed(
model,
@ -1090,10 +1093,13 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
# call mark_step between the 2 calls to make the comparison fair.
maybe_mark_step(args)
with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
with (
maybe_mark_profile(p=p, mark="actual"),
maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
):
timings[rep, 1], actual_output = timed(
model,
@ -2445,12 +2451,15 @@ class BenchmarkRunner:
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
), maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
with (
maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"
@ -2598,12 +2607,15 @@ class BenchmarkRunner:
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
), maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
with (
maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"

View File

@ -56,9 +56,11 @@ class Benchmark(BenchmarkBase):
def _work(self):
# enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency.
with fresh_inductor_cache(), torch._inductor.config.patch(
force_shape_pad=self._force_shape_pad
), torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False):
with (
fresh_inductor_cache(),
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
):
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
self.m.cuda() if self._is_gpu else self.m
)

View File

@ -40,7 +40,7 @@ standard_library = ["typing_extensions"]
[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 88
src = ["caffe2", "torch", "torchgen", "functorch", "test"]
@ -85,7 +85,6 @@ ignore = [
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117",
"SIM118",
"UP006", # keep-runtime-typing
"UP007", # keep-runtime-typing
]
select = [

View File

@ -12,16 +12,8 @@ import pprint
import sys
import unittest
import warnings
from typing import (
Any,
Callable,
Collection,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
)
from collections.abc import Collection, Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TypeVar
import error_reproduction
import numpy as np

View File

@ -25,7 +25,7 @@ errors.
from __future__ import annotations
import os
from typing import Callable, Optional, Sequence, Tuple, TYPE_CHECKING
from typing import Callable, Optional, TYPE_CHECKING
import error_reproduction
import numpy as np
@ -44,6 +44,7 @@ from torch.utils import _pytree as pytree
if TYPE_CHECKING:
import unittest
from collections.abc import Sequence
from torch.testing._internal.opinfo import core as opinfo_core
@ -73,7 +74,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]:
) -> tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None

View File

@ -513,8 +513,9 @@ def verify(
"could mean that your network is numerically unstable. Otherwise\n"
"it indicates a bug in PyTorch/ONNX; please file a bug report."
)
with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt(
result_hint
with (
Errors(msg, rtol=rtol, atol=atol) as errs,
errs.addErrCtxt(result_hint),
):
for i, (x, y) in enumerate(zip(torch_out, backend_out)):
errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}")

View File

@ -580,9 +580,12 @@ class TestDecomp(TestCase):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
func = partial(op.get_op(), **kwargs)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=False
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=False
) as mode,
enable_python_dispatcher(),
):
torch.autograd.gradcheck(func, args)
self.check_decomposed(aten_name, mode)
@ -677,9 +680,12 @@ class TestDecomp(TestCase):
module_input.forward_input.args,
module_input.forward_input.kwargs,
)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
), enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
),
enable_python_dispatcher(),
):
decomp_out = m(*args, **kwargs)
non_decomp_out = m(*args, **kwargs)
@ -955,9 +961,12 @@ def forward(self, scores_1, mask_1, value_1):
# store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode
# for each region
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
@ -974,9 +983,12 @@ def forward(self, scores_1, mask_1, value_1):
):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
decomp_vjp_fn(cotangents)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
@ -993,9 +1005,12 @@ def forward(self, scores_1, mask_1, value_1):
kwargs = sample_input.kwargs
# A failure here might be because the decomposition for the op is wrong or because a
# decomposition used by the particular op is wrong.
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
func(*args, **kwargs)
if run_without_python_dispatcher(mode):

View File

@ -255,9 +255,11 @@ class TestForeach(TestCase):
else inputs
)
try:
with InplaceForeachVersionBumpCheck(
self, inputs[0]
) if op.is_inplace else nullcontext():
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
@ -278,9 +280,11 @@ class TestForeach(TestCase):
try:
op_kwargs = {}
op_kwargs.update(kwargs)
with InplaceForeachVersionBumpCheck(
self, inputs[0]
) if op.is_inplace else nullcontext():
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):

View File

@ -2782,8 +2782,10 @@ class TestFakeTensor(TestCase):
with torch._subclasses.CrossRefFakeMode(
ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True
):
with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(
False
with (
warnings.catch_warnings(),
context(),
torch.autograd.set_multithreading_enabled(False),
):
composite_compliance.compute_expected_grads(
op.get_op(),

View File

@ -866,9 +866,10 @@ class TracingContext:
@contextlib.contextmanager
def clear_frame():
tc = TracingContext.get()
with unittest.mock.patch.object(
tc, "frame_summary_stack", []
), unittest.mock.patch.object(tc, "loc_in_frame", None):
with (
unittest.mock.patch.object(tc, "frame_summary_stack", []),
unittest.mock.patch.object(tc, "loc_in_frame", None),
):
try:
yield
except Exception as e:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from typing import Any, Callable, List, Literal, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
from sympy import Expr, symbols
@ -160,8 +160,8 @@ class CUDATemplateKernel(CUDAKernel):
def __init__(
self,
kernel_name: str,
runtime_arg_info: List["ArgInfo"],
runtime_arg_values: List[Any],
runtime_arg_info: list["ArgInfo"],
runtime_arg_values: list[Any],
) -> None:
"""
Initializes a new instance of the CUDATemplateKernel class.

View File

@ -834,7 +834,9 @@ class OpOverload(OperatorBase):
if curr_mode not in self.python_key_table:
if isinstance(self, TorchBindOpOverload):
with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
with (
torch.utils._python_dispatch._pop_mode_temporarily() as mode
):
return torch._library.utils.handle_dispatch_mode(
mode, self, *args, **kwargs
)

View File

@ -569,9 +569,13 @@ class Exporter:
# https://github.com/pytorch/pytorch/issues/103764
from torch.onnx._internal.fx import decomposition_skip
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
self.options
), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
with (
self.options.diagnostic_context,
decomposition_skip.enable_decomposition_skips(self.options),
torch._dynamo.config.patch(
dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
),
):
graph_module = self.options.fx_tracer.generate_fx(
self.options, self.model, self.model_args, self.model_kwargs
)

View File

@ -8,7 +8,8 @@ from __future__ import annotations
__all__ = ["onnx_impl", "get_torchlib_ops"]
import logging
from typing import Any, Callable, Sequence, TypeVar
from collections.abc import Sequence
from typing import Any, Callable, TypeVar
import onnxscript

View File

@ -68,7 +68,11 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode:
with (
fake_tensor.unset_fake_temporarily(),
python_dispatch.enable_python_dispatcher(),
fake_mode,
):
decomposed_module = proxy_tensor.make_fx(
module,
decomposition_table=self.decomposition_table,

View File

@ -179,13 +179,12 @@ def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
.. deprecated:: 2.7
Please set training mode before exporting the model.
"""
with select_model_mode_for_export(
model, mode
) as mode_ctx, disable_apex_o2_state_dict_hook(
model
) as apex_ctx, setup_onnx_logging(
verbose
) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx:
with (
select_model_mode_for_export(model, mode) as mode_ctx,
disable_apex_o2_state_dict_hook(model) as apex_ctx,
setup_onnx_logging(verbose) as log_ctx,
diagnostics.create_export_diagnostic_context() as diagnostic_ctx,
):
yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx)

View File

@ -1606,9 +1606,12 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
return saved_id[0]
return deserialized_objects[int(saved_id)]
with closing(
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
) as tar, mkdtemp() as tmpdir:
with (
closing(
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
) as tar,
mkdtemp() as tmpdir,
):
if pickle_module is _weights_only_unpickler:
raise RuntimeError(
"Cannot use ``weights_only=True`` with files saved in the "

View File

@ -6,7 +6,8 @@ import os
import contextlib
import torch._logging
import torch._logging._internal
from typing import Callable, ContextManager, List, Tuple
from contextlib import AbstractContextManager
from typing import Callable
from torch._dynamo.utils import LazyString
from torch._inductor import config as inductor_config
import logging
@ -214,7 +215,7 @@ def logs_to_string(module, log_option):
return log_stream, ctx_manager
def multiple_logs_to_string(module: str, *log_options: str) -> Tuple[List[io.StringIO], Callable[[], ContextManager[None]]]:
def multiple_logs_to_string(module: str, *log_options: str) -> tuple[list[io.StringIO], Callable[[], AbstractContextManager[None]]]:
"""Example:
multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs")
returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the
@ -234,7 +235,7 @@ def multiple_logs_to_string(module: str, *log_options: str) -> Tuple[List[io.Str
for logger, handler in zip(loggers, handlers):
logger.removeHandler(handler)
def ctx_manager() -> ContextManager[None]:
def ctx_manager() -> AbstractContextManager[None]:
exit_stack = log_settings(", ".join(log_options))
exit_stack.enter_context(tmp_redirect_logs())
return exit_stack # type: ignore[return-value]

View File

@ -10,7 +10,7 @@ from collections.abc import Sequence
# targets fail to typecheck with:
# TypeError: Cannot create a consistent method resolution order (MRO) for
# bases Iterable, Generic
from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP006
from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP035
from typing_extensions import deprecated
# No 'default_generator' in torch/__init__.pyi