PEP585 update - benchmarks tools torchgen (#145101)

This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc).  Most of the PRs were completely automated with RUFF as follows:

Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes:

```
--- a/tools/linter/adapters/ruff_linter.py
+++ b/tools/linter/adapters/ruff_linter.py
@@ -313,6 +313,7 @@
                     "ruff",
                     "check",
                     "--fix-only",
+                    "--unsafe-fixes",
                     "--exit-zero",
                     *([f"--config={config}"] if config else []),
                     "--stdin-filename",
```

Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent):

```
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@

 [tool.ruff]
-target-version = "py38"
+target-version = "py39"
 line-length = 88
 src = ["caffe2", "torch", "torchgen", "functorch", "test"]

@@ -87,7 +87,6 @@
     "SIM116", # Disable Use a dictionary instead of consecutive `if` statements
     "SIM117",
     "SIM118",
-    "UP006", # keep-runtime-typing
     "UP007", # keep-runtime-typing
 ]
 select = [
```

Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-17 15:13:25 -08:00
committed by PyTorch MergeBot
parent 2c4281d7da
commit 07669ed960
44 changed files with 227 additions and 240 deletions

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Optional, Union
import numpy as np
from tabulate import tabulate
@ -46,7 +46,7 @@ class ExperimentConfig:
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
backends: List[str]
backends: list[str]
def __post_init__(self):
assert (
@ -80,7 +80,7 @@ class ExperimentResults:
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: Dict[str, ExperimentResults] # backend -> ExperimentResults
results: dict[str, ExperimentResults] # backend -> ExperimentResults
def asdict(self):
dict1 = self.config.asdict()
@ -357,7 +357,7 @@ def run_single_experiment(
config: ExperimentConfig,
dynamic=False,
max_autotune=False,
) -> Dict[str, ExperimentResults]:
) -> dict[str, ExperimentResults]:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
@ -504,7 +504,7 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
def get_average_speedups(results: List[Experiment], type: str, backend: str):
def get_average_speedups(results: list[Experiment], type: str, backend: str):
# Calculate speedups
speedups = [
calculate_speedup(r.results["compiled"], r.results[backend], type)
@ -533,7 +533,7 @@ def get_average_speedups(results: List[Experiment], type: str, backend: str):
return table_data
def print_results(results: List[Experiment], save_path: Optional[str] = None):
def print_results(results: list[Experiment], save_path: Optional[str] = None):
table_data = defaultdict(list)
for experiment in results:
backends = experiment.config.backends + ["compiled"]
@ -1024,16 +1024,16 @@ def generate_eager_sdpa(
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: List[int],
num_heads: List[tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods_str: List[str],
batch_sizes: list[int],
num_heads: list[tuple[int, int]],
seq_lens: list[int],
head_dims: list[int],
score_mods_str: list[str],
decoding: bool,
kv_cache_size: List[int],
kv_cache_size: list[int],
cal_bandwidth: bool,
backends: List[str],
) -> List[ExperimentConfig]:
backends: list[str],
) -> list[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
if decoding: