mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - benchmarks tools torchgen (#145101)
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc). Most of the PRs were completely automated with RUFF as follows:
Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes:
```
--- a/tools/linter/adapters/ruff_linter.py
+++ b/tools/linter/adapters/ruff_linter.py
@@ -313,6 +313,7 @@
"ruff",
"check",
"--fix-only",
+ "--unsafe-fixes",
"--exit-zero",
*([f"--config={config}"] if config else []),
"--stdin-filename",
```
Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent):
```
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@
[tool.ruff]
-target-version = "py38"
+target-version = "py39"
line-length = 88
src = ["caffe2", "torch", "torchgen", "functorch", "test"]
@@ -87,7 +87,6 @@
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117",
"SIM118",
- "UP006", # keep-runtime-typing
"UP007", # keep-runtime-typing
]
select = [
```
Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101
Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
2c4281d7da
commit
07669ed960
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
@ -46,7 +46,7 @@ class ExperimentConfig:
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
cal_bandwidth: bool
|
||||
backends: List[str]
|
||||
backends: list[str]
|
||||
|
||||
def __post_init__(self):
|
||||
assert (
|
||||
@ -80,7 +80,7 @@ class ExperimentResults:
|
||||
@dataclass(frozen=True)
|
||||
class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: Dict[str, ExperimentResults] # backend -> ExperimentResults
|
||||
results: dict[str, ExperimentResults] # backend -> ExperimentResults
|
||||
|
||||
def asdict(self):
|
||||
dict1 = self.config.asdict()
|
||||
@ -357,7 +357,7 @@ def run_single_experiment(
|
||||
config: ExperimentConfig,
|
||||
dynamic=False,
|
||||
max_autotune=False,
|
||||
) -> Dict[str, ExperimentResults]:
|
||||
) -> dict[str, ExperimentResults]:
|
||||
device = torch.device("cuda")
|
||||
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
|
||||
query, key, value = generate_inputs(
|
||||
@ -504,7 +504,7 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
|
||||
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
|
||||
|
||||
|
||||
def get_average_speedups(results: List[Experiment], type: str, backend: str):
|
||||
def get_average_speedups(results: list[Experiment], type: str, backend: str):
|
||||
# Calculate speedups
|
||||
speedups = [
|
||||
calculate_speedup(r.results["compiled"], r.results[backend], type)
|
||||
@ -533,7 +533,7 @@ def get_average_speedups(results: List[Experiment], type: str, backend: str):
|
||||
return table_data
|
||||
|
||||
|
||||
def print_results(results: List[Experiment], save_path: Optional[str] = None):
|
||||
def print_results(results: list[Experiment], save_path: Optional[str] = None):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
backends = experiment.config.backends + ["compiled"]
|
||||
@ -1024,16 +1024,16 @@ def generate_eager_sdpa(
|
||||
def generate_experiment_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
batch_sizes: List[int],
|
||||
num_heads: List[tuple[int, int]],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods_str: List[str],
|
||||
batch_sizes: list[int],
|
||||
num_heads: list[tuple[int, int]],
|
||||
seq_lens: list[int],
|
||||
head_dims: list[int],
|
||||
score_mods_str: list[str],
|
||||
decoding: bool,
|
||||
kv_cache_size: List[int],
|
||||
kv_cache_size: list[int],
|
||||
cal_bandwidth: bool,
|
||||
backends: List[str],
|
||||
) -> List[ExperimentConfig]:
|
||||
backends: list[str],
|
||||
) -> list[ExperimentConfig]:
|
||||
assert not (calculate_bwd and decoding), "Decoding does not support backward"
|
||||
|
||||
if decoding:
|
||||
|
||||
Reference in New Issue
Block a user