PEP585 update - benchmarks tools torchgen (#145101)

This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc).  Most of the PRs were completely automated with RUFF as follows:

Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes:

```
--- a/tools/linter/adapters/ruff_linter.py
+++ b/tools/linter/adapters/ruff_linter.py
@@ -313,6 +313,7 @@
                     "ruff",
                     "check",
                     "--fix-only",
+                    "--unsafe-fixes",
                     "--exit-zero",
                     *([f"--config={config}"] if config else []),
                     "--stdin-filename",
```

Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent):

```
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@

 [tool.ruff]
-target-version = "py38"
+target-version = "py39"
 line-length = 88
 src = ["caffe2", "torch", "torchgen", "functorch", "test"]

@@ -87,7 +87,6 @@
     "SIM116", # Disable Use a dictionary instead of consecutive `if` statements
     "SIM117",
     "SIM118",
-    "UP006", # keep-runtime-typing
     "UP007", # keep-runtime-typing
 ]
 select = [
```

Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-17 15:13:25 -08:00
committed by PyTorch MergeBot
parent 2c4281d7da
commit 07669ed960
44 changed files with 227 additions and 240 deletions

View File

@ -3,12 +3,11 @@
import argparse
import os
import sys
from typing import Set
# Note - hf and timm have their own version of this, torchbench does not
# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
def model_names(filename: str) -> Set[str]:
def model_names(filename: str) -> set[str]:
names = set()
with open(filename) as fh:
lines = fh.readlines()

View File

@ -23,18 +23,7 @@ import time
import weakref
from contextlib import contextmanager
from pathlib import Path
from typing import (
Any,
Callable,
Generator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Type,
TYPE_CHECKING,
)
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING
from typing_extensions import Self
from unittest.mock import MagicMock
@ -97,6 +86,8 @@ except ImportError:
if TYPE_CHECKING:
from collections.abc import Generator, Mapping, Sequence
from torch.onnx._internal.fx import diagnostics
@ -1779,7 +1770,7 @@ class OnnxModel(abc.ABC):
for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs)
}
def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[npt.NDArray]) -> Any:
def adapt_onnx_outputs_to_pt(self, onnx_outputs: list[npt.NDArray]) -> Any:
pt_outputs = [
torch.from_numpy(onnx_output).to(current_device)
for onnx_output in onnx_outputs
@ -2217,11 +2208,11 @@ class OnnxExportErrorRow:
)
@property
def headers(self) -> List[str]:
def headers(self) -> list[str]:
return [field.name for field in dataclasses.fields(self)]
@property
def row(self) -> List[str]:
def row(self) -> list[str]:
return [getattr(self, field.name) for field in dataclasses.fields(self)]
@ -2271,7 +2262,7 @@ class OnnxContext:
def optimize_onnx_ctx(
output_directory: str,
onnx_model_cls: Type[OnnxModel],
onnx_model_cls: type[OnnxModel],
run_n_iterations: Callable,
dynamic_shapes: bool = False,
copy_before_export: bool = False,

View File

@ -3,8 +3,9 @@ import logging
import math
import os
from collections import Counter, defaultdict
from collections.abc import Generator, Iterable
from functools import partial
from typing import Any, Dict, Generator, Iterable
from typing import Any
import torch
from torch.testing import make_tensor
@ -263,7 +264,7 @@ class OperatorInputsLoader:
def get_inputs_for_operator(
self, operator, dtype=None, device="cuda"
) -> Generator[tuple[Iterable[Any], Dict[str, Any]], None, None]:
) -> Generator[tuple[Iterable[Any], dict[str, Any]], None, None]:
assert (
str(operator) in self.operator_db
), f"Could not find {operator}, must provide overload"

View File

@ -1,7 +1,6 @@
import numbers
import warnings
from collections import namedtuple
from typing import List
import torch
import torch.jit as jit
@ -115,7 +114,7 @@ def script_lnlstm(
LSTMState = namedtuple("LSTMState", ["hx", "cx"])
def reverse(lst: List[Tensor]) -> List[Tensor]:
def reverse(lst: list[Tensor]) -> list[Tensor]:
return lst[::-1]
@ -228,7 +227,7 @@ class LSTMLayer(jit.ScriptModule):
self, input: Tensor, state: tuple[Tensor, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
outputs = torch.jit.annotate(list[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
@ -245,7 +244,7 @@ class ReverseLSTMLayer(jit.ScriptModule):
self, input: Tensor, state: tuple[Tensor, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
inputs = reverse(input.unbind(0))
outputs = jit.annotate(List[Tensor], [])
outputs = jit.annotate(list[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
@ -266,11 +265,11 @@ class BidirLSTMLayer(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
self, input: Tensor, states: list[tuple[Tensor, Tensor]]
) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
# List[LSTMState]: [forward LSTMState, backward LSTMState]
outputs = jit.annotate(List[Tensor], [])
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
outputs = jit.annotate(list[Tensor], [])
output_states = jit.annotate(list[tuple[Tensor, Tensor]], [])
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for direction in self.directions:
@ -300,10 +299,10 @@ class StackedLSTM(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
self, input: Tensor, states: list[tuple[Tensor, Tensor]]
) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
output_states = jit.annotate(list[tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
@ -330,11 +329,11 @@ class StackedLSTM2(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[List[tuple[Tensor, Tensor]]]
) -> tuple[Tensor, List[List[tuple[Tensor, Tensor]]]]:
self, input: Tensor, states: list[list[tuple[Tensor, Tensor]]]
) -> tuple[Tensor, list[list[tuple[Tensor, Tensor]]]]:
# List[List[LSTMState]]: The outer list is for layers,
# inner list is for directions.
output_states = jit.annotate(List[List[tuple[Tensor, Tensor]]], [])
output_states = jit.annotate(list[list[tuple[Tensor, Tensor]]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
@ -370,10 +369,10 @@ class StackedLSTMWithDropout(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
self, input: Tensor, states: list[tuple[Tensor, Tensor]]
) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
output_states = jit.annotate(list[tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0

View File

@ -1,5 +1,4 @@
from collections import namedtuple
from typing import List
import torch
from torch import Tensor
@ -265,13 +264,13 @@ def varlen_pytorch_lstm_creator(**kwargs):
def varlen_lstm_factory(cell, script):
def dynamic_rnn(
sequences: List[Tensor],
sequences: list[Tensor],
hiddens: tuple[Tensor, Tensor],
wih: Tensor,
whh: Tensor,
bih: Tensor,
bhh: Tensor,
) -> tuple[List[Tensor], tuple[List[Tensor], List[Tensor]]]:
) -> tuple[list[Tensor], tuple[list[Tensor], list[Tensor]]]:
hx, cx = hiddens
hxs = hx.unbind(1)
cxs = cx.unbind(1)
@ -506,7 +505,7 @@ def lstm_factory_simple(cell, script):
def lstm_factory_multilayer(cell, script):
def dynamic_rnn(
input: Tensor, hidden: tuple[Tensor, Tensor], params: List[Tensor]
input: Tensor, hidden: tuple[Tensor, Tensor], params: list[Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
params_stride = 4 # NB: this assumes that biases are there
hx, cx = hidden

View File

@ -1,7 +1,7 @@
import time
from argparse import ArgumentParser
from collections import defaultdict
from typing import Any, Callable, List, NamedTuple
from typing import Any, Callable, NamedTuple
import torch
from torch.autograd import functional
@ -147,8 +147,8 @@ ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
class ModelDef(NamedTuple):
name: str
getter: GetterType
tasks: List[str]
unsupported: List[str]
tasks: list[str]
unsupported: list[str]
MODELS = [
@ -223,7 +223,7 @@ def run_once_functorch(
def run_model(
model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once
) -> List[float]:
) -> list[float]:
if args.gpu == -1:
device = torch.device("cpu")

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Optional, Union
import torch
from torch import nn, Tensor
@ -14,13 +14,13 @@ GetterType = Callable[[torch.device], GetterReturnType]
VType = Union[None, Tensor, tuple[Tensor, ...]]
# Type used to store timing results. The first key is the model name, the second key
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
TimingResultType = Dict[str, Dict[str, tuple[float, ...]]]
TimingResultType = dict[str, dict[str, tuple[float, ...]]]
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
def _del_nested_attr(obj: nn.Module, names: list[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
@ -32,7 +32,7 @@ def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
def _set_nested_attr(obj: nn.Module, names: list[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
@ -44,7 +44,7 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], list[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
@ -65,7 +65,7 @@ def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
return params, names
def load_weights(mod: nn.Module, names: List[str], params: tuple[Tensor, ...]) -> None:
def load_weights(mod: nn.Module, names: list[str], params: tuple[Tensor, ...]) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left

View File

@ -1,8 +1,8 @@
import dataclasses
from typing import Callable, Dict, Optional
from typing import Callable, Optional
all_experiments: Dict[str, Callable] = {}
all_experiments: dict[str, Callable] = {}
@dataclasses.dataclass

View File

@ -6,7 +6,7 @@ import argparse
import hashlib
import json
import time
from typing import Dict, List, Union
from typing import Union
from core.expand import materialize
from definitions.standard import BENCHMARKS
@ -22,7 +22,7 @@ VERSION = 0
MD5 = "4d55e8abf881ad38bb617a96714c1296"
def main(argv: List[str]) -> None:
def main(argv: list[str]) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--destination", type=str, default=None)
parser.add_argument("--subset", action="store_true")
@ -56,7 +56,7 @@ def main(argv: List[str]) -> None:
results = Runner(work_orders, cadence=30.0).run()
# TODO: Annotate with TypedDict when 3.8 is the minimum supported verson.
grouped_results: Dict[str, Dict[str, List[Union[float, int]]]] = {
grouped_results: dict[str, dict[str, list[Union[float, int]]]] = {
key: {"times": [], "counts": []} for key in keys
}

View File

@ -7,7 +7,7 @@ import enum
import itertools as it
import re
import textwrap
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
from worker.main import WorkerTimerArgs
@ -49,7 +49,7 @@ class AutoLabels:
language: Language
@property
def as_dict(self) -> Dict[str, str]:
def as_dict(self) -> dict[str, str]:
"""Dict representation for CI reporting."""
return {
"runtime": self.runtime.value,
@ -261,7 +261,7 @@ class GroupedBenchmark:
py_block: str = "",
cpp_block: str = "",
num_threads: Union[int, tuple[int, ...]] = 1,
) -> Dict[Union[tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
) -> dict[Union[tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
py_cases, py_setup, py_global_setup = cls._parse_variants(
py_block, Language.PYTHON
)
@ -279,9 +279,9 @@ class GroupedBenchmark:
# NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused
# and we use the superset `Union[Tuple[str, ...], Optional[str]` to
# match the expected signature.
variants: Dict[Union[tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
variants: dict[Union[tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
seen_labels: Set[str] = set()
seen_labels: set[str] = set()
for label in it.chain(py_cases.keys(), cpp_cases.keys()):
if label in seen_labels:
continue
@ -415,13 +415,13 @@ class GroupedBenchmark:
@staticmethod
def _parse_variants(
block: str, language: Language
) -> tuple[Dict[str, List[str]], str, str]:
) -> tuple[dict[str, list[str]], str, str]:
block = textwrap.dedent(block).strip()
comment = "#" if language == Language.PYTHON else "//"
label_pattern = f"{comment} @(.+)$"
label = ""
lines_by_label: Dict[str, List[str]] = {"SETUP": [], "GLOBAL_SETUP": []}
lines_by_label: dict[str, list[str]] = {"SETUP": [], "GLOBAL_SETUP": []}
for line in block.splitlines(keepends=False):
match = re.search(label_pattern, line.strip())
if match:

View File

@ -12,7 +12,7 @@ import os
import re
import textwrap
import uuid
from typing import List, Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
import torch
@ -204,7 +204,7 @@ def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition:
GroupedBenchmarks into multiple TimerArgs, and tagging the results with
AutoLabels.
"""
results: List[tuple[Label, AutoLabels, TimerArgs]] = []
results: list[tuple[Label, AutoLabels, TimerArgs]] = []
for label, args in benchmarks.items():
if isinstance(args, TimerArgs):

View File

@ -2,7 +2,7 @@
# mypy: ignore-errors
from typing import Dict, Optional, Union
from typing import Optional, Union
from core.api import AutoLabels, GroupedBenchmark, TimerArgs
@ -71,15 +71,15 @@ _Label = Union[Label, Optional[str]]
_Value = Union[
Union[TimerArgs, GroupedBenchmark],
Dict[_Label, "_Value"],
dict[_Label, "_Value"],
]
Definition = Dict[_Label, _Value]
Definition = dict[_Label, _Value]
# We initially have to parse (flatten) to an intermediate state in order to
# build TorchScript models since multiple entries will share the same model
# artifact.
FlatIntermediateDefinition = Dict[Label, Union[TimerArgs, GroupedBenchmark]]
FlatIntermediateDefinition = dict[Label, Union[TimerArgs, GroupedBenchmark]]
# Final parsed schema.
FlatDefinition = tuple[tuple[Label, AutoLabels, TimerArgs], ...]

View File

@ -3,7 +3,7 @@ import atexit
import re
import shutil
import textwrap
from typing import List, Optional
from typing import Optional
from core.api import GroupedBenchmark, TimerArgs
from core.types import Definition, FlatIntermediateDefinition, Label
@ -70,7 +70,7 @@ def parse_stmts(stmts: str) -> tuple[str, str]:
- The column separator is " | ", not "|". Whitespace matters.
"""
stmts = textwrap.dedent(stmts).strip()
lines: List[str] = stmts.splitlines(keepends=False)
lines: list[str] = stmts.splitlines(keepends=False)
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
@ -87,8 +87,8 @@ def parse_stmts(stmts: str) -> tuple[str, str]:
assert re.search(separation_pattern, lines[1])
py_lines: List[str] = []
cpp_lines: List[str] = []
py_lines: list[str] = []
cpp_lines: list[str] = []
for l in lines[2:]:
l_match = re.search(code_pattern, l)
if l_match is None:

View File

@ -8,7 +8,7 @@ import subprocess
import textwrap
import threading
import time
from typing import Dict, List, Optional, Set, Union
from typing import Optional, Union
from worker.main import WorkerFailure, WorkerOutput
@ -51,11 +51,11 @@ class CorePool:
self._num_cores = max_core_id - min_core_id + 1
print(f"Core pool created: cores {self._min_core_id}-{self._max_core_id}")
self._available: List[bool] = [
self._available: list[bool] = [
True for _ in range(min_core_id, min_core_id + self._num_cores)
]
self._reservations: Dict[str, tuple[int, ...]] = {}
self._reservations: dict[str, tuple[int, ...]] = {}
self._lock = threading.Lock()
def reserve(self, n: int) -> Optional[str]:
@ -96,19 +96,19 @@ class Runner:
self._cadence: float = cadence
# Working state.
self._work_queue: List[WorkOrder] = list(work_items)
self._active_jobs: List[InProgress] = []
self._results: Dict[WorkOrder, WorkerOutput] = {}
self._work_queue: list[WorkOrder] = list(work_items)
self._active_jobs: list[InProgress] = []
self._results: dict[WorkOrder, WorkerOutput] = {}
# Debug information for ETA and error messages.
self._start_time: float = -1
self._durations: Dict[WorkOrder, float] = {}
self._durations: dict[WorkOrder, float] = {}
self._currently_processed: Optional[WorkOrder] = None
if len(work_items) != len(set(work_items)):
raise ValueError("Duplicate work items.")
def run(self) -> Dict[WorkOrder, WorkerOutput]:
def run(self) -> dict[WorkOrder, WorkerOutput]:
try:
return self._run()
@ -137,7 +137,7 @@ class Runner:
self._force_shutdown(verbose=True)
raise
def _run(self) -> Dict[WorkOrder, WorkerOutput]:
def _run(self) -> dict[WorkOrder, WorkerOutput]:
self._start_time = time.time()
self._canary_import()
while self._work_queue or self._active_jobs:
@ -150,7 +150,7 @@ class Runner:
return self._results.copy()
def _update_active_jobs(self) -> None:
active_jobs: List[InProgress] = []
active_jobs: list[InProgress] = []
for job in self._active_jobs:
self._currently_processed = job.work_order
if not job.check_finished():
@ -172,7 +172,7 @@ class Runner:
self._active_jobs.extend(active_jobs)
def _enqueue_new_jobs(self) -> None:
work_queue: List[WorkOrder] = []
work_queue: list[WorkOrder] = []
for i, work_order in enumerate(self._work_queue):
self._currently_processed = work_order
cpu_list = self._core_pool.reserve(work_order.timer_args.num_threads)
@ -249,7 +249,7 @@ class Runner:
def _canary_import(self) -> None:
"""Make sure we can import torch before launching a slew of workers."""
source_cmds: Set[str] = set()
source_cmds: set[str] = set()
for w in self._work_items:
if w.source_cmd is not None:
source_cmds.add(f"{w.source_cmd} && ")

View File

@ -10,7 +10,7 @@ import signal
import subprocess
import time
import uuid
from typing import List, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
from core.api import AutoLabels
from core.types import Label
@ -98,7 +98,7 @@ class _BenchmarkProcess:
@property
def cmd(self) -> str:
cmd: List[str] = []
cmd: list[str] = []
if self._work_order.source_cmd is not None:
cmd.extend([self._work_order.source_cmd, "&&"])

View File

@ -10,7 +10,6 @@ underlying benchmark generation infrastructure in the mean time.
import argparse
import sys
from typing import List
from applications import ci
from core.expand import materialize
@ -19,7 +18,7 @@ from execution.runner import Runner
from execution.work import WorkOrder
def main(argv: List[str]) -> None:
def main(argv: list[str]) -> None:
work_orders = tuple(
WorkOrder(label, autolabels, timer_args, timeout=600, retries=2)
for label, autolabels, timer_args in materialize(BENCHMARKS)

View File

@ -1,5 +1,3 @@
from typing import List
import operator_benchmark as op_bench
import torch
@ -44,7 +42,7 @@ class As_stridedBenchmark(op_bench.TorchBenchmarkBase):
self.set_module_name("as_strided")
def forward(
self, input_one, size: List[int], stride: List[int], storage_offset: int
self, input_one, size: list[int], stride: list[int], storage_offset: int
):
return torch.as_strided(input_one, size, stride, storage_offset)

View File

@ -1,5 +1,4 @@
import random
from typing import List
import operator_benchmark as op_bench
@ -143,7 +142,7 @@ class CatBenchmark(op_bench.TorchBenchmarkBase):
self.inputs = {"result": result, "inputs": inputs, "dim": dim}
self.set_module_name("cat")
def forward(self, result: torch.Tensor, inputs: List[torch.Tensor], dim: int):
def forward(self, result: torch.Tensor, inputs: list[torch.Tensor], dim: int):
return torch.cat(inputs, dim=dim, out=result)

View File

@ -1,5 +1,3 @@
from typing import List
import operator_benchmark as op_bench
import torch
@ -58,7 +56,7 @@ class QCatBenchmark(op_bench.TorchBenchmarkBase):
self.inputs = {"input": self.input, "dim": dim}
self.set_module_name("qcat")
def forward(self, input: List[torch.Tensor], dim: int):
def forward(self, input: list[torch.Tensor], dim: int):
return self.qf.cat(input, dim=dim)

View File

@ -1,5 +1,4 @@
import random
from typing import List
import operator_benchmark as op_bench
@ -79,7 +78,7 @@ class StackBenchmark(op_bench.TorchBenchmarkBase):
self.inputs = {"result": result, "inputs": inputs, "dim": dim}
self.set_module_name("stack")
def forward(self, result: torch.Tensor, inputs: List[torch.Tensor], dim: int):
def forward(self, result: torch.Tensor, inputs: list[torch.Tensor], dim: int):
return torch.stack(inputs, dim=dim, out=result)

View File

@ -1,7 +1,7 @@
import itertools
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List, Union
from typing import Callable, Union
import numpy as np
from tabulate import tabulate
@ -50,7 +50,7 @@ class ExperimentResults:
materialized_mask_time: float
attn_mask_subclass_time: float
def get_entries(self) -> List:
def get_entries(self) -> list:
return [
f"{self.materialized_mask_time:2f}",
f"{self.attn_mask_subclass_time:2f}",
@ -62,7 +62,7 @@ class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> List:
def get_entries(self) -> list:
return self.config.get_entries() + self.results.get_entries()
@ -176,7 +176,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
)
def generate_experiment_configs() -> List[ExperimentConfig]:
def generate_experiment_configs() -> list[ExperimentConfig]:
batch_sizes = [1, 8, 16, 128]
num_heads = [16, 32]
q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
@ -206,7 +206,7 @@ def calculate_speedup(results: ExperimentResults) -> float:
return results.materialized_mask_time / results.attn_mask_subclass_time
def print_results(results: List[Experiment]):
def print_results(results: list[Experiment]):
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Optional, Union
import numpy as np
from tabulate import tabulate
@ -46,7 +46,7 @@ class ExperimentConfig:
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
backends: List[str]
backends: list[str]
def __post_init__(self):
assert (
@ -80,7 +80,7 @@ class ExperimentResults:
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: Dict[str, ExperimentResults] # backend -> ExperimentResults
results: dict[str, ExperimentResults] # backend -> ExperimentResults
def asdict(self):
dict1 = self.config.asdict()
@ -357,7 +357,7 @@ def run_single_experiment(
config: ExperimentConfig,
dynamic=False,
max_autotune=False,
) -> Dict[str, ExperimentResults]:
) -> dict[str, ExperimentResults]:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
@ -504,7 +504,7 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
def get_average_speedups(results: List[Experiment], type: str, backend: str):
def get_average_speedups(results: list[Experiment], type: str, backend: str):
# Calculate speedups
speedups = [
calculate_speedup(r.results["compiled"], r.results[backend], type)
@ -533,7 +533,7 @@ def get_average_speedups(results: List[Experiment], type: str, backend: str):
return table_data
def print_results(results: List[Experiment], save_path: Optional[str] = None):
def print_results(results: list[Experiment], save_path: Optional[str] = None):
table_data = defaultdict(list)
for experiment in results:
backends = experiment.config.backends + ["compiled"]
@ -1024,16 +1024,16 @@ def generate_eager_sdpa(
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: List[int],
num_heads: List[tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods_str: List[str],
batch_sizes: list[int],
num_heads: list[tuple[int, int]],
seq_lens: list[int],
head_dims: list[int],
score_mods_str: list[str],
decoding: bool,
kv_cache_size: List[int],
kv_cache_size: list[int],
cal_bandwidth: bool,
backends: List[str],
) -> List[ExperimentConfig]:
backends: list[str],
) -> list[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
if decoding:

View File

@ -5,7 +5,7 @@ import warnings
from dataclasses import dataclass
from pathlib import Path
from pprint import pprint
from typing import List, Optional
from typing import Optional
import numpy as np
from prettytable import PrettyTable
@ -32,7 +32,7 @@ class ExperimentConfig:
enable_mem_efficient: bool
enable_cudnn: bool
def get_entries(self) -> List:
def get_entries(self) -> list:
return [
self.batch_size,
self.num_heads,
@ -47,7 +47,7 @@ class ExperimentConfig:
]
@classmethod
def get_entry_names(cls) -> List[str]:
def get_entry_names(cls) -> list[str]:
return [
"batch_size",
"num_heads",
@ -69,7 +69,7 @@ class ExperimentResults:
composite_mha_time: float
compiled_composite_mha_time: Optional[float]
def get_entries(self) -> List:
def get_entries(self) -> list:
return [
f"{self.nn_mha_time:2f}",
f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None,
@ -80,7 +80,7 @@ class ExperimentResults:
]
@classmethod
def get_entry_names(cls) -> List[str]:
def get_entry_names(cls) -> list[str]:
return [
"nn_mha_time (\u00b5s)",
"compiled_nn_mha_time (\u00b5s)",
@ -94,7 +94,7 @@ class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> List:
def get_entries(self) -> list:
return self.config.get_entries() + self.results.get_entries()
@ -275,7 +275,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
# Could return generator
def generate_experiments(
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
) -> List[ExperimentConfig]:
) -> list[ExperimentConfig]:
configs = []
for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product(
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
@ -337,7 +337,7 @@ def main(save_path: Optional[Path]):
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
)
experiments: List[Experiment] = []
experiments: list[Experiment] = []
for experiment_config in tqdm(experiment_configs):
experiment = run_single_experiment(experiment_config)
experiments.append(experiment)

View File

@ -2,7 +2,7 @@ import itertools
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from typing import Callable, List
from typing import Callable
from tabulate import tabulate
from tqdm import tqdm
@ -119,7 +119,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
)
def generate_experiment_configs() -> List[ExperimentConfig]:
def generate_experiment_configs() -> list[ExperimentConfig]:
batch_sizes = [
1,
8,
@ -160,7 +160,7 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
return all_configs
def print_results(experiments: List[Experiment]):
def print_results(experiments: list[Experiment]):
table_data = defaultdict(list)
for experiment in experiments:
for key, value in experiment.asdict().items():

View File

@ -8,7 +8,7 @@ import argparse
import ast
import os
import sys
from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined]
from typing import Any # type: ignore[attr-defined]
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.types import (
@ -57,12 +57,12 @@ Flat DB builder
def build_groups_memberships(
pg_config: Any,
) -> Tuple[
List[Group],
Dict[Any, Group],
List[Membership],
Dict[str, Set[Any]],
Dict[Tuple[str, int], str],
) -> tuple[
list[Group],
dict[Any, Group],
list[Membership],
dict[str, set[Any]],
dict[tuple[str, int], str],
]:
"""
pg_config: {
@ -126,12 +126,12 @@ def build_groups_memberships(
def build_collectives(
all_entries: Dict[int, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
_pg_guids: Dict[Tuple[str, int], str],
all_entries: dict[int, list[dict[str, Any]]],
_groups: dict[str, Group],
_memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str],
version: str,
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
) -> tuple[list[Traceback], list[Collective], list[NCCLCall]]:
"""
groups, memberships are the non-flat dicts that are indexable
all_entries is a raw dict from the original dumps:
@ -161,10 +161,10 @@ def build_collectives(
}
"""
major_v, minor_v = get_version_detail(version)
tracebacks: List[Traceback] = []
tracebacks: list[Traceback] = []
collectives: List[Collective] = []
nccl_calls: List[NCCLCall] = []
collectives: list[Collective] = []
nccl_calls: list[NCCLCall] = []
# once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
# instead, just record the remaining ops as NCCLCalls
@ -420,7 +420,7 @@ def build_collectives(
def build_db(
details: Dict[str, Dict[str, Any]], args: argparse.Namespace, version: str
details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str
) -> Database:
if args.verbose:
os.environ["FR_TRACE_VERBOSE_OUTPUT"] = "1"

View File

@ -6,7 +6,8 @@
import argparse
import logging
from typing import Optional, Sequence
from collections.abc import Sequence
from typing import Optional
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger

View File

@ -10,9 +10,8 @@ import os
import pickle
import re
import time
import typing
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Union
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
@ -20,7 +19,7 @@ from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
logger: FlightRecorderLogger = FlightRecorderLogger()
def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
def read_dump(prefix: str, filename: str) -> dict[str, Union[str, int, list[Any]]]:
basename = os.path.basename(filename)
rank = int(basename[len(prefix) :])
@ -45,12 +44,12 @@ def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]
exp = re.compile(r"([\w\-\_]*?)(\d+)$")
def _determine_prefix(files: List[str]) -> str:
def _determine_prefix(files: list[str]) -> str:
"""If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to
infer the common prefix most of the time. But if we can't confidently infer, just fall back to requring the user
to specify it
"""
possible_prefixes: typing.DefaultDict[str, Set[int]] = defaultdict(set)
possible_prefixes: defaultdict[str, set[int]] = defaultdict(set)
for f in files:
m = exp.search(f)
if m:
@ -67,7 +66,7 @@ def _determine_prefix(files: List[str]) -> str:
)
def read_dir(args: argparse.Namespace) -> Tuple[Dict[str, Dict[str, Any]], str]:
def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]:
gc.disable()
prefix = args.prefix
details = {}

View File

@ -10,14 +10,9 @@ from enum import auto, Enum
from typing import ( # type: ignore[attr-defined]
_eval_type,
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TypeVar,
)
@ -33,7 +28,7 @@ class Ref(Generic[T]):
class TypeInfo(NamedTuple):
name: str
fields: List[Tuple[str, Type]] # type: ignore[type-arg]
fields: list[tuple[str, type]] # type: ignore[type-arg]
@classmethod
def from_type(cls, c: T) -> "TypeInfo":
@ -126,15 +121,15 @@ class Collective(NamedTuple):
record_id: int
pg_desc: str
collective_name: str
input_sizes: List[List[int]]
output_sizes: List[List[int]]
expected_ranks: Set[int]
input_sizes: list[list[int]]
output_sizes: list[list[int]]
expected_ranks: set[int]
collective_state: str
collective_frames: List[Dict[str, str]]
collective_frames: list[dict[str, str]]
input_numel: Optional[int] = None
output_numel: Optional[int] = None
missing_ranks: Optional[Set[int]] = None
mismatch_collectives: Optional[Dict[int, "Collective"]] = None
missing_ranks: Optional[set[int]] = None
mismatch_collectives: Optional[dict[int, "Collective"]] = None
type_of_mismatch: Optional[MatchState] = None
@ -145,15 +140,15 @@ class NCCLCall(NamedTuple):
global_rank: int # technically Ref[Process] once we have it
traceback_id: Ref[Traceback]
collective_type: str
sizes: List[List[int]]
sizes: list[list[int]]
class Database(NamedTuple):
groups: List[Group]
memberships: List[Membership]
tracebacks: List[Traceback]
collectives: List[Collective]
ncclcalls: List[NCCLCall]
groups: list[Group]
memberships: list[Membership]
tracebacks: list[Traceback]
collectives: list[Collective]
ncclcalls: list[NCCLCall]
# TODO: We need to add a schema for the following
@ -206,7 +201,7 @@ class EntryState:
log the error info during analysis.
"""
def __init__(self, entry: Dict[str, Any], expected_ranks: Set[int]) -> None:
def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None:
self.pg_name = entry["process_group"][0]
self.desc = entry["process_group"][1]
self.pg_desc = (
@ -221,19 +216,19 @@ class EntryState:
self.collective_state = entry["state"]
self.collective_frames = entry["frames"]
self.expected_ranks = expected_ranks
self.missing_ranks: Set[int]
self.missing_ranks: set[int]
self.input_numel: int
self.output_numel: int
self.errors: Set[Tuple[int, MatchState]]
self.errors: set[tuple[int, MatchState]]
def log(
self,
logger: FlightRecorderLogger,
logger_msg: str,
frame_formatter: Any,
total_numel: Optional[Tuple[int, int]] = None,
errors: Optional[Set[Tuple[int, MatchState]]] = None,
missing_ranks: Optional[Set[int]] = None,
total_numel: Optional[tuple[int, int]] = None,
errors: Optional[set[tuple[int, MatchState]]] = None,
missing_ranks: Optional[set[int]] = None,
) -> None:
logger.info(
logger_msg,
@ -268,9 +263,9 @@ class EntryState:
def to_collective(
self,
id: int,
errors: Optional[Set[Tuple[int, MatchState]]] = None,
idx_map: Optional[Dict[int, int]] = None,
all_entries: Optional[Dict[int, List[Dict[str, Any]]]] = None,
errors: Optional[set[tuple[int, MatchState]]] = None,
idx_map: Optional[dict[int, int]] = None,
all_entries: Optional[dict[int, list[dict[str, Any]]]] = None,
) -> Collective:
if not errors:
return Collective(
@ -340,11 +335,11 @@ class EntryState:
def to_nccl_call(
self,
all_entries: Dict[int, List[Dict[str, Any]]],
idx_map: Dict[int, int],
all_entries: dict[int, list[dict[str, Any]]],
idx_map: dict[int, int],
nccl_call_id: int,
collective_id: Any,
) -> List[NCCLCall]:
) -> list[NCCLCall]:
result = []
for i, k in idx_map.items():
all_entries[i].pop(k)
@ -373,7 +368,7 @@ class Op:
"""
def __init__(
self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str
self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str
):
self.profiling_name = event["profiling_name"]
nccl, name = self.profiling_name.split(":")
@ -412,7 +407,7 @@ class Op:
self.collective_frames = event["frames"]
self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1"
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
def _init_global_src_dst(self, pg_ranks: set[Any]) -> None:
pg_ranks = sorted(pg_ranks)
self._src_g = pg_ranks[self._src] if self._src is not None else None
self._dst_g = pg_ranks[self._dst] if self._dst is not None else None

View File

@ -6,7 +6,7 @@
import argparse
import math
from typing import Any, Dict, List, Set, Tuple
from typing import Any
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.types import (
@ -27,14 +27,14 @@ except ModuleNotFoundError:
logger.debug("tabulate is not installed. Proceeding without it.")
def format_frame(frame: Dict[str, str]) -> str:
def format_frame(frame: dict[str, str]) -> str:
name = frame["name"]
filename = frame["filename"]
line = frame["line"]
return f"{name} at {filename}:{line}"
def format_frames(frames: List[Dict[str, str]]) -> str:
def format_frames(frames: list[dict[str, str]]) -> str:
formatted_frames = []
for frame in frames:
formatted_frames.append(format_frame(frame))
@ -42,9 +42,9 @@ def format_frames(frames: List[Dict[str, str]]) -> str:
def match_one_event(
event_a: Dict[Any, Any],
event_b: Dict[Any, Any],
memberships: Dict[str, Set[Any]],
event_a: dict[Any, Any],
event_b: dict[Any, Any],
memberships: dict[str, set[Any]],
pg_name: str,
) -> MatchState:
op_a = Op(event_a, memberships, pg_name)
@ -53,11 +53,11 @@ def match_one_event(
def match_coalesced_groups(
all_rank_events: Dict[Any, Any],
all_rank_events: dict[Any, Any],
group_size: int,
groups: Dict[str, Group],
memberships: Dict[str, Set[Any]],
_pg_guids: Dict[Tuple[str, int], str],
groups: dict[str, Group],
memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str],
) -> bool:
"""
all_rank_events: {
@ -92,7 +92,7 @@ def match_coalesced_groups(
def visualize_ops(
match: bool,
_pg_guids: Dict[Tuple[str, int], str],
_pg_guids: dict[tuple[str, int], str],
) -> None:
all_ops = {
rank: [
@ -174,7 +174,7 @@ def match_coalesced_groups(
return True
def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int, int]:
def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
input_numel = 0
output_numel = 0
for e in alltoall_cases:
@ -185,10 +185,10 @@ def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int
def find_coalesced_group(
pg_name: str,
entries: List[Dict[str, Any]],
_pg_guids: Dict[Tuple[str, int], str],
entries: list[dict[str, Any]],
_pg_guids: dict[tuple[str, int], str],
rank: int,
) -> List[Tuple[int, Dict[str, Any]]]:
) -> list[tuple[int, dict[str, Any]]]:
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
"""
@ -216,10 +216,10 @@ def find_coalesced_group(
def just_print_entries(
all_entries: Dict[int, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
_pg_guids: Dict[Tuple[str, int], str],
all_entries: dict[int, list[dict[str, Any]]],
_groups: dict[str, Group],
_memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str],
args: argparse.Namespace,
) -> None:
rows = []
@ -257,7 +257,7 @@ def just_print_entries(
def check_no_missing_dump_files(
entries: Dict[int, Any], memberships: List[Membership]
entries: dict[int, Any], memberships: list[Membership]
) -> None:
all_ranks = set()
for membership in memberships:
@ -268,14 +268,14 @@ def check_no_missing_dump_files(
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
def check_version(version_by_ranks: Dict[str, str], version: str) -> None:
def check_version(version_by_ranks: dict[str, str], version: str) -> None:
for rank, v in version_by_ranks.items():
assert (
v == version
), f"Rank {rank} has different version {v} from the given version {version}"
def get_version_detail(version: str) -> Tuple[int, int]:
def get_version_detail(version: str) -> tuple[int, int]:
version = version.split(".")
assert len(version) == 2, f"Invalid version {version}"
major, minor = map(int, version)
@ -283,8 +283,8 @@ def get_version_detail(version: str) -> Tuple[int, int]:
def align_trace_from_beginning(
entries: Dict[int, List[Dict[str, Any]]],
) -> Dict[int, List[Dict[str, Any]]]:
entries: dict[int, list[dict[str, Any]]],
) -> dict[int, list[dict[str, Any]]]:
"""
Align the trace entries by record ID for entries.
This function takes a dictionary of rank names to lists of trace entries as input.

View File

@ -29,7 +29,8 @@ python fr_trace.py <dump dir containing trace files> [-o <output file>]
"""
import pickle
from typing import Optional, Sequence
from collections.abc import Sequence
from typing import Optional
from tools.flight_recorder.components.builder import build_db
from tools.flight_recorder.components.config_manager import JobConfig

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import json
import os
from typing import Any, Callable, cast, Dict
from typing import Any, Callable, cast
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen
@ -72,7 +72,7 @@ def gh_fetch_json_dict(
params: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> dict[str, Any]:
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]:

View File

@ -12,10 +12,14 @@ from enum import Enum
from functools import cached_property
from pathlib import Path
from tokenize import generate_tokens, TokenInfo
from typing import Any, Iterator, Sequence
from typing import Any, TYPE_CHECKING
from typing_extensions import Never
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
FSTRING_START = getattr(token, "FSTRING_START", None) # py3.12+
FSTRING_END = getattr(token, "FSTRING_END", None)
EMPTY_TOKENS = dict.fromkeys(

View File

@ -4,7 +4,7 @@ import sys
import token
from functools import cached_property
from pathlib import Path
from typing import Iterator, Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING
_PARENT = Path(__file__).parent.absolute()
@ -16,6 +16,7 @@ else:
import _linter
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from tokenize import TokenInfo

View File

@ -12,7 +12,7 @@ import sys
import token
from enum import Enum
from pathlib import Path
from typing import List, NamedTuple, Set, TYPE_CHECKING
from typing import NamedTuple, TYPE_CHECKING
_PARENT = Path(__file__).parent.absolute()
@ -45,7 +45,7 @@ class LintMessage(NamedTuple):
LINTER_CODE = "NEWLINE"
CURRENT_FILE_NAME = os.path.basename(__file__)
_MODULE_NAME_ALLOW_LIST: Set[str] = set()
_MODULE_NAME_ALLOW_LIST: set[str] = set()
# Add builtin modules.
if sys.version_info >= (3, 10):
@ -352,7 +352,7 @@ use sys.modules.get("torchrec") or the like.
"""
def check_file(filepath: str) -> List[LintMessage]:
def check_file(filepath: str) -> list[LintMessage]:
path = Path(filepath)
file = _linter.PythonFile("import_linter", path)
lint_messages = []

View File

@ -5,7 +5,7 @@ import sys
import token
from functools import cached_property
from pathlib import Path
from typing import Iterator, Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING
_PARENT = Path(__file__).parent.absolute()
@ -17,6 +17,7 @@ else:
import _linter
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from tokenize import TokenInfo

View File

@ -10,7 +10,6 @@ import tempfile
import time
from collections.abc import Iterator
from pathlib import Path
from typing import Dict, List
logging.basicConfig(
@ -25,7 +24,7 @@ REQUIREMENTS_PATH = ROOT_PATH / "requirements.txt"
def run_cmd(
cmd: List[str], capture_output: bool = False
cmd: list[str], capture_output: bool = False
) -> subprocess.CompletedProcess[bytes]:
logger.debug("Running command: %s", " ".join(cmd))
return subprocess.run(
@ -69,7 +68,7 @@ class Builder:
def __init__(self, interpreter: str) -> None:
self.interpreter = interpreter
def setup_py(self, cmd_args: List[str]) -> bool:
def setup_py(self, cmd_args: list[str]) -> bool:
return (
run_cmd([self.interpreter, str(SETUP_PY_PATH), *cmd_args]).returncode == 0
)
@ -114,7 +113,7 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
args = parse_args()
pythons = args.python or [sys.executable]
build_times: Dict[str, float] = dict()
build_times: dict[str, float] = dict()
if len(pythons) > 1 and args.destination == "dist/":
logger.warning(

View File

@ -6,7 +6,7 @@ import os
import xml.etree.ElementTree as ET
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Generator
from typing import Any, TYPE_CHECKING
from tools.stats.upload_stats_lib import (
download_s3_artifacts,
@ -17,6 +17,10 @@ from tools.stats.upload_stats_lib import (
from tools.stats.upload_test_stats import process_xml_element
if TYPE_CHECKING:
from collections.abc import Generator
TESTCASE_TAG = "testcase"
SEPARATOR = ";"

View File

@ -7,7 +7,7 @@ import json
import os
import shutil
from pathlib import Path
from typing import Any, Callable, cast, Dict
from typing import Any, Callable, cast
from urllib.request import urlopen
@ -61,7 +61,7 @@ def fetch_and_cache(
if os.path.exists(path) and is_cached_file_valid():
# Another test process already download the file, so don't re-do it
with open(path) as f:
return cast(Dict[str, Any], json.load(f))
return cast(dict[str, Any], json.load(f))
for _ in range(3):
try:

View File

@ -2,13 +2,13 @@ import glob
import json
import os
from pathlib import Path
from typing import Any, Dict
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
def flatten_data(d: Dict[str, Any]) -> Dict[str, Any]:
def flatten_data(d: dict[str, Any]) -> dict[str, Any]:
# Flatten the sccache stats data from a possibly nested dictionary to a flat
# dictionary. For example, the input:
# {

View File

@ -8,7 +8,7 @@ import os
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict
from typing import Any
from tools.stats.upload_stats_lib import (
download_s3_artifacts,
@ -86,7 +86,7 @@ def get_perf_stats(
return perf_stats
def generate_partition_key(repo: str, doc: Dict[str, Any]) -> str:
def generate_partition_key(repo: str, doc: dict[str, Any]) -> str:
"""
Generate an unique partition key for the document on DynamoDB
"""

View File

@ -6,7 +6,7 @@ import json
import os
import time
import urllib.parse
from typing import Any, Callable, cast, Dict, List
from typing import Any, Callable, cast
from urllib.error import HTTPError
from urllib.request import Request, urlopen
@ -60,7 +60,7 @@ def fetch_json(
f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items()
)
return cast(
List[Dict[str, Any]],
list[dict[str, Any]],
_fetch_url(url, headers=headers, data=data, reader=json.load),
)
@ -79,7 +79,7 @@ def get_external_pr_data(
responses: list[dict[str, Any]] = []
while len(responses) > 0 or page == 1:
response = cast(
Dict[str, Any],
dict[str, Any],
fetch_json(
"https://api.github.com/search/issues",
params={

View File

@ -9,7 +9,7 @@ import time
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import boto3 # type: ignore[import]
import requests
@ -122,8 +122,8 @@ def download_gha_artifacts(
def upload_to_dynamodb(
dynamodb_table: str,
repo: str,
docs: List[Any],
generate_partition_key: Optional[Callable[[str, Dict[str, Any]], str]],
docs: list[Any],
generate_partition_key: Optional[Callable[[str, dict[str, Any]], str]],
) -> None:
print(f"Writing {len(docs)} documents to DynamoDB {dynamodb_table}")
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/dynamodb.html#batch-writing

View File

@ -1,7 +1,7 @@
import sys
import time
from functools import lru_cache
from typing import Any, List
from functools import cache
from typing import Any
from tools.stats.test_dashboard import upload_additional_info
from tools.stats.upload_stats_lib import get_s3_resource
@ -11,7 +11,7 @@ from tools.stats.upload_test_stats import get_tests
BUCKET_PREFIX = "workflows_failing_pending_upload"
@lru_cache(maxsize=None)
@cache
def get_bucket() -> Any:
return get_s3_resource().Bucket("gha-artifacts")
@ -41,7 +41,7 @@ def do_upload(workflow_id: int) -> None:
upload_additional_info(workflow_id, workflow_attempt, test_cases)
def get_workflow_ids(pending: bool = False) -> List[int]:
def get_workflow_ids(pending: bool = False) -> list[int]:
prefix = f"{BUCKET_PREFIX}/{'pending/' if pending else ''}"
objs = get_bucket().objects.filter(Prefix=prefix)
return [int(obj.key.split("/")[-1].split(".")[0]) for obj in objs]

View File

@ -8,7 +8,7 @@ import os
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Literal, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
import yaml
@ -2324,7 +2324,7 @@ def gen_source_files(
def register_dispatch_key_env_callable(
gnf: NativeFunction | NativeFunctionsGroup,
) -> Dict[str, list[str]]:
) -> dict[str, list[str]]:
return {
"dispatch_definitions": get_native_function_definitions(
fm=fm, # noqa: F821