mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - benchmarks tools torchgen (#145101)
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc). Most of the PRs were completely automated with RUFF as follows: Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes: ``` --- a/tools/linter/adapters/ruff_linter.py +++ b/tools/linter/adapters/ruff_linter.py @@ -313,6 +313,7 @@ "ruff", "check", "--fix-only", + "--unsafe-fixes", "--exit-zero", *([f"--config={config}"] if config else []), "--stdin-filename", ``` Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent): ``` --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -87,7 +87,6 @@ "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM118", - "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing ] select = [ ``` Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
2c4281d7da
commit
07669ed960
@ -3,12 +3,11 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
|
||||
# Note - hf and timm have their own version of this, torchbench does not
|
||||
# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
|
||||
def model_names(filename: str) -> Set[str]:
|
||||
def model_names(filename: str) -> set[str]:
|
||||
names = set()
|
||||
with open(filename) as fh:
|
||||
lines = fh.readlines()
|
||||
|
@ -23,18 +23,7 @@ import time
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
List,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -97,6 +86,8 @@ except ImportError:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
|
||||
from torch.onnx._internal.fx import diagnostics
|
||||
|
||||
|
||||
@ -1779,7 +1770,7 @@ class OnnxModel(abc.ABC):
|
||||
for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs)
|
||||
}
|
||||
|
||||
def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[npt.NDArray]) -> Any:
|
||||
def adapt_onnx_outputs_to_pt(self, onnx_outputs: list[npt.NDArray]) -> Any:
|
||||
pt_outputs = [
|
||||
torch.from_numpy(onnx_output).to(current_device)
|
||||
for onnx_output in onnx_outputs
|
||||
@ -2217,11 +2208,11 @@ class OnnxExportErrorRow:
|
||||
)
|
||||
|
||||
@property
|
||||
def headers(self) -> List[str]:
|
||||
def headers(self) -> list[str]:
|
||||
return [field.name for field in dataclasses.fields(self)]
|
||||
|
||||
@property
|
||||
def row(self) -> List[str]:
|
||||
def row(self) -> list[str]:
|
||||
return [getattr(self, field.name) for field in dataclasses.fields(self)]
|
||||
|
||||
|
||||
@ -2271,7 +2262,7 @@ class OnnxContext:
|
||||
|
||||
def optimize_onnx_ctx(
|
||||
output_directory: str,
|
||||
onnx_model_cls: Type[OnnxModel],
|
||||
onnx_model_cls: type[OnnxModel],
|
||||
run_n_iterations: Callable,
|
||||
dynamic_shapes: bool = False,
|
||||
copy_before_export: bool = False,
|
||||
|
@ -3,8 +3,9 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Generator, Iterable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
@ -263,7 +264,7 @@ class OperatorInputsLoader:
|
||||
|
||||
def get_inputs_for_operator(
|
||||
self, operator, dtype=None, device="cuda"
|
||||
) -> Generator[tuple[Iterable[Any], Dict[str, Any]], None, None]:
|
||||
) -> Generator[tuple[Iterable[Any], dict[str, Any]], None, None]:
|
||||
assert (
|
||||
str(operator) in self.operator_db
|
||||
), f"Could not find {operator}, must provide overload"
|
||||
|
@ -1,7 +1,6 @@
|
||||
import numbers
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.jit as jit
|
||||
@ -115,7 +114,7 @@ def script_lnlstm(
|
||||
LSTMState = namedtuple("LSTMState", ["hx", "cx"])
|
||||
|
||||
|
||||
def reverse(lst: List[Tensor]) -> List[Tensor]:
|
||||
def reverse(lst: list[Tensor]) -> list[Tensor]:
|
||||
return lst[::-1]
|
||||
|
||||
|
||||
@ -228,7 +227,7 @@ class LSTMLayer(jit.ScriptModule):
|
||||
self, input: Tensor, state: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
inputs = input.unbind(0)
|
||||
outputs = torch.jit.annotate(List[Tensor], [])
|
||||
outputs = torch.jit.annotate(list[Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
out, state = self.cell(inputs[i], state)
|
||||
outputs += [out]
|
||||
@ -245,7 +244,7 @@ class ReverseLSTMLayer(jit.ScriptModule):
|
||||
self, input: Tensor, state: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
inputs = reverse(input.unbind(0))
|
||||
outputs = jit.annotate(List[Tensor], [])
|
||||
outputs = jit.annotate(list[Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
out, state = self.cell(inputs[i], state)
|
||||
outputs += [out]
|
||||
@ -266,11 +265,11 @@ class BidirLSTMLayer(jit.ScriptModule):
|
||||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
|
||||
self, input: Tensor, states: list[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: [forward LSTMState, backward LSTMState]
|
||||
outputs = jit.annotate(List[Tensor], [])
|
||||
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
|
||||
outputs = jit.annotate(list[Tensor], [])
|
||||
output_states = jit.annotate(list[tuple[Tensor, Tensor]], [])
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
for direction in self.directions:
|
||||
@ -300,10 +299,10 @@ class StackedLSTM(jit.ScriptModule):
|
||||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
|
||||
self, input: Tensor, states: list[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: One state per layer
|
||||
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
|
||||
output_states = jit.annotate(list[tuple[Tensor, Tensor]], [])
|
||||
output = input
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
@ -330,11 +329,11 @@ class StackedLSTM2(jit.ScriptModule):
|
||||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[List[tuple[Tensor, Tensor]]]
|
||||
) -> tuple[Tensor, List[List[tuple[Tensor, Tensor]]]]:
|
||||
self, input: Tensor, states: list[list[tuple[Tensor, Tensor]]]
|
||||
) -> tuple[Tensor, list[list[tuple[Tensor, Tensor]]]]:
|
||||
# List[List[LSTMState]]: The outer list is for layers,
|
||||
# inner list is for directions.
|
||||
output_states = jit.annotate(List[List[tuple[Tensor, Tensor]]], [])
|
||||
output_states = jit.annotate(list[list[tuple[Tensor, Tensor]]], [])
|
||||
output = input
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
@ -370,10 +369,10 @@ class StackedLSTMWithDropout(jit.ScriptModule):
|
||||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
|
||||
self, input: Tensor, states: list[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: One state per layer
|
||||
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
|
||||
output_states = jit.annotate(list[tuple[Tensor, Tensor]], [])
|
||||
output = input
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
|
@ -1,5 +1,4 @@
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -265,13 +264,13 @@ def varlen_pytorch_lstm_creator(**kwargs):
|
||||
|
||||
def varlen_lstm_factory(cell, script):
|
||||
def dynamic_rnn(
|
||||
sequences: List[Tensor],
|
||||
sequences: list[Tensor],
|
||||
hiddens: tuple[Tensor, Tensor],
|
||||
wih: Tensor,
|
||||
whh: Tensor,
|
||||
bih: Tensor,
|
||||
bhh: Tensor,
|
||||
) -> tuple[List[Tensor], tuple[List[Tensor], List[Tensor]]]:
|
||||
) -> tuple[list[Tensor], tuple[list[Tensor], list[Tensor]]]:
|
||||
hx, cx = hiddens
|
||||
hxs = hx.unbind(1)
|
||||
cxs = cx.unbind(1)
|
||||
@ -506,7 +505,7 @@ def lstm_factory_simple(cell, script):
|
||||
|
||||
def lstm_factory_multilayer(cell, script):
|
||||
def dynamic_rnn(
|
||||
input: Tensor, hidden: tuple[Tensor, Tensor], params: List[Tensor]
|
||||
input: Tensor, hidden: tuple[Tensor, Tensor], params: list[Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
params_stride = 4 # NB: this assumes that biases are there
|
||||
hx, cx = hidden
|
||||
|
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, List, NamedTuple
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import functional
|
||||
@ -147,8 +147,8 @@ ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
|
||||
class ModelDef(NamedTuple):
|
||||
name: str
|
||||
getter: GetterType
|
||||
tasks: List[str]
|
||||
unsupported: List[str]
|
||||
tasks: list[str]
|
||||
unsupported: list[str]
|
||||
|
||||
|
||||
MODELS = [
|
||||
@ -223,7 +223,7 @@ def run_once_functorch(
|
||||
|
||||
def run_model(
|
||||
model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
if args.gpu == -1:
|
||||
device = torch.device("cpu")
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
@ -14,13 +14,13 @@ GetterType = Callable[[torch.device], GetterReturnType]
|
||||
VType = Union[None, Tensor, tuple[Tensor, ...]]
|
||||
# Type used to store timing results. The first key is the model name, the second key
|
||||
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
|
||||
TimingResultType = Dict[str, Dict[str, tuple[float, ...]]]
|
||||
TimingResultType = dict[str, dict[str, tuple[float, ...]]]
|
||||
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
def _del_nested_attr(obj: nn.Module, names: list[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
@ -32,7 +32,7 @@ def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
def _set_nested_attr(obj: nn.Module, names: list[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
@ -44,7 +44,7 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
|
||||
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
|
||||
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], list[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
@ -65,7 +65,7 @@ def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
|
||||
return params, names
|
||||
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: tuple[Tensor, ...]) -> None:
|
||||
def load_weights(mod: nn.Module, names: list[str], params: tuple[Tensor, ...]) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
|
@ -1,8 +1,8 @@
|
||||
import dataclasses
|
||||
from typing import Callable, Dict, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
all_experiments: Dict[str, Callable] = {}
|
||||
all_experiments: dict[str, Callable] = {}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Union
|
||||
from typing import Union
|
||||
|
||||
from core.expand import materialize
|
||||
from definitions.standard import BENCHMARKS
|
||||
@ -22,7 +22,7 @@ VERSION = 0
|
||||
MD5 = "4d55e8abf881ad38bb617a96714c1296"
|
||||
|
||||
|
||||
def main(argv: List[str]) -> None:
|
||||
def main(argv: list[str]) -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--destination", type=str, default=None)
|
||||
parser.add_argument("--subset", action="store_true")
|
||||
@ -56,7 +56,7 @@ def main(argv: List[str]) -> None:
|
||||
results = Runner(work_orders, cadence=30.0).run()
|
||||
|
||||
# TODO: Annotate with TypedDict when 3.8 is the minimum supported verson.
|
||||
grouped_results: Dict[str, Dict[str, List[Union[float, int]]]] = {
|
||||
grouped_results: dict[str, dict[str, list[Union[float, int]]]] = {
|
||||
key: {"times": [], "counts": []} for key in keys
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@ import enum
|
||||
import itertools as it
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
from worker.main import WorkerTimerArgs
|
||||
|
||||
@ -49,7 +49,7 @@ class AutoLabels:
|
||||
language: Language
|
||||
|
||||
@property
|
||||
def as_dict(self) -> Dict[str, str]:
|
||||
def as_dict(self) -> dict[str, str]:
|
||||
"""Dict representation for CI reporting."""
|
||||
return {
|
||||
"runtime": self.runtime.value,
|
||||
@ -261,7 +261,7 @@ class GroupedBenchmark:
|
||||
py_block: str = "",
|
||||
cpp_block: str = "",
|
||||
num_threads: Union[int, tuple[int, ...]] = 1,
|
||||
) -> Dict[Union[tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
|
||||
) -> dict[Union[tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
|
||||
py_cases, py_setup, py_global_setup = cls._parse_variants(
|
||||
py_block, Language.PYTHON
|
||||
)
|
||||
@ -279,9 +279,9 @@ class GroupedBenchmark:
|
||||
# NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused
|
||||
# and we use the superset `Union[Tuple[str, ...], Optional[str]` to
|
||||
# match the expected signature.
|
||||
variants: Dict[Union[tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
|
||||
variants: dict[Union[tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
|
||||
|
||||
seen_labels: Set[str] = set()
|
||||
seen_labels: set[str] = set()
|
||||
for label in it.chain(py_cases.keys(), cpp_cases.keys()):
|
||||
if label in seen_labels:
|
||||
continue
|
||||
@ -415,13 +415,13 @@ class GroupedBenchmark:
|
||||
@staticmethod
|
||||
def _parse_variants(
|
||||
block: str, language: Language
|
||||
) -> tuple[Dict[str, List[str]], str, str]:
|
||||
) -> tuple[dict[str, list[str]], str, str]:
|
||||
block = textwrap.dedent(block).strip()
|
||||
comment = "#" if language == Language.PYTHON else "//"
|
||||
label_pattern = f"{comment} @(.+)$"
|
||||
label = ""
|
||||
|
||||
lines_by_label: Dict[str, List[str]] = {"SETUP": [], "GLOBAL_SETUP": []}
|
||||
lines_by_label: dict[str, list[str]] = {"SETUP": [], "GLOBAL_SETUP": []}
|
||||
for line in block.splitlines(keepends=False):
|
||||
match = re.search(label_pattern, line.strip())
|
||||
if match:
|
||||
|
@ -12,7 +12,7 @@ import os
|
||||
import re
|
||||
import textwrap
|
||||
import uuid
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -204,7 +204,7 @@ def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition:
|
||||
GroupedBenchmarks into multiple TimerArgs, and tagging the results with
|
||||
AutoLabels.
|
||||
"""
|
||||
results: List[tuple[Label, AutoLabels, TimerArgs]] = []
|
||||
results: list[tuple[Label, AutoLabels, TimerArgs]] = []
|
||||
|
||||
for label, args in benchmarks.items():
|
||||
if isinstance(args, TimerArgs):
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.api import AutoLabels, GroupedBenchmark, TimerArgs
|
||||
|
||||
@ -71,15 +71,15 @@ _Label = Union[Label, Optional[str]]
|
||||
|
||||
_Value = Union[
|
||||
Union[TimerArgs, GroupedBenchmark],
|
||||
Dict[_Label, "_Value"],
|
||||
dict[_Label, "_Value"],
|
||||
]
|
||||
|
||||
Definition = Dict[_Label, _Value]
|
||||
Definition = dict[_Label, _Value]
|
||||
|
||||
# We initially have to parse (flatten) to an intermediate state in order to
|
||||
# build TorchScript models since multiple entries will share the same model
|
||||
# artifact.
|
||||
FlatIntermediateDefinition = Dict[Label, Union[TimerArgs, GroupedBenchmark]]
|
||||
FlatIntermediateDefinition = dict[Label, Union[TimerArgs, GroupedBenchmark]]
|
||||
|
||||
# Final parsed schema.
|
||||
FlatDefinition = tuple[tuple[Label, AutoLabels, TimerArgs], ...]
|
||||
|
@ -3,7 +3,7 @@ import atexit
|
||||
import re
|
||||
import shutil
|
||||
import textwrap
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.api import GroupedBenchmark, TimerArgs
|
||||
from core.types import Definition, FlatIntermediateDefinition, Label
|
||||
@ -70,7 +70,7 @@ def parse_stmts(stmts: str) -> tuple[str, str]:
|
||||
- The column separator is " | ", not "|". Whitespace matters.
|
||||
"""
|
||||
stmts = textwrap.dedent(stmts).strip()
|
||||
lines: List[str] = stmts.splitlines(keepends=False)
|
||||
lines: list[str] = stmts.splitlines(keepends=False)
|
||||
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
|
||||
|
||||
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
|
||||
@ -87,8 +87,8 @@ def parse_stmts(stmts: str) -> tuple[str, str]:
|
||||
|
||||
assert re.search(separation_pattern, lines[1])
|
||||
|
||||
py_lines: List[str] = []
|
||||
cpp_lines: List[str] = []
|
||||
py_lines: list[str] = []
|
||||
cpp_lines: list[str] = []
|
||||
for l in lines[2:]:
|
||||
l_match = re.search(code_pattern, l)
|
||||
if l_match is None:
|
||||
|
@ -8,7 +8,7 @@ import subprocess
|
||||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from worker.main import WorkerFailure, WorkerOutput
|
||||
|
||||
@ -51,11 +51,11 @@ class CorePool:
|
||||
self._num_cores = max_core_id - min_core_id + 1
|
||||
print(f"Core pool created: cores {self._min_core_id}-{self._max_core_id}")
|
||||
|
||||
self._available: List[bool] = [
|
||||
self._available: list[bool] = [
|
||||
True for _ in range(min_core_id, min_core_id + self._num_cores)
|
||||
]
|
||||
|
||||
self._reservations: Dict[str, tuple[int, ...]] = {}
|
||||
self._reservations: dict[str, tuple[int, ...]] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def reserve(self, n: int) -> Optional[str]:
|
||||
@ -96,19 +96,19 @@ class Runner:
|
||||
self._cadence: float = cadence
|
||||
|
||||
# Working state.
|
||||
self._work_queue: List[WorkOrder] = list(work_items)
|
||||
self._active_jobs: List[InProgress] = []
|
||||
self._results: Dict[WorkOrder, WorkerOutput] = {}
|
||||
self._work_queue: list[WorkOrder] = list(work_items)
|
||||
self._active_jobs: list[InProgress] = []
|
||||
self._results: dict[WorkOrder, WorkerOutput] = {}
|
||||
|
||||
# Debug information for ETA and error messages.
|
||||
self._start_time: float = -1
|
||||
self._durations: Dict[WorkOrder, float] = {}
|
||||
self._durations: dict[WorkOrder, float] = {}
|
||||
self._currently_processed: Optional[WorkOrder] = None
|
||||
|
||||
if len(work_items) != len(set(work_items)):
|
||||
raise ValueError("Duplicate work items.")
|
||||
|
||||
def run(self) -> Dict[WorkOrder, WorkerOutput]:
|
||||
def run(self) -> dict[WorkOrder, WorkerOutput]:
|
||||
try:
|
||||
return self._run()
|
||||
|
||||
@ -137,7 +137,7 @@ class Runner:
|
||||
self._force_shutdown(verbose=True)
|
||||
raise
|
||||
|
||||
def _run(self) -> Dict[WorkOrder, WorkerOutput]:
|
||||
def _run(self) -> dict[WorkOrder, WorkerOutput]:
|
||||
self._start_time = time.time()
|
||||
self._canary_import()
|
||||
while self._work_queue or self._active_jobs:
|
||||
@ -150,7 +150,7 @@ class Runner:
|
||||
return self._results.copy()
|
||||
|
||||
def _update_active_jobs(self) -> None:
|
||||
active_jobs: List[InProgress] = []
|
||||
active_jobs: list[InProgress] = []
|
||||
for job in self._active_jobs:
|
||||
self._currently_processed = job.work_order
|
||||
if not job.check_finished():
|
||||
@ -172,7 +172,7 @@ class Runner:
|
||||
self._active_jobs.extend(active_jobs)
|
||||
|
||||
def _enqueue_new_jobs(self) -> None:
|
||||
work_queue: List[WorkOrder] = []
|
||||
work_queue: list[WorkOrder] = []
|
||||
for i, work_order in enumerate(self._work_queue):
|
||||
self._currently_processed = work_order
|
||||
cpu_list = self._core_pool.reserve(work_order.timer_args.num_threads)
|
||||
@ -249,7 +249,7 @@ class Runner:
|
||||
|
||||
def _canary_import(self) -> None:
|
||||
"""Make sure we can import torch before launching a slew of workers."""
|
||||
source_cmds: Set[str] = set()
|
||||
source_cmds: set[str] = set()
|
||||
for w in self._work_items:
|
||||
if w.source_cmd is not None:
|
||||
source_cmds.add(f"{w.source_cmd} && ")
|
||||
|
@ -10,7 +10,7 @@ import signal
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
from core.api import AutoLabels
|
||||
from core.types import Label
|
||||
@ -98,7 +98,7 @@ class _BenchmarkProcess:
|
||||
|
||||
@property
|
||||
def cmd(self) -> str:
|
||||
cmd: List[str] = []
|
||||
cmd: list[str] = []
|
||||
if self._work_order.source_cmd is not None:
|
||||
cmd.extend([self._work_order.source_cmd, "&&"])
|
||||
|
||||
|
@ -10,7 +10,6 @@ underlying benchmark generation infrastructure in the mean time.
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from applications import ci
|
||||
from core.expand import materialize
|
||||
@ -19,7 +18,7 @@ from execution.runner import Runner
|
||||
from execution.work import WorkOrder
|
||||
|
||||
|
||||
def main(argv: List[str]) -> None:
|
||||
def main(argv: list[str]) -> None:
|
||||
work_orders = tuple(
|
||||
WorkOrder(label, autolabels, timer_args, timeout=600, retries=2)
|
||||
for label, autolabels, timer_args in materialize(BENCHMARKS)
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
@ -44,7 +42,7 @@ class As_stridedBenchmark(op_bench.TorchBenchmarkBase):
|
||||
self.set_module_name("as_strided")
|
||||
|
||||
def forward(
|
||||
self, input_one, size: List[int], stride: List[int], storage_offset: int
|
||||
self, input_one, size: list[int], stride: list[int], storage_offset: int
|
||||
):
|
||||
return torch.as_strided(input_one, size, stride, storage_offset)
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
@ -143,7 +142,7 @@ class CatBenchmark(op_bench.TorchBenchmarkBase):
|
||||
self.inputs = {"result": result, "inputs": inputs, "dim": dim}
|
||||
self.set_module_name("cat")
|
||||
|
||||
def forward(self, result: torch.Tensor, inputs: List[torch.Tensor], dim: int):
|
||||
def forward(self, result: torch.Tensor, inputs: list[torch.Tensor], dim: int):
|
||||
return torch.cat(inputs, dim=dim, out=result)
|
||||
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
@ -58,7 +56,7 @@ class QCatBenchmark(op_bench.TorchBenchmarkBase):
|
||||
self.inputs = {"input": self.input, "dim": dim}
|
||||
self.set_module_name("qcat")
|
||||
|
||||
def forward(self, input: List[torch.Tensor], dim: int):
|
||||
def forward(self, input: list[torch.Tensor], dim: int):
|
||||
return self.qf.cat(input, dim=dim)
|
||||
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
@ -79,7 +78,7 @@ class StackBenchmark(op_bench.TorchBenchmarkBase):
|
||||
self.inputs = {"result": result, "inputs": inputs, "dim": dim}
|
||||
self.set_module_name("stack")
|
||||
|
||||
def forward(self, result: torch.Tensor, inputs: List[torch.Tensor], dim: int):
|
||||
def forward(self, result: torch.Tensor, inputs: list[torch.Tensor], dim: int):
|
||||
return torch.stack(inputs, dim=dim, out=result)
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
@ -50,7 +50,7 @@ class ExperimentResults:
|
||||
materialized_mask_time: float
|
||||
attn_mask_subclass_time: float
|
||||
|
||||
def get_entries(self) -> List:
|
||||
def get_entries(self) -> list:
|
||||
return [
|
||||
f"{self.materialized_mask_time:2f}",
|
||||
f"{self.attn_mask_subclass_time:2f}",
|
||||
@ -62,7 +62,7 @@ class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def get_entries(self) -> List:
|
||||
def get_entries(self) -> list:
|
||||
return self.config.get_entries() + self.results.get_entries()
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
)
|
||||
|
||||
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
def generate_experiment_configs() -> list[ExperimentConfig]:
|
||||
batch_sizes = [1, 8, 16, 128]
|
||||
num_heads = [16, 32]
|
||||
q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
|
||||
@ -206,7 +206,7 @@ def calculate_speedup(results: ExperimentResults) -> float:
|
||||
return results.materialized_mask_time / results.attn_mask_subclass_time
|
||||
|
||||
|
||||
def print_results(results: List[Experiment]):
|
||||
def print_results(results: list[Experiment]):
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
|
||||
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
@ -46,7 +46,7 @@ class ExperimentConfig:
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
cal_bandwidth: bool
|
||||
backends: List[str]
|
||||
backends: list[str]
|
||||
|
||||
def __post_init__(self):
|
||||
assert (
|
||||
@ -80,7 +80,7 @@ class ExperimentResults:
|
||||
@dataclass(frozen=True)
|
||||
class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: Dict[str, ExperimentResults] # backend -> ExperimentResults
|
||||
results: dict[str, ExperimentResults] # backend -> ExperimentResults
|
||||
|
||||
def asdict(self):
|
||||
dict1 = self.config.asdict()
|
||||
@ -357,7 +357,7 @@ def run_single_experiment(
|
||||
config: ExperimentConfig,
|
||||
dynamic=False,
|
||||
max_autotune=False,
|
||||
) -> Dict[str, ExperimentResults]:
|
||||
) -> dict[str, ExperimentResults]:
|
||||
device = torch.device("cuda")
|
||||
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
|
||||
query, key, value = generate_inputs(
|
||||
@ -504,7 +504,7 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
|
||||
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
|
||||
|
||||
|
||||
def get_average_speedups(results: List[Experiment], type: str, backend: str):
|
||||
def get_average_speedups(results: list[Experiment], type: str, backend: str):
|
||||
# Calculate speedups
|
||||
speedups = [
|
||||
calculate_speedup(r.results["compiled"], r.results[backend], type)
|
||||
@ -533,7 +533,7 @@ def get_average_speedups(results: List[Experiment], type: str, backend: str):
|
||||
return table_data
|
||||
|
||||
|
||||
def print_results(results: List[Experiment], save_path: Optional[str] = None):
|
||||
def print_results(results: list[Experiment], save_path: Optional[str] = None):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
backends = experiment.config.backends + ["compiled"]
|
||||
@ -1024,16 +1024,16 @@ def generate_eager_sdpa(
|
||||
def generate_experiment_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
batch_sizes: List[int],
|
||||
num_heads: List[tuple[int, int]],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods_str: List[str],
|
||||
batch_sizes: list[int],
|
||||
num_heads: list[tuple[int, int]],
|
||||
seq_lens: list[int],
|
||||
head_dims: list[int],
|
||||
score_mods_str: list[str],
|
||||
decoding: bool,
|
||||
kv_cache_size: List[int],
|
||||
kv_cache_size: list[int],
|
||||
cal_bandwidth: bool,
|
||||
backends: List[str],
|
||||
) -> List[ExperimentConfig]:
|
||||
backends: list[str],
|
||||
) -> list[ExperimentConfig]:
|
||||
assert not (calculate_bwd and decoding), "Decoding does not support backward"
|
||||
|
||||
if decoding:
|
||||
|
@ -5,7 +5,7 @@ import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from prettytable import PrettyTable
|
||||
@ -32,7 +32,7 @@ class ExperimentConfig:
|
||||
enable_mem_efficient: bool
|
||||
enable_cudnn: bool
|
||||
|
||||
def get_entries(self) -> List:
|
||||
def get_entries(self) -> list:
|
||||
return [
|
||||
self.batch_size,
|
||||
self.num_heads,
|
||||
@ -47,7 +47,7 @@ class ExperimentConfig:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_entry_names(cls) -> List[str]:
|
||||
def get_entry_names(cls) -> list[str]:
|
||||
return [
|
||||
"batch_size",
|
||||
"num_heads",
|
||||
@ -69,7 +69,7 @@ class ExperimentResults:
|
||||
composite_mha_time: float
|
||||
compiled_composite_mha_time: Optional[float]
|
||||
|
||||
def get_entries(self) -> List:
|
||||
def get_entries(self) -> list:
|
||||
return [
|
||||
f"{self.nn_mha_time:2f}",
|
||||
f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None,
|
||||
@ -80,7 +80,7 @@ class ExperimentResults:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_entry_names(cls) -> List[str]:
|
||||
def get_entry_names(cls) -> list[str]:
|
||||
return [
|
||||
"nn_mha_time (\u00b5s)",
|
||||
"compiled_nn_mha_time (\u00b5s)",
|
||||
@ -94,7 +94,7 @@ class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def get_entries(self) -> List:
|
||||
def get_entries(self) -> list:
|
||||
return self.config.get_entries() + self.results.get_entries()
|
||||
|
||||
|
||||
@ -275,7 +275,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
# Could return generator
|
||||
def generate_experiments(
|
||||
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
|
||||
) -> List[ExperimentConfig]:
|
||||
) -> list[ExperimentConfig]:
|
||||
configs = []
|
||||
for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product(
|
||||
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
|
||||
@ -337,7 +337,7 @@ def main(save_path: Optional[Path]):
|
||||
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
|
||||
)
|
||||
|
||||
experiments: List[Experiment] = []
|
||||
experiments: list[Experiment] = []
|
||||
for experiment_config in tqdm(experiment_configs):
|
||||
experiment = run_single_experiment(experiment_config)
|
||||
experiments.append(experiment)
|
||||
|
@ -2,7 +2,7 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
@ -119,7 +119,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
)
|
||||
|
||||
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
def generate_experiment_configs() -> list[ExperimentConfig]:
|
||||
batch_sizes = [
|
||||
1,
|
||||
8,
|
||||
@ -160,7 +160,7 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
return all_configs
|
||||
|
||||
|
||||
def print_results(experiments: List[Experiment]):
|
||||
def print_results(experiments: list[Experiment]):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in experiments:
|
||||
for key, value in experiment.asdict().items():
|
||||
|
@ -8,7 +8,7 @@ import argparse
|
||||
import ast
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined]
|
||||
from typing import Any # type: ignore[attr-defined]
|
||||
|
||||
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
||||
from tools.flight_recorder.components.types import (
|
||||
@ -57,12 +57,12 @@ Flat DB builder
|
||||
|
||||
def build_groups_memberships(
|
||||
pg_config: Any,
|
||||
) -> Tuple[
|
||||
List[Group],
|
||||
Dict[Any, Group],
|
||||
List[Membership],
|
||||
Dict[str, Set[Any]],
|
||||
Dict[Tuple[str, int], str],
|
||||
) -> tuple[
|
||||
list[Group],
|
||||
dict[Any, Group],
|
||||
list[Membership],
|
||||
dict[str, set[Any]],
|
||||
dict[tuple[str, int], str],
|
||||
]:
|
||||
"""
|
||||
pg_config: {
|
||||
@ -126,12 +126,12 @@ def build_groups_memberships(
|
||||
|
||||
|
||||
def build_collectives(
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
_groups: Dict[str, Group],
|
||||
_memberships: Dict[str, Set[Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
all_entries: dict[int, list[dict[str, Any]]],
|
||||
_groups: dict[str, Group],
|
||||
_memberships: dict[str, set[Any]],
|
||||
_pg_guids: dict[tuple[str, int], str],
|
||||
version: str,
|
||||
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
|
||||
) -> tuple[list[Traceback], list[Collective], list[NCCLCall]]:
|
||||
"""
|
||||
groups, memberships are the non-flat dicts that are indexable
|
||||
all_entries is a raw dict from the original dumps:
|
||||
@ -161,10 +161,10 @@ def build_collectives(
|
||||
}
|
||||
"""
|
||||
major_v, minor_v = get_version_detail(version)
|
||||
tracebacks: List[Traceback] = []
|
||||
tracebacks: list[Traceback] = []
|
||||
|
||||
collectives: List[Collective] = []
|
||||
nccl_calls: List[NCCLCall] = []
|
||||
collectives: list[Collective] = []
|
||||
nccl_calls: list[NCCLCall] = []
|
||||
|
||||
# once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
|
||||
# instead, just record the remaining ops as NCCLCalls
|
||||
@ -420,7 +420,7 @@ def build_collectives(
|
||||
|
||||
|
||||
def build_db(
|
||||
details: Dict[str, Dict[str, Any]], args: argparse.Namespace, version: str
|
||||
details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str
|
||||
) -> Database:
|
||||
if args.verbose:
|
||||
os.environ["FR_TRACE_VERBOSE_OUTPUT"] = "1"
|
||||
|
@ -6,7 +6,8 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
||||
|
||||
|
@ -10,9 +10,8 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
||||
|
||||
@ -20,7 +19,7 @@ from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
||||
logger: FlightRecorderLogger = FlightRecorderLogger()
|
||||
|
||||
|
||||
def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
|
||||
def read_dump(prefix: str, filename: str) -> dict[str, Union[str, int, list[Any]]]:
|
||||
basename = os.path.basename(filename)
|
||||
|
||||
rank = int(basename[len(prefix) :])
|
||||
@ -45,12 +44,12 @@ def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]
|
||||
exp = re.compile(r"([\w\-\_]*?)(\d+)$")
|
||||
|
||||
|
||||
def _determine_prefix(files: List[str]) -> str:
|
||||
def _determine_prefix(files: list[str]) -> str:
|
||||
"""If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to
|
||||
infer the common prefix most of the time. But if we can't confidently infer, just fall back to requring the user
|
||||
to specify it
|
||||
"""
|
||||
possible_prefixes: typing.DefaultDict[str, Set[int]] = defaultdict(set)
|
||||
possible_prefixes: defaultdict[str, set[int]] = defaultdict(set)
|
||||
for f in files:
|
||||
m = exp.search(f)
|
||||
if m:
|
||||
@ -67,7 +66,7 @@ def _determine_prefix(files: List[str]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def read_dir(args: argparse.Namespace) -> Tuple[Dict[str, Dict[str, Any]], str]:
|
||||
def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]:
|
||||
gc.disable()
|
||||
prefix = args.prefix
|
||||
details = {}
|
||||
|
@ -10,14 +10,9 @@ from enum import auto, Enum
|
||||
from typing import ( # type: ignore[attr-defined]
|
||||
_eval_type,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
@ -33,7 +28,7 @@ class Ref(Generic[T]):
|
||||
|
||||
class TypeInfo(NamedTuple):
|
||||
name: str
|
||||
fields: List[Tuple[str, Type]] # type: ignore[type-arg]
|
||||
fields: list[tuple[str, type]] # type: ignore[type-arg]
|
||||
|
||||
@classmethod
|
||||
def from_type(cls, c: T) -> "TypeInfo":
|
||||
@ -126,15 +121,15 @@ class Collective(NamedTuple):
|
||||
record_id: int
|
||||
pg_desc: str
|
||||
collective_name: str
|
||||
input_sizes: List[List[int]]
|
||||
output_sizes: List[List[int]]
|
||||
expected_ranks: Set[int]
|
||||
input_sizes: list[list[int]]
|
||||
output_sizes: list[list[int]]
|
||||
expected_ranks: set[int]
|
||||
collective_state: str
|
||||
collective_frames: List[Dict[str, str]]
|
||||
collective_frames: list[dict[str, str]]
|
||||
input_numel: Optional[int] = None
|
||||
output_numel: Optional[int] = None
|
||||
missing_ranks: Optional[Set[int]] = None
|
||||
mismatch_collectives: Optional[Dict[int, "Collective"]] = None
|
||||
missing_ranks: Optional[set[int]] = None
|
||||
mismatch_collectives: Optional[dict[int, "Collective"]] = None
|
||||
type_of_mismatch: Optional[MatchState] = None
|
||||
|
||||
|
||||
@ -145,15 +140,15 @@ class NCCLCall(NamedTuple):
|
||||
global_rank: int # technically Ref[Process] once we have it
|
||||
traceback_id: Ref[Traceback]
|
||||
collective_type: str
|
||||
sizes: List[List[int]]
|
||||
sizes: list[list[int]]
|
||||
|
||||
|
||||
class Database(NamedTuple):
|
||||
groups: List[Group]
|
||||
memberships: List[Membership]
|
||||
tracebacks: List[Traceback]
|
||||
collectives: List[Collective]
|
||||
ncclcalls: List[NCCLCall]
|
||||
groups: list[Group]
|
||||
memberships: list[Membership]
|
||||
tracebacks: list[Traceback]
|
||||
collectives: list[Collective]
|
||||
ncclcalls: list[NCCLCall]
|
||||
|
||||
|
||||
# TODO: We need to add a schema for the following
|
||||
@ -206,7 +201,7 @@ class EntryState:
|
||||
log the error info during analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, entry: Dict[str, Any], expected_ranks: Set[int]) -> None:
|
||||
def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None:
|
||||
self.pg_name = entry["process_group"][0]
|
||||
self.desc = entry["process_group"][1]
|
||||
self.pg_desc = (
|
||||
@ -221,19 +216,19 @@ class EntryState:
|
||||
self.collective_state = entry["state"]
|
||||
self.collective_frames = entry["frames"]
|
||||
self.expected_ranks = expected_ranks
|
||||
self.missing_ranks: Set[int]
|
||||
self.missing_ranks: set[int]
|
||||
self.input_numel: int
|
||||
self.output_numel: int
|
||||
self.errors: Set[Tuple[int, MatchState]]
|
||||
self.errors: set[tuple[int, MatchState]]
|
||||
|
||||
def log(
|
||||
self,
|
||||
logger: FlightRecorderLogger,
|
||||
logger_msg: str,
|
||||
frame_formatter: Any,
|
||||
total_numel: Optional[Tuple[int, int]] = None,
|
||||
errors: Optional[Set[Tuple[int, MatchState]]] = None,
|
||||
missing_ranks: Optional[Set[int]] = None,
|
||||
total_numel: Optional[tuple[int, int]] = None,
|
||||
errors: Optional[set[tuple[int, MatchState]]] = None,
|
||||
missing_ranks: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
logger.info(
|
||||
logger_msg,
|
||||
@ -268,9 +263,9 @@ class EntryState:
|
||||
def to_collective(
|
||||
self,
|
||||
id: int,
|
||||
errors: Optional[Set[Tuple[int, MatchState]]] = None,
|
||||
idx_map: Optional[Dict[int, int]] = None,
|
||||
all_entries: Optional[Dict[int, List[Dict[str, Any]]]] = None,
|
||||
errors: Optional[set[tuple[int, MatchState]]] = None,
|
||||
idx_map: Optional[dict[int, int]] = None,
|
||||
all_entries: Optional[dict[int, list[dict[str, Any]]]] = None,
|
||||
) -> Collective:
|
||||
if not errors:
|
||||
return Collective(
|
||||
@ -340,11 +335,11 @@ class EntryState:
|
||||
|
||||
def to_nccl_call(
|
||||
self,
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
idx_map: Dict[int, int],
|
||||
all_entries: dict[int, list[dict[str, Any]]],
|
||||
idx_map: dict[int, int],
|
||||
nccl_call_id: int,
|
||||
collective_id: Any,
|
||||
) -> List[NCCLCall]:
|
||||
) -> list[NCCLCall]:
|
||||
result = []
|
||||
for i, k in idx_map.items():
|
||||
all_entries[i].pop(k)
|
||||
@ -373,7 +368,7 @@ class Op:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str
|
||||
self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str
|
||||
):
|
||||
self.profiling_name = event["profiling_name"]
|
||||
nccl, name = self.profiling_name.split(":")
|
||||
@ -412,7 +407,7 @@ class Op:
|
||||
self.collective_frames = event["frames"]
|
||||
self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1"
|
||||
|
||||
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
|
||||
def _init_global_src_dst(self, pg_ranks: set[Any]) -> None:
|
||||
pg_ranks = sorted(pg_ranks)
|
||||
self._src_g = pg_ranks[self._src] if self._src is not None else None
|
||||
self._dst_g = pg_ranks[self._dst] if self._dst is not None else None
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
||||
from tools.flight_recorder.components.types import (
|
||||
@ -27,14 +27,14 @@ except ModuleNotFoundError:
|
||||
logger.debug("tabulate is not installed. Proceeding without it.")
|
||||
|
||||
|
||||
def format_frame(frame: Dict[str, str]) -> str:
|
||||
def format_frame(frame: dict[str, str]) -> str:
|
||||
name = frame["name"]
|
||||
filename = frame["filename"]
|
||||
line = frame["line"]
|
||||
return f"{name} at {filename}:{line}"
|
||||
|
||||
|
||||
def format_frames(frames: List[Dict[str, str]]) -> str:
|
||||
def format_frames(frames: list[dict[str, str]]) -> str:
|
||||
formatted_frames = []
|
||||
for frame in frames:
|
||||
formatted_frames.append(format_frame(frame))
|
||||
@ -42,9 +42,9 @@ def format_frames(frames: List[Dict[str, str]]) -> str:
|
||||
|
||||
|
||||
def match_one_event(
|
||||
event_a: Dict[Any, Any],
|
||||
event_b: Dict[Any, Any],
|
||||
memberships: Dict[str, Set[Any]],
|
||||
event_a: dict[Any, Any],
|
||||
event_b: dict[Any, Any],
|
||||
memberships: dict[str, set[Any]],
|
||||
pg_name: str,
|
||||
) -> MatchState:
|
||||
op_a = Op(event_a, memberships, pg_name)
|
||||
@ -53,11 +53,11 @@ def match_one_event(
|
||||
|
||||
|
||||
def match_coalesced_groups(
|
||||
all_rank_events: Dict[Any, Any],
|
||||
all_rank_events: dict[Any, Any],
|
||||
group_size: int,
|
||||
groups: Dict[str, Group],
|
||||
memberships: Dict[str, Set[Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
groups: dict[str, Group],
|
||||
memberships: dict[str, set[Any]],
|
||||
_pg_guids: dict[tuple[str, int], str],
|
||||
) -> bool:
|
||||
"""
|
||||
all_rank_events: {
|
||||
@ -92,7 +92,7 @@ def match_coalesced_groups(
|
||||
|
||||
def visualize_ops(
|
||||
match: bool,
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
_pg_guids: dict[tuple[str, int], str],
|
||||
) -> None:
|
||||
all_ops = {
|
||||
rank: [
|
||||
@ -174,7 +174,7 @@ def match_coalesced_groups(
|
||||
return True
|
||||
|
||||
|
||||
def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int, int]:
|
||||
def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
|
||||
input_numel = 0
|
||||
output_numel = 0
|
||||
for e in alltoall_cases:
|
||||
@ -185,10 +185,10 @@ def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int
|
||||
|
||||
def find_coalesced_group(
|
||||
pg_name: str,
|
||||
entries: List[Dict[str, Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
entries: list[dict[str, Any]],
|
||||
_pg_guids: dict[tuple[str, int], str],
|
||||
rank: int,
|
||||
) -> List[Tuple[int, Dict[str, Any]]]:
|
||||
) -> list[tuple[int, dict[str, Any]]]:
|
||||
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
|
||||
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
|
||||
"""
|
||||
@ -216,10 +216,10 @@ def find_coalesced_group(
|
||||
|
||||
|
||||
def just_print_entries(
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
_groups: Dict[str, Group],
|
||||
_memberships: Dict[str, Set[Any]],
|
||||
_pg_guids: Dict[Tuple[str, int], str],
|
||||
all_entries: dict[int, list[dict[str, Any]]],
|
||||
_groups: dict[str, Group],
|
||||
_memberships: dict[str, set[Any]],
|
||||
_pg_guids: dict[tuple[str, int], str],
|
||||
args: argparse.Namespace,
|
||||
) -> None:
|
||||
rows = []
|
||||
@ -257,7 +257,7 @@ def just_print_entries(
|
||||
|
||||
|
||||
def check_no_missing_dump_files(
|
||||
entries: Dict[int, Any], memberships: List[Membership]
|
||||
entries: dict[int, Any], memberships: list[Membership]
|
||||
) -> None:
|
||||
all_ranks = set()
|
||||
for membership in memberships:
|
||||
@ -268,14 +268,14 @@ def check_no_missing_dump_files(
|
||||
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
||||
|
||||
|
||||
def check_version(version_by_ranks: Dict[str, str], version: str) -> None:
|
||||
def check_version(version_by_ranks: dict[str, str], version: str) -> None:
|
||||
for rank, v in version_by_ranks.items():
|
||||
assert (
|
||||
v == version
|
||||
), f"Rank {rank} has different version {v} from the given version {version}"
|
||||
|
||||
|
||||
def get_version_detail(version: str) -> Tuple[int, int]:
|
||||
def get_version_detail(version: str) -> tuple[int, int]:
|
||||
version = version.split(".")
|
||||
assert len(version) == 2, f"Invalid version {version}"
|
||||
major, minor = map(int, version)
|
||||
@ -283,8 +283,8 @@ def get_version_detail(version: str) -> Tuple[int, int]:
|
||||
|
||||
|
||||
def align_trace_from_beginning(
|
||||
entries: Dict[int, List[Dict[str, Any]]],
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
entries: dict[int, list[dict[str, Any]]],
|
||||
) -> dict[int, list[dict[str, Any]]]:
|
||||
"""
|
||||
Align the trace entries by record ID for entries.
|
||||
This function takes a dictionary of rank names to lists of trace entries as input.
|
||||
|
@ -29,7 +29,8 @@ python fr_trace.py <dump dir containing trace files> [-o <output file>]
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from typing import Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from tools.flight_recorder.components.builder import build_db
|
||||
from tools.flight_recorder.components.config_manager import JobConfig
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, cast, Dict
|
||||
from typing import Any, Callable, cast
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
@ -72,7 +72,7 @@ def gh_fetch_json_dict(
|
||||
params: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||
return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||
|
||||
|
||||
def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]:
|
||||
|
@ -12,10 +12,14 @@ from enum import Enum
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from tokenize import generate_tokens, TokenInfo
|
||||
from typing import Any, Iterator, Sequence
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import Never
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
|
||||
FSTRING_START = getattr(token, "FSTRING_START", None) # py3.12+
|
||||
FSTRING_END = getattr(token, "FSTRING_END", None)
|
||||
EMPTY_TOKENS = dict.fromkeys(
|
||||
|
@ -4,7 +4,7 @@ import sys
|
||||
import token
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
_PARENT = Path(__file__).parent.absolute()
|
||||
@ -16,6 +16,7 @@ else:
|
||||
import _linter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from tokenize import TokenInfo
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ import sys
|
||||
import token
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, NamedTuple, Set, TYPE_CHECKING
|
||||
from typing import NamedTuple, TYPE_CHECKING
|
||||
|
||||
|
||||
_PARENT = Path(__file__).parent.absolute()
|
||||
@ -45,7 +45,7 @@ class LintMessage(NamedTuple):
|
||||
|
||||
LINTER_CODE = "NEWLINE"
|
||||
CURRENT_FILE_NAME = os.path.basename(__file__)
|
||||
_MODULE_NAME_ALLOW_LIST: Set[str] = set()
|
||||
_MODULE_NAME_ALLOW_LIST: set[str] = set()
|
||||
|
||||
# Add builtin modules.
|
||||
if sys.version_info >= (3, 10):
|
||||
@ -352,7 +352,7 @@ use sys.modules.get("torchrec") or the like.
|
||||
"""
|
||||
|
||||
|
||||
def check_file(filepath: str) -> List[LintMessage]:
|
||||
def check_file(filepath: str) -> list[LintMessage]:
|
||||
path = Path(filepath)
|
||||
file = _linter.PythonFile("import_linter", path)
|
||||
lint_messages = []
|
||||
|
@ -5,7 +5,7 @@ import sys
|
||||
import token
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
_PARENT = Path(__file__).parent.absolute()
|
||||
@ -17,6 +17,7 @@ else:
|
||||
import _linter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from tokenize import TokenInfo
|
||||
|
||||
|
||||
|
@ -10,7 +10,6 @@ import tempfile
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -25,7 +24,7 @@ REQUIREMENTS_PATH = ROOT_PATH / "requirements.txt"
|
||||
|
||||
|
||||
def run_cmd(
|
||||
cmd: List[str], capture_output: bool = False
|
||||
cmd: list[str], capture_output: bool = False
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logger.debug("Running command: %s", " ".join(cmd))
|
||||
return subprocess.run(
|
||||
@ -69,7 +68,7 @@ class Builder:
|
||||
def __init__(self, interpreter: str) -> None:
|
||||
self.interpreter = interpreter
|
||||
|
||||
def setup_py(self, cmd_args: List[str]) -> bool:
|
||||
def setup_py(self, cmd_args: list[str]) -> bool:
|
||||
return (
|
||||
run_cmd([self.interpreter, str(SETUP_PY_PATH), *cmd_args]).returncode == 0
|
||||
)
|
||||
@ -114,7 +113,7 @@ def parse_args() -> argparse.Namespace:
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
pythons = args.python or [sys.executable]
|
||||
build_times: Dict[str, float] = dict()
|
||||
build_times: dict[str, float] = dict()
|
||||
|
||||
if len(pythons) > 1 and args.destination == "dist/":
|
||||
logger.warning(
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Generator
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from tools.stats.upload_stats_lib import (
|
||||
download_s3_artifacts,
|
||||
@ -17,6 +17,10 @@ from tools.stats.upload_stats_lib import (
|
||||
from tools.stats.upload_test_stats import process_xml_element
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
TESTCASE_TAG = "testcase"
|
||||
SEPARATOR = ";"
|
||||
|
||||
|
@ -7,7 +7,7 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Dict
|
||||
from typing import Any, Callable, cast
|
||||
from urllib.request import urlopen
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ def fetch_and_cache(
|
||||
if os.path.exists(path) and is_cached_file_valid():
|
||||
# Another test process already download the file, so don't re-do it
|
||||
with open(path) as f:
|
||||
return cast(Dict[str, Any], json.load(f))
|
||||
return cast(dict[str, Any], json.load(f))
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
|
@ -2,13 +2,13 @@ import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
|
||||
def flatten_data(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def flatten_data(d: dict[str, Any]) -> dict[str, Any]:
|
||||
# Flatten the sccache stats data from a possibly nested dictionary to a flat
|
||||
# dictionary. For example, the input:
|
||||
# {
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from tools.stats.upload_stats_lib import (
|
||||
download_s3_artifacts,
|
||||
@ -86,7 +86,7 @@ def get_perf_stats(
|
||||
return perf_stats
|
||||
|
||||
|
||||
def generate_partition_key(repo: str, doc: Dict[str, Any]) -> str:
|
||||
def generate_partition_key(repo: str, doc: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate an unique partition key for the document on DynamoDB
|
||||
"""
|
||||
|
@ -6,7 +6,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, cast, Dict, List
|
||||
from typing import Any, Callable, cast
|
||||
from urllib.error import HTTPError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
@ -60,7 +60,7 @@ def fetch_json(
|
||||
f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items()
|
||||
)
|
||||
return cast(
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
_fetch_url(url, headers=headers, data=data, reader=json.load),
|
||||
)
|
||||
|
||||
@ -79,7 +79,7 @@ def get_external_pr_data(
|
||||
responses: list[dict[str, Any]] = []
|
||||
while len(responses) > 0 or page == 1:
|
||||
response = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
fetch_json(
|
||||
"https://api.github.com/search/issues",
|
||||
params={
|
||||
|
@ -9,7 +9,7 @@ import time
|
||||
import zipfile
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import boto3 # type: ignore[import]
|
||||
import requests
|
||||
@ -122,8 +122,8 @@ def download_gha_artifacts(
|
||||
def upload_to_dynamodb(
|
||||
dynamodb_table: str,
|
||||
repo: str,
|
||||
docs: List[Any],
|
||||
generate_partition_key: Optional[Callable[[str, Dict[str, Any]], str]],
|
||||
docs: list[Any],
|
||||
generate_partition_key: Optional[Callable[[str, dict[str, Any]], str]],
|
||||
) -> None:
|
||||
print(f"Writing {len(docs)} documents to DynamoDB {dynamodb_table}")
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/dynamodb.html#batch-writing
|
||||
|
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import Any, List
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
|
||||
from tools.stats.test_dashboard import upload_additional_info
|
||||
from tools.stats.upload_stats_lib import get_s3_resource
|
||||
@ -11,7 +11,7 @@ from tools.stats.upload_test_stats import get_tests
|
||||
BUCKET_PREFIX = "workflows_failing_pending_upload"
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def get_bucket() -> Any:
|
||||
return get_s3_resource().Bucket("gha-artifacts")
|
||||
|
||||
@ -41,7 +41,7 @@ def do_upload(workflow_id: int) -> None:
|
||||
upload_additional_info(workflow_id, workflow_attempt, test_cases)
|
||||
|
||||
|
||||
def get_workflow_ids(pending: bool = False) -> List[int]:
|
||||
def get_workflow_ids(pending: bool = False) -> list[int]:
|
||||
prefix = f"{BUCKET_PREFIX}/{'pending/' if pending else ''}"
|
||||
objs = get_bucket().objects.filter(Prefix=prefix)
|
||||
return [int(obj.key.split("/")[-1].split(".")[0]) for obj in objs]
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Literal, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
|
||||
|
||||
import yaml
|
||||
|
||||
@ -2324,7 +2324,7 @@ def gen_source_files(
|
||||
|
||||
def register_dispatch_key_env_callable(
|
||||
gnf: NativeFunction | NativeFunctionsGroup,
|
||||
) -> Dict[str, list[str]]:
|
||||
) -> dict[str, list[str]]:
|
||||
return {
|
||||
"dispatch_definitions": get_native_function_definitions(
|
||||
fm=fm, # noqa: F821
|
||||
|
Reference in New Issue
Block a user