From d1a62c80363cf769552453eed187e935f905737d Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 7 Oct 2025 22:51:13 +0000 Subject: [PATCH] [BE][Ez]: Enable RUF007 Prefer itertools.pairwise over zip slicing (#164856) Now that our min version is 3.10 we can support this rule. This is more concise, readable, and efficient than the previous zip slicing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164856 Approved by: https://github.com/williamwen42 --- .../cuda/mem_eff_attention/kernels/generate_kernels.py | 4 ++-- pyproject.toml | 1 + test/profiler/test_memory_profiler.py | 2 +- torch/_dynamo/bytecode_analysis.py | 3 ++- torch/_dynamo/bytecode_transformation.py | 2 +- torch/_dynamo/output_graph.py | 2 +- torch/_inductor/fx_passes/split_cat.py | 4 ++-- torch/_inductor/lowering.py | 2 +- torch/_tensor.py | 3 ++- torch/testing/_internal/opinfo/definitions/linalg.py | 2 +- torch/utils/model_dump/__init__.py | 3 ++- 11 files changed, 16 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py index 2ef59f42140b..7b83617f643d 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py @@ -117,7 +117,7 @@ class FwdKernel: def get_all(cls) -> list["FwdKernel"]: kernels: list[FwdKernel] = [] for aligned, dtype, (sm, sm_max) in itertools.product( - [True, False], DTYPES.keys(), zip(SM, SM[1:]) + [True, False], DTYPES.keys(), itertools.pairwise(SM) ): # Remove some kernels we don't use if dtype == "bf16" and sm < 80: @@ -228,7 +228,7 @@ class BwdKernel: for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product( [True, False], DTYPES.keys(), - zip(SM, SM[1:]), + itertools.pairwise(SM), [True, False], [32, 64, 128, 2**16], ): diff --git a/pyproject.toml b/pyproject.toml index 152a210d61eb..8a2823258916 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,6 +242,7 @@ select = [ "Q003", # avoidable escaped quote "Q004", # unnecessary escaped quote "RSE", + "RUF007", # pairwise over zip "RUF008", # mutable dataclass default "RUF013", # ban implicit optional "RUF015", # access first ele in constant time diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index 814074d768eb..f9821d1bf3a2 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -1451,7 +1451,7 @@ class TestMemoryProfilerE2E(TestCase): memory_profile = prof._memory_profile() timeline = memory_profile.timeline times = tuple(t for t, _, _, _ in timeline) - self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times) + self.assertTrue(all(t1 >= t0 for t0, t1 in it.pairwise(times)), times) self.assertTrue( all( (t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 8bdf155e0060..3ccbd56bfada 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -15,6 +15,7 @@ for better performance while maintaining correct semantics. import bisect import dataclasses import dis +import itertools import sys from typing import Any, TYPE_CHECKING, Union @@ -110,7 +111,7 @@ def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instructi """Eliminate jumps to the next instruction""" pointless_jumps = { id(a) - for a, b in zip(instructions, instructions[1:]) + for a, b in itertools.pairwise(instructions) if a.opname == "JUMP_ABSOLUTE" and a.target is b } return [inst for inst in instructions if id(inst) not in pointless_jumps] diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index c4ee0c49a1f3..48d667319a11 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1239,7 +1239,7 @@ def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> Non def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None: new_insts = [] - for inst, next_inst in zip(instructions, instructions[1:]): + for inst, next_inst in itertools.pairwise(instructions): if ( inst.opname == "NOP" and inst.argval == "GRAPH_BREAK_IF_LEAF" diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index cb8d67582230..9f0e40ffbf9f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1872,7 +1872,7 @@ class OutputGraph(OutputGraphCommon): node.meta.pop("creation_timestamp", None) grad_enabled = torch.is_grad_enabled() - for node1, node2 in zip(nodes, nodes[1:]): + for node1, node2 in itertools.pairwise(nodes): if ( node1.target is torch._C._set_grad_enabled and tuple(node1.args) == (not grad_enabled,) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index af3631dc3288..899960ac435c 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -824,7 +824,7 @@ class SplitCatSimplifier: return split_ranges def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool: - for range_, next_range in zip(ranges, ranges[1:]): + for range_, next_range in itertools.pairwise(ranges): if range_[1] > next_range[0]: return False return True @@ -1477,7 +1477,7 @@ def is_sorted_and_consecutive(arr: list[int]) -> bool: # check if the array is sorted if arr == sorted(arr): # check if the differences between adjacent elements are all 1 - return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + return all(x[1] - x[0] == 1 for x in itertools.pairwise(arr)) else: return False diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 8dd7319e62f3..77f0f32d54e7 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3700,7 +3700,7 @@ def index_output_size_and_inner_fn( # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will # be pulled to the front. non_consecutive_tensors = False - for previous, current in zip(tensor_indices, tensor_indices[1:]): + for previous, current in itertools.pairwise(tensor_indices): if current - previous != 1: non_consecutive_tensors = True diff --git a/torch/_tensor.py b/torch/_tensor.py index a07fc65aee0a..52e3a2fda8fb 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -2,6 +2,7 @@ import copyreg import enum import functools +import itertools import warnings from collections import OrderedDict from collections.abc import Callable @@ -1633,7 +1634,7 @@ class Tensor(torch._C.TensorBase): # Check if there are any duplicate strides has_duplicate_strides = any( guard_or_false(earlier == later) - for earlier, later in zip(strides, strides[1:]) + for earlier, later in itertools.pairwise(strides) ) # Check if there are any singleton dimensions diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 3e6658741fef..ae5a468ddd6a 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -293,7 +293,7 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwar for sizes in test_cases: tensors = [] - for size in zip(sizes[:-1], sizes[1:]): + for size in itertools.pairwise(sizes): t = make_tensor( size, dtype=dtype, device=device, requires_grad=requires_grad ) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 7d6a6890e4ce..dd56877c6cb8 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -66,6 +66,7 @@ Possible improvements: import argparse import io +import itertools import json import os import pickle @@ -280,7 +281,7 @@ def get_model_info( debug_info.append((len(raw_code), (('', '', 0), 0, 0))) code_parts = [] - for di, di_next in zip(debug_info, debug_info[1:]): + for di, di_next in itertools.pairwise(debug_info): start, source_range, *_ = di end = di_next[0] assert end > start