[BE][Ez]: Enable RUF007 Prefer itertools.pairwise over zip slicing (#164856)

Now that our min version is 3.10 we can support this rule. This is more concise, readable, and efficient than the previous zip slicing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164856
Approved by: https://github.com/williamwen42
This commit is contained in:
Aaron Gokaslan
2025-10-07 22:51:13 +00:00
committed by PyTorch MergeBot
parent 6861a27062
commit d1a62c8036
11 changed files with 16 additions and 12 deletions

View File

@ -117,7 +117,7 @@ class FwdKernel:
def get_all(cls) -> list["FwdKernel"]:
kernels: list[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), zip(SM, SM[1:])
[True, False], DTYPES.keys(), itertools.pairwise(SM)
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
@ -228,7 +228,7 @@ class BwdKernel:
for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product(
[True, False],
DTYPES.keys(),
zip(SM, SM[1:]),
itertools.pairwise(SM),
[True, False],
[32, 64, 128, 2**16],
):

View File

@ -242,6 +242,7 @@ select = [
"Q003", # avoidable escaped quote
"Q004", # unnecessary escaped quote
"RSE",
"RUF007", # pairwise over zip
"RUF008", # mutable dataclass default
"RUF013", # ban implicit optional
"RUF015", # access first ele in constant time

View File

@ -1451,7 +1451,7 @@ class TestMemoryProfilerE2E(TestCase):
memory_profile = prof._memory_profile()
timeline = memory_profile.timeline
times = tuple(t for t, _, _, _ in timeline)
self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times)
self.assertTrue(all(t1 >= t0 for t0, t1 in it.pairwise(times)), times)
self.assertTrue(
all(
(t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0)

View File

@ -15,6 +15,7 @@ for better performance while maintaining correct semantics.
import bisect
import dataclasses
import dis
import itertools
import sys
from typing import Any, TYPE_CHECKING, Union
@ -110,7 +111,7 @@ def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instructi
"""Eliminate jumps to the next instruction"""
pointless_jumps = {
id(a)
for a, b in zip(instructions, instructions[1:])
for a, b in itertools.pairwise(instructions)
if a.opname == "JUMP_ABSOLUTE" and a.target is b
}
return [inst for inst in instructions if id(inst) not in pointless_jumps]

View File

@ -1239,7 +1239,7 @@ def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> Non
def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
new_insts = []
for inst, next_inst in zip(instructions, instructions[1:]):
for inst, next_inst in itertools.pairwise(instructions):
if (
inst.opname == "NOP"
and inst.argval == "GRAPH_BREAK_IF_LEAF"

View File

@ -1872,7 +1872,7 @@ class OutputGraph(OutputGraphCommon):
node.meta.pop("creation_timestamp", None)
grad_enabled = torch.is_grad_enabled()
for node1, node2 in zip(nodes, nodes[1:]):
for node1, node2 in itertools.pairwise(nodes):
if (
node1.target is torch._C._set_grad_enabled
and tuple(node1.args) == (not grad_enabled,)

View File

@ -824,7 +824,7 @@ class SplitCatSimplifier:
return split_ranges
def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool:
for range_, next_range in zip(ranges, ranges[1:]):
for range_, next_range in itertools.pairwise(ranges):
if range_[1] > next_range[0]:
return False
return True
@ -1477,7 +1477,7 @@ def is_sorted_and_consecutive(arr: list[int]) -> bool:
# check if the array is sorted
if arr == sorted(arr):
# check if the differences between adjacent elements are all 1
return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:]))
return all(x[1] - x[0] == 1 for x in itertools.pairwise(arr))
else:
return False

View File

@ -3700,7 +3700,7 @@ def index_output_size_and_inner_fn(
# Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
# be pulled to the front.
non_consecutive_tensors = False
for previous, current in zip(tensor_indices, tensor_indices[1:]):
for previous, current in itertools.pairwise(tensor_indices):
if current - previous != 1:
non_consecutive_tensors = True

View File

@ -2,6 +2,7 @@
import copyreg
import enum
import functools
import itertools
import warnings
from collections import OrderedDict
from collections.abc import Callable
@ -1633,7 +1634,7 @@ class Tensor(torch._C.TensorBase):
# Check if there are any duplicate strides
has_duplicate_strides = any(
guard_or_false(earlier == later)
for earlier, later in zip(strides, strides[1:])
for earlier, later in itertools.pairwise(strides)
)
# Check if there are any singleton dimensions

View File

@ -293,7 +293,7 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwar
for sizes in test_cases:
tensors = []
for size in zip(sizes[:-1], sizes[1:]):
for size in itertools.pairwise(sizes):
t = make_tensor(
size, dtype=dtype, device=device, requires_grad=requires_grad
)

View File

@ -66,6 +66,7 @@ Possible improvements:
import argparse
import io
import itertools
import json
import os
import pickle
@ -280,7 +281,7 @@ def get_model_info(
debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
code_parts = []
for di, di_next in zip(debug_info, debug_info[1:]):
for di, di_next in itertools.pairwise(debug_info):
start, source_range, *_ = di
end = di_next[0]
assert end > start