Update ruff linter for PEP585 (#147540)

This turns on PEP585 enforcement in RUFF.

- Updates the target python version
- Stops ignoring UP006 warnings (PEP585)
- Fixes a few issues which crept into the tree in the last day

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147540
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-02-21 07:57:30 -08:00
committed by PyTorch MergeBot
parent 77d2780657
commit 086d146f6f
19 changed files with 131 additions and 88 deletions

View File

@ -580,9 +580,12 @@ class TestDecomp(TestCase):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
func = partial(op.get_op(), **kwargs)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=False
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=False
) as mode,
enable_python_dispatcher(),
):
torch.autograd.gradcheck(func, args)
self.check_decomposed(aten_name, mode)
@ -677,9 +680,12 @@ class TestDecomp(TestCase):
module_input.forward_input.args,
module_input.forward_input.kwargs,
)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
), enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
),
enable_python_dispatcher(),
):
decomp_out = m(*args, **kwargs)
non_decomp_out = m(*args, **kwargs)
@ -955,9 +961,12 @@ def forward(self, scores_1, mask_1, value_1):
# store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode
# for each region
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
@ -974,9 +983,12 @@ def forward(self, scores_1, mask_1, value_1):
):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
decomp_vjp_fn(cotangents)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
@ -993,9 +1005,12 @@ def forward(self, scores_1, mask_1, value_1):
kwargs = sample_input.kwargs
# A failure here might be because the decomposition for the op is wrong or because a
# decomposition used by the particular op is wrong.
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode,
enable_python_dispatcher(),
):
func(*args, **kwargs)
if run_without_python_dispatcher(mode):