mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][Ez]: Enable RUF007 Prefer itertools.pairwise over zip slicing (#164856)
Now that our min version is 3.10 we can support this rule. This is more concise, readable, and efficient than the previous zip slicing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164856 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
6861a27062
commit
d1a62c8036
@ -117,7 +117,7 @@ class FwdKernel:
|
||||
def get_all(cls) -> list["FwdKernel"]:
|
||||
kernels: list[FwdKernel] = []
|
||||
for aligned, dtype, (sm, sm_max) in itertools.product(
|
||||
[True, False], DTYPES.keys(), zip(SM, SM[1:])
|
||||
[True, False], DTYPES.keys(), itertools.pairwise(SM)
|
||||
):
|
||||
# Remove some kernels we don't use
|
||||
if dtype == "bf16" and sm < 80:
|
||||
@ -228,7 +228,7 @@ class BwdKernel:
|
||||
for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product(
|
||||
[True, False],
|
||||
DTYPES.keys(),
|
||||
zip(SM, SM[1:]),
|
||||
itertools.pairwise(SM),
|
||||
[True, False],
|
||||
[32, 64, 128, 2**16],
|
||||
):
|
||||
|
@ -242,6 +242,7 @@ select = [
|
||||
"Q003", # avoidable escaped quote
|
||||
"Q004", # unnecessary escaped quote
|
||||
"RSE",
|
||||
"RUF007", # pairwise over zip
|
||||
"RUF008", # mutable dataclass default
|
||||
"RUF013", # ban implicit optional
|
||||
"RUF015", # access first ele in constant time
|
||||
|
@ -1451,7 +1451,7 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
memory_profile = prof._memory_profile()
|
||||
timeline = memory_profile.timeline
|
||||
times = tuple(t for t, _, _, _ in timeline)
|
||||
self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times)
|
||||
self.assertTrue(all(t1 >= t0 for t0, t1 in it.pairwise(times)), times)
|
||||
self.assertTrue(
|
||||
all(
|
||||
(t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0)
|
||||
|
@ -15,6 +15,7 @@ for better performance while maintaining correct semantics.
|
||||
import bisect
|
||||
import dataclasses
|
||||
import dis
|
||||
import itertools
|
||||
import sys
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
@ -110,7 +111,7 @@ def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instructi
|
||||
"""Eliminate jumps to the next instruction"""
|
||||
pointless_jumps = {
|
||||
id(a)
|
||||
for a, b in zip(instructions, instructions[1:])
|
||||
for a, b in itertools.pairwise(instructions)
|
||||
if a.opname == "JUMP_ABSOLUTE" and a.target is b
|
||||
}
|
||||
return [inst for inst in instructions if id(inst) not in pointless_jumps]
|
||||
|
@ -1239,7 +1239,7 @@ def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> Non
|
||||
|
||||
def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst, next_inst in zip(instructions, instructions[1:]):
|
||||
for inst, next_inst in itertools.pairwise(instructions):
|
||||
if (
|
||||
inst.opname == "NOP"
|
||||
and inst.argval == "GRAPH_BREAK_IF_LEAF"
|
||||
|
@ -1872,7 +1872,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
node.meta.pop("creation_timestamp", None)
|
||||
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
for node1, node2 in zip(nodes, nodes[1:]):
|
||||
for node1, node2 in itertools.pairwise(nodes):
|
||||
if (
|
||||
node1.target is torch._C._set_grad_enabled
|
||||
and tuple(node1.args) == (not grad_enabled,)
|
||||
|
@ -824,7 +824,7 @@ class SplitCatSimplifier:
|
||||
return split_ranges
|
||||
|
||||
def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool:
|
||||
for range_, next_range in zip(ranges, ranges[1:]):
|
||||
for range_, next_range in itertools.pairwise(ranges):
|
||||
if range_[1] > next_range[0]:
|
||||
return False
|
||||
return True
|
||||
@ -1477,7 +1477,7 @@ def is_sorted_and_consecutive(arr: list[int]) -> bool:
|
||||
# check if the array is sorted
|
||||
if arr == sorted(arr):
|
||||
# check if the differences between adjacent elements are all 1
|
||||
return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:]))
|
||||
return all(x[1] - x[0] == 1 for x in itertools.pairwise(arr))
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -3700,7 +3700,7 @@ def index_output_size_and_inner_fn(
|
||||
# Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
|
||||
# be pulled to the front.
|
||||
non_consecutive_tensors = False
|
||||
for previous, current in zip(tensor_indices, tensor_indices[1:]):
|
||||
for previous, current in itertools.pairwise(tensor_indices):
|
||||
if current - previous != 1:
|
||||
non_consecutive_tensors = True
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import copyreg
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
@ -1633,7 +1634,7 @@ class Tensor(torch._C.TensorBase):
|
||||
# Check if there are any duplicate strides
|
||||
has_duplicate_strides = any(
|
||||
guard_or_false(earlier == later)
|
||||
for earlier, later in zip(strides, strides[1:])
|
||||
for earlier, later in itertools.pairwise(strides)
|
||||
)
|
||||
|
||||
# Check if there are any singleton dimensions
|
||||
|
@ -293,7 +293,7 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwar
|
||||
|
||||
for sizes in test_cases:
|
||||
tensors = []
|
||||
for size in zip(sizes[:-1], sizes[1:]):
|
||||
for size in itertools.pairwise(sizes):
|
||||
t = make_tensor(
|
||||
size, dtype=dtype, device=device, requires_grad=requires_grad
|
||||
)
|
||||
|
@ -66,6 +66,7 @@ Possible improvements:
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@ -280,7 +281,7 @@ def get_model_info(
|
||||
debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
|
||||
|
||||
code_parts = []
|
||||
for di, di_next in zip(debug_info, debug_info[1:]):
|
||||
for di, di_next in itertools.pairwise(debug_info):
|
||||
start, source_range, *_ = di
|
||||
end = di_next[0]
|
||||
assert end > start
|
||||
|
Reference in New Issue
Block a user