[BE] Format .ci/ / .github/ / benchmarks/ / functorch/ / tools/ / torchgen/ with ruff format (#132577)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132577
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-10-11 22:07:20 +08:00
committed by PyTorch MergeBot
parent 04adb74d08
commit 267f82b860
64 changed files with 210 additions and 233 deletions

View File

@ -45,8 +45,7 @@ def create_cert(path, C, ST, L, O, key):
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(
# Our certificate will be valid for 10 days
datetime.now(timezone.utc)
+ timedelta(days=10)
datetime.now(timezone.utc) + timedelta(days=10)
)
.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
@ -91,8 +90,7 @@ def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(
# Our certificate will be valid for 10 days
datetime.now(timezone.utc)
+ timedelta(days=10)
datetime.now(timezone.utc) + timedelta(days=10)
# Sign our certificate with our private key
)
.sign(private_ca_key, hashes.SHA256())

View File

@ -409,7 +409,7 @@ def generate_wheels_matrix(
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
"package_type": package_type,
"pytorch_extra_install_requirements": (
PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] # fmt: skip
PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version]
if os != "linux-aarch64"
else ""
),
@ -457,7 +457,7 @@ def generate_wheels_matrix(
".", "_"
),
"pytorch_extra_install_requirements": (
PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"] # fmt: skip
PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"]
if os != "linux" and gpu_arch_type != "xpu"
else ""
),

View File

@ -1506,7 +1506,7 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
def checks_to_markdown_bullets(
checks: List[Tuple[str, Optional[str], Optional[int]]]
checks: List[Tuple[str, Optional[str], Optional[int]]],
) -> List[str]:
return [
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]

View File

@ -51,9 +51,7 @@ def main():
print()
print(f"{'':>10s}", end="") # noqa: E999
for _ in [75, 95]:
print(
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
) # noqa: E999
print(f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="") # noqa: E999
print()
# Print measurements

View File

@ -209,9 +209,8 @@ def main():
x_axis_variables
): # run benchmark for every x axis variable
if len(x_axis_variables) > 1:
args[
args["x_axis_name"]
] = x_axis_variable # set x axis variable for this benchmark iteration
# set x axis variable for this benchmark iteration
args[args["x_axis_name"]] = x_axis_variable
processes = []
start_time = time.time()
for rank in range(args["world_size"]):

View File

@ -1391,9 +1391,7 @@ class AOTInductorModelCache:
strict=False,
).module()
with torch.no_grad():
so_path = torch._inductor.aot_compile(
gm, example_args, example_kwargs
) # type: ignore[arg-type]
so_path = torch._inductor.aot_compile(gm, example_args, example_kwargs) # type: ignore[arg-type]
cls.cache[key] = torch._export.aot_load(so_path, device)
@ -1559,12 +1557,10 @@ class OnnxModel(abc.ABC):
return model_path
@abc.abstractmethod
def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]:
...
def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: ...
@abc.abstractmethod
def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]:
...
def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: ...
def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, npt.NDArray]:
pt_inputs = self.format_pt_inputs(pt_inputs)
@ -3134,9 +3130,9 @@ class BenchmarkRunner:
experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
experiment_kwargs["dynamo_stats"] = dynamo_stats
if self.args.profile_dynamo_cache_lookup:
experiment_kwargs[
"cache_lookup_latency"
] = dynamo_cache_lookup_latency
experiment_kwargs["cache_lookup_latency"] = (
dynamo_cache_lookup_latency
)
if experiment.func is speedup_experiment_onnx:
experiment = functools.partial(
@ -3290,9 +3286,9 @@ class BenchmarkRunner:
experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
experiment_kwargs["dynamo_stats"] = dynamo_stats
if self.args.profile_dynamo_cache_lookup:
experiment_kwargs[
"cache_lookup_latency"
] = dynamo_cache_lookup_latency
experiment_kwargs["cache_lookup_latency"] = (
dynamo_cache_lookup_latency
)
if experiment.func is coverage_experiment:
ok, total = Stats.reset_counters()
@ -4324,7 +4320,14 @@ def run(runner, args, original_dir=None):
runner.skip_models.clear()
experiment = null_experiment
global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler
global \
current_name, \
current_device, \
current_batch_size, \
output_filename, \
disable_output, \
optimize_ctx, \
current_onnx_compiler
optimize_ctx = contextlib.nullcontext()
if args.disable_output:

View File

@ -2,6 +2,7 @@
A tool to merge multiple csv files (generated by torchbench.py/etc) into a single csv file.
Performs an outer join based on the benchmark name, filling in any missing data with zeros.
"""
import argparse
import functools
import operator

View File

@ -4,6 +4,7 @@ To generate output that can be fed into this script set the env varTORCHINDUCTOR
That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates.
"""
import json
import click

View File

@ -214,8 +214,7 @@ def bench(rnn_runners, group_name, print_json=False, sep=" ", **params):
k: {"avg": v.avg_fwd, "std": v.std_fwd, "info": v.info_fwd}
for k, v in results.items()
},
group_name
+ "-backward": {
f"{group_name}-backward": {
k: {"avg": v.avg_bwd, "std": v.std_bwd, "info": v.info_bwd}
for k, v in results.items()
},

View File

@ -184,7 +184,5 @@ class ConditionalFeedForwardInt8(nn.Module):
].to(x.dtype)
expert_outs = torch.einsum(
"tao, taio -> tai", (x1 * x3), w2_weights
) * self.scales2[expert_indices].to(
x.dtype
) # [T, A, D, D]
) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
return expert_outs

View File

@ -1,5 +1,7 @@
"""Collect instruction counts for continuous integration."""
# mypy: ignore-errors
import argparse
import hashlib
import json

View File

@ -1,5 +1,7 @@
"""Key enums and structs used to handle data flow within the benchmark."""
# mypy: ignore-errors
import dataclasses
import enum
import itertools as it

View File

@ -2,7 +2,9 @@
This is mostly string manipulation, with just a bit of importlib magic.
"""
# mypy: ignore-errors
import importlib.abc
import importlib.util
import itertools as it

View File

@ -1,5 +1,7 @@
"""Type annotations for various benchmark objects."""
# mypy: ignore-errors
from typing import Any, Dict, Optional, Tuple, Union
from core.api import AutoLabels, GroupedBenchmark, TimerArgs

View File

@ -1,5 +1,7 @@
"""Define some common setup blocks which benchmarks can reuse."""
# mypy: ignore-errors
import enum
from core.api import GroupedSetup

View File

@ -1,5 +1,7 @@
"""Run benchmarks while handling parallelism, isolation, and fault tolerance."""
# mypy: ignore-errors
import math
import multiprocessing
import subprocess

View File

@ -1,5 +1,7 @@
"""Handle the details of subprocess calls and retries for a given benchmark run."""
# mypy: ignore-errors
import dataclasses
import json
import os

View File

@ -5,7 +5,9 @@ expressive and robust components (e.g. better runner and result display
components) in future iterations. However this allows us to excercise the
underlying benchmark generation infrastructure in the mean time.
"""
# mypy: ignore-errors
import argparse
import sys
from typing import List

View File

@ -15,6 +15,7 @@ The life of a worker is very simple:
Because this file only expects to run in a child context, error handling means
plumbing failures up to the caller, not raising in this process.
"""
import argparse
import dataclasses
import io

View File

@ -48,14 +48,20 @@ class LSTMBenchmark(op_bench.TorchBenchmarkBase):
)[0]
x = torch.randn(
sequence_len, batch_size, I # sequence length # batch size
) # Number of features in X
sequence_len, # sequence length
batch_size, # batch size
I, # Number of features in X
)
h = torch.randn(
NL * (D + 1), batch_size, H # layer_num * dir_num # batch size
) # hidden size
NL * (D + 1), # layer_num * dir_num
batch_size, # batch size
H, # hidden size
)
c = torch.randn(
NL * (D + 1), batch_size, H # layer_num * dir_num # batch size
) # hidden size
NL * (D + 1), # layer_num * dir_num
batch_size, # batch size
H, # hidden size
)
self.inputs = {"x": x, "h": h, "c": c}
self.set_module_name("QLSTM")

View File

@ -152,8 +152,8 @@ def run(
result_entry["sequence_length"] = sequence_length
result_entry["n_heads"] = num_heads
result_entry["embed_dim"] = embed_dim
result_entry["time_native_mha_slow(\u00B5s)"] = f"{time_native_mha_slow:.3f}"
result_entry["time_native_mha_fast (\u00B5s)"] = f"{time_native_mha_fast:.3f}"
result_entry["time_native_mha_slow(\u00b5s)"] = f"{time_native_mha_slow:.3f}"
result_entry["time_native_mha_fast (\u00b5s)"] = f"{time_native_mha_fast:.3f}"
result_entry["speedup flash_mha v native_mha"] = f"{speedup_fast_internal:.3f}"
result_entry["padding"] = f"{padding:.3f}"
return result_entry

View File

@ -82,10 +82,10 @@ class ExperimentResults:
@classmethod
def get_entry_names(cls) -> List[str]:
return [
"nn_mha_time (\u00B5s)",
"compiled_nn_mha_time (\u00B5s)",
"composite_mha_time (\u00B5s)",
"compiled_composite_mha_time (\u00B5s)",
"nn_mha_time (\u00b5s)",
"compiled_nn_mha_time (\u00b5s)",
"composite_mha_time (\u00b5s)",
"compiled_composite_mha_time (\u00b5s)",
]

View File

@ -32,8 +32,7 @@ class Dim:
if self._vmap_level is not None:
_vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
while (
not _vmap_levels[-1].alive
and current_level() == _vmap_levels[-1].level # noqa: F821
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821
):
_vmap_decrement_nesting() # noqa: F821
_vmap_levels.pop()

View File

@ -22,6 +22,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import annotations
import keyword
@ -283,16 +284,16 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
str: the comma-separated string
Examples:
>>> comma_separate(('d0',))
>>> comma_separate(("d0",))
'd0'
>>> comma_separate(('d0', 'd1', 'd2', 'd3'))
>>> comma_separate(("d0", "d1", "d2", "d3"))
'd0, d1, d2, d3'
>>> comma_separate([('d1', 'd4')])
>>> comma_separate([("d1", "d4")])
'(d1, d4)'
>>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
>>> comma_separate([("d0",), (), ("d1",), ("d2",), ("d3", "d4")])
'(d0,), (), (d1,), (d2,), (d3, d4)'
"""
return ", ".join(

View File

@ -95,7 +95,7 @@ def _create_rearrange_callable(
raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
) -> List[Union[str, Tuple[str, ...]]]:
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
class dims."""
@ -171,31 +171,31 @@ def rearrange(
>>> images = torch.randn((32, 30, 40, 3))
>>> # stack along first (batch) axis, output is a single array
>>> rearrange(images, 'b h w c -> b h w c').shape
>>> rearrange(images, "b h w c -> b h w c").shape
torch.Size([32, 30, 40, 3])
>>> # concatenate images along height (vertical axis), 960 = 32 * 30
>>> rearrange(images, 'b h w c -> (b h) w c').shape
>>> rearrange(images, "b h w c -> (b h) w c").shape
torch.Size([960, 40, 3])
>>> # concatenated images along horizontal axis, 1280 = 32 * 40
>>> rearrange(images, 'b h w c -> h (b w) c').shape
>>> rearrange(images, "b h w c -> h (b w) c").shape
torch.Size([30, 1280, 3])
>>> # reordered axes to "b c h w" format for deep learning
>>> rearrange(images, 'b h w c -> b c h w').shape
>>> rearrange(images, "b h w c -> b c h w").shape
torch.Size([32, 3, 30, 40])
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3
>>> rearrange(images, 'b h w c -> b (c h w)').shape
>>> rearrange(images, "b h w c -> b (c h w)").shape
torch.Size([32, 3600])
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
>>> rearrange(images, "b (h1 h) (w1 w) c -> (b h1 w1) h w c", h1=2, w1=2).shape
torch.Size([128, 15, 20, 3])
>>> # space-to-depth operation
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
>>> rearrange(images, "b (h h1) (w w1) c -> b h w (c h1 w1)", h1=2, w1=2).shape
torch.Size([32, 15, 20, 12])
"""
if not isinstance(tensor, torch.Tensor):

View File

@ -152,7 +152,7 @@ def train(db, net, device, meta_opt, epoch, log):
spt_logits = fnet(new_params, buffers, x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params, create_graph=True)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
new_params = [p - g * 1e-1 for p, g in zip(new_params, grads)]
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
@ -215,7 +215,7 @@ def test(db, net, device, epoch, log):
spt_logits = fnet(new_params, buffers, x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
new_params = [p - g * 1e-1 for p, g in zip(new_params, grads)]
# The query loss and acc induced by these parameters.
qry_logits = fnet(new_params, buffers, x_qry[i]).detach()

View File

@ -132,7 +132,7 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
new_params = params
for _ in range(n_inner_iter):
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
new_params = {k: new_params[k] - g * 1e-1 for k, g, in grads.items()}
new_params = {k: new_params[k] - g * 1e-1 for k, g in grads.items()}
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
@ -216,7 +216,7 @@ def test(db, net, device, epoch, log):
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params.values())
new_params = {
k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)
k: new_params[k] - g * 1e-1 for k, g in zip(new_params, grads)
}
# The query loss and acc induced by these parameters.

View File

@ -169,9 +169,7 @@ class OmniglotNShot:
),
)
temp = (
{}
) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
for img, label in self.x:
if label in temp.keys():
temp[label].append(img)

View File

@ -16,6 +16,7 @@ for-loops and speeding them up through vectorization.
Let's demonstrate how to do this using an ensemble of simple CNNs.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@ -8,6 +8,7 @@ deep learning models. It is difficult (or annoying) to compute these quantities
efficiently using a standard autodiff system like PyTorch Autograd; functorch
provides ways of computing various higher-order autodiff quantities efficiently.
"""
from functools import partial
import torch

View File

@ -9,6 +9,7 @@ Per-sample-gradient computation is computing the gradient for each and every
sample in a batch of data. It is a useful quantity in differential privacy
and optimization research.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@ -9,7 +9,7 @@ from torchgen.utils import T
# Like tools.api.context.with_native_function, but for
# NativeFunctionWithDifferentiabilityInfo.
def with_native_function_with_differentiability_info(
func: Callable[[NFWDI], T]
func: Callable[[NFWDI], T],
) -> Callable[[NFWDI], T]:
@functools.wraps(func)
def wrapper(f: NFWDI) -> T:
@ -21,7 +21,7 @@ def with_native_function_with_differentiability_info(
# Like the above but with an additional dispatch key string argument
def with_native_function_with_differentiability_info_and_key(
func: Callable[[NFWDI, str], T]
func: Callable[[NFWDI, str], T],
) -> Callable[[NFWDI, str], T]:
@functools.wraps(func)
def wrapper(f: NFWDI, key: str) -> T:

View File

@ -70,9 +70,9 @@ def gen_autograd(
),
key=lambda f: cpp.name(f.func),
)
fns_with_diff_infos: list[
NativeFunctionWithDifferentiabilityInfo
] = match_differentiability_info(fns, differentiability_infos)
fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo] = (
match_differentiability_info(fns, differentiability_infos)
)
# Generate VariableType.h/cpp
if not disable_autograd:

View File

@ -447,7 +447,7 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def get_infos_with_derivatives_list(
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
) -> list[DifferentiabilityInfo]:
diff_info_list = [
info

View File

@ -590,8 +590,7 @@ def inplace_or_view_method_definition(
# For functions that modify their inputs but don't return them,
# we can't give them autograd support.
# See https://github.com/pytorch/pytorch/issues/53796
not modifies_arguments(f)
or len(f.func.returns) == 0
not modifies_arguments(f) or len(f.func.returns) == 0
):
return None
return METHOD_DEFINITION.substitute(

View File

@ -386,9 +386,9 @@ def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
grouped: dict[
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
] = defaultdict(list)
grouped: dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] = (
defaultdict(list)
)
for pair in pairs:
if pred(pair.function):
grouped[pair.function.func.name.name].append(pair)
@ -522,12 +522,12 @@ def create_python_bindings_sharded(
grouped = group_filter_overloads(pairs, pred)
def key_func(
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]],
) -> str:
return kv[0].base
def env_func(
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]],
) -> dict[str, list[str]]:
name, fn_pairs = kv
return {
@ -679,9 +679,7 @@ def load_deprecated_signatures(
function=pair.function,
)
)
assert (
any_schema_found
), f"No native function with name {aten_name} matched signature:\n {str(schema)}"
assert any_schema_found, f"No native function with name {aten_name} matched signature:\n {str(schema)}"
return results

View File

@ -128,9 +128,9 @@ def load_derivatives(
# function schema is the complete declaration including mutability annotation / default value and etc.
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
# that are semantically related.
functions_by_signature: dict[
FunctionSchema, list[NativeFunction]
] = defaultdict(list)
functions_by_signature: dict[FunctionSchema, list[NativeFunction]] = (
defaultdict(list)
)
functions_by_schema: dict[str, NativeFunction] = {}
for function in native_functions:
functions_by_signature[function.func.signature()].append(function)
@ -991,7 +991,7 @@ def _create_op_prefix(name: str) -> str:
OP names correspond to classes, hence the change to title case.
Example::
>>> _create_op_prefix('add')
>>> _create_op_prefix("add")
'AddBackward'
"""
camel_case = "".join([p.title() for p in name.split("_")])

View File

@ -112,8 +112,8 @@ def build_groups_memberships(
assert (
_groups[pg_guid].desc == desc
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
assert _memberships[pg_guid] == set(
ranks
assert (
_memberships[pg_guid] == set(ranks)
), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
return groups, _groups, memberships, _memberships, _pg_guids

View File

@ -278,7 +278,7 @@ def get_version_detail(version: str) -> Tuple[int, int]:
def align_trace_from_beginning(
entries: Dict[int, List[Dict[str, Any]]]
entries: Dict[int, List[Dict[str, Any]]],
) -> Dict[int, List[Dict[str, Any]]]:
"""
Align the trace entries by record ID for entries.

View File

@ -115,7 +115,8 @@ RESULTS_RE: re.Pattern[str] = re.compile(
def _test_results_re() -> None:
"""
>>> def t(s): return RESULTS_RE.search(s).groupdict()
>>> def t(s):
... return RESULTS_RE.search(s).groupdict()
>>> t(r"file.py:80:1: E302 expected 2 blank lines, found 1")
... # doctest: +NORMALIZE_WHITESPACE

View File

@ -31,17 +31,11 @@ USE_BLACK_FILELIST = re.compile(
[
# **
# .ci/**
".ci/**",
# .github/**
".github/**",
# benchmarks/**
"benchmarks/**",
# functorch/**
"functorch/**",
# tools/**
"tools/**",
# torchgen/**
"torchgen/**",
# test/**
# test/[a-h]*/**
"test/[a-h]*/**",

View File

@ -383,9 +383,9 @@ TORCH_API bool kernel_1();
class TestNativeFunctionGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.native_functions: list[NativeFunction] = []
self.backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = (
defaultdict(dict)
)
yaml_entry = """
- func: op(Tensor self) -> Tensor
dispatch:
@ -442,9 +442,9 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
# Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = (
defaultdict(dict)
)
yaml_entry = """
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:

View File

@ -474,7 +474,8 @@ class TestCalculateShards(unittest.TestCase):
else:
# x.time is not None because of the above check
self.assertAlmostEqual(
random_times[test], sum(x.time for x in sharded_tests) # type: ignore[misc]
random_times[test],
sum(x.time for x in sharded_tests), # type: ignore[misc]
)
self.assertListEqual(
list(range(sharded_tests[0].num_shards)),

View File

@ -52,13 +52,11 @@ class TestPrioritizations:
files[test.test_file] |= test
for test in files.values():
assert (
test.is_full_file()
), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
assert test.is_full_file(), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" # noqa: B950
# Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
assert self._original_tests == set(
files.keys()
assert (
self._original_tests == set(files.keys())
), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]:
@ -279,9 +277,9 @@ class AggregatedHeuristics:
stats["heuristics"] = heuristics
stats[
"aggregated"
] = self.get_aggregated_priorities().get_priority_info_for_test(test)
stats["aggregated"] = (
self.get_aggregated_priorities().get_priority_info_for_test(test)
)
stats["aggregated_trial"] = self.get_aggregated_priorities(
include_trial=True

View File

@ -68,12 +68,10 @@ class BenchmarkRunner:
self.main(args.num_samples, args.num_reps)
@abstractmethod
def run_benchmark(self, *args: Any) -> None:
...
def run_benchmark(self, *args: Any) -> None: ...
@abstractmethod
def create_input(self) -> Tuple[Any, ...]:
...
def create_input(self) -> Tuple[Any, ...]: ...
def main(self, num_samples: int, num_reps: int) -> None:
for _ in tqdm(range(num_samples)):

View File

@ -449,8 +449,8 @@ class AHTrainDecisionTree(AHTrain):
for row in group.itertuples():
choice2time[row.choice] = row.median_execution_time
assert len(unique_choices) == len(
group
assert (
len(unique_choices) == len(group)
), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
return pd.Series(

View File

@ -253,9 +253,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
assert not mutable, "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)

View File

@ -378,7 +378,8 @@ class LazyIrSchema:
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(
arg.name, arg.type # type:ignore[arg-type]
arg.name,
arg.type, # type:ignore[arg-type]
)
keyword_args.extend(
LazyArgument(arg, self.properties, symint=symint)

View File

@ -551,9 +551,9 @@ class PythonSignatureGroup:
# Out overloads in C++ don't have TensorOptions arguments,
# so take these from the functional variant
signature_kwargs[
"tensor_options_args"
] = functional.signature.tensor_options_args
signature_kwargs["tensor_options_args"] = (
functional.signature.tensor_options_args
)
return PythonSignatureGroup(
signature=type(out.signature)(**signature_kwargs),

View File

@ -164,42 +164,42 @@ def translate(
and isinstance(t.elem.elem, BaseCType)
and str(t.elem.elem.type) == "at::Tensor"
):
ctx[
NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = (
f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
)
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
ctx[
NamedCType(t.name, BaseCType(optionalTensorRefT))
] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = (
f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
)
if t.type == ConstRefCType(BaseCType(scalarT)):
ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
ctx[
NamedCType(t.name, BaseCType(optionalScalarRefT))
] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = (
f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
)
if t.type == BaseCType(scalar_t):
ctx[
NamedCType(t.name, BaseCType(opmath_t))
] = f"static_cast<opmath_t>({b.expr})"
ctx[NamedCType(t.name, BaseCType(opmath_t))] = (
f"static_cast<opmath_t>({b.expr})"
)
# [Note: IOptTensorListRef]
if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
ctx[
NamedCType(t.name, BaseCType(iOptTensorListRefT))
] = f"at::IOptTensorListRef({b.expr})"
ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = (
f"at::IOptTensorListRef({b.expr})"
)
# Add implicit bindings if the generated code is inside a Tensor method
if method:
ctx[
NamedCType("self", MutRefCType(BaseCType(tensorT)))
] = "const_cast<Tensor&>(*this)"
ctx[
NamedCType("self", ConstRefCType(BaseCType(tensorT)))
] = "const_cast<Tensor&>(*this)"
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = (
"const_cast<Tensor&>(*this)"
)
ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = (
"const_cast<Tensor&>(*this)"
)
# This is better! Byte-for-byte compat
# ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"

View File

@ -406,9 +406,7 @@ def kernel_signature(
meta = backend_index.get_kernel(f)
symint = meta is not None and meta.supports_symint()
if symint:
assert (
f.func.has_symint()
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
assert f.func.has_symint(), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
if backend_index.external:
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
else:

View File

@ -194,9 +194,7 @@ if ({arg_name}_opt.has_value()) {{
}} else {{
{out_name} = {ctype.cpp_type(strip_ref=True)}();
}}
""".split(
"\n"
),
""".split("\n"),
decl,
)
@ -213,9 +211,7 @@ def _gen_code_list_type(
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
""".split(
"\n"
)
""".split("\n")
)
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif isinstance(t.elem, OptionalType):
@ -226,9 +222,7 @@ for (c10::IValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{out_name}.push_back({res_name});
}}
""".split(
"\n"
)
""".split("\n")
)
else:
# use ArrayRef as default.
@ -242,8 +236,6 @@ for (c10::IValue {elem_name}: {in_name}) {{
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
""".split("\n")
)
return code, decl

View File

@ -95,7 +95,7 @@ def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T
def method_with_nested_native_function(
func: Callable[[S, F3], T]
func: Callable[[S, F3], T],
) -> Callable[[S, F3], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F3) -> T:
@ -108,7 +108,7 @@ def method_with_nested_native_function(
# Convenience decorator for functions that explicitly take in a BackendIndex,
# instead of indirectly taking one in as a closure
def with_native_function_and_index(
func: Callable[[F, BackendIndex], T]
func: Callable[[F, BackendIndex], T],
) -> Callable[[F, BackendIndex], T]:
@functools.wraps(func)
def wrapper(f: F, backend_index: BackendIndex) -> T:
@ -120,7 +120,7 @@ def with_native_function_and_index(
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
def with_native_function_and_indices(
func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
func: Callable[[F, dict[DispatchKey, BackendIndex]], T],
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func)
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:

View File

@ -184,9 +184,7 @@ def returntype_type(t: Type, *, mutable: bool) -> CType:
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
assert not mutable, "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)

View File

@ -127,9 +127,7 @@ class Unboxing:
return (
f"""
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
""".split("\n"),
decl,
)
@ -147,9 +145,7 @@ class Unboxing:
code.extend(
f"""
auto {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
""".split("\n")
)
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
@ -157,17 +153,13 @@ class Unboxing:
code.extend(
f"""
auto {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
""".split("\n")
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
auto {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
""".split("\n")
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
@ -183,9 +175,7 @@ for (auto {elem_name}: {in_name}) {{
#else
auto {out_name} = {arg_name}.toBoolList();
#endif
""".split(
"\n"
)
""".split("\n")
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
@ -205,9 +195,7 @@ for (auto {elem_name}: {in_name}) {{
#else
auto {out_name} = {arg_name}.toListOptionalTensor();
#endif
""".split(
"\n"
)
""".split("\n")
)
else:
# use ArrayRef as default.
@ -223,8 +211,6 @@ auto {out_name} = {arg_name}.toListOptionalTensor();
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
""".split("\n")
)
return code, decl

View File

@ -96,7 +96,7 @@ class ETKernelKey:
)
assert (
dim_order in dim_order_alias_map
), "Undefined dim_order alias: " + str(dim_order)
), f"Undefined dim_order alias: {dim_order}"
dtype_alias_used.add(type_alias)
# Generate all permutations of dtype alias values
@ -172,11 +172,11 @@ class ETKernelIndex:
@staticmethod
def from_backend_indices(
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> ETKernelIndex:
kernel_index: dict[
OperatorName, dict[ETKernelKey, BackendMetadata]
] = defaultdict(dict)
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = (
defaultdict(dict)
)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index)

View File

@ -1362,7 +1362,7 @@ def get_grouped_by_view_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
def maybe_create_view_group(
d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
d: dict[ViewSchemaKind | SchemaKind, NativeFunction],
) -> list[NativeFunction | NativeFunctionsViewGroup]:
funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
if ViewSchemaKind.aliasing in d:
@ -1409,7 +1409,7 @@ def get_grouped_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
def flatten_pre_group(
d: dict[SchemaKind, NativeFunction]
d: dict[SchemaKind, NativeFunction],
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
@ -1476,9 +1476,7 @@ def get_native_function_declarations_from_ns_grouped_kernels(
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
""".split(newline)
)
return declarations
@ -1671,9 +1669,7 @@ def get_namespaced_declaration(
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
""".split(newline)
)
return declarations
@ -2386,9 +2382,7 @@ def gen_source_files(
os.path.join(aoti_fm.install_dir, header_file_name)
) as old_file:
old_header = old_file.read()
assert (
old_header == new_header
), """
assert old_header == new_header, """
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!

View File

@ -71,7 +71,10 @@ base_type_to_callsite_expr = {
# convert args to C types, names in declarations, and expressions in function bodies
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
def convert_arg_type_and_name( # type: ignore[return]
typ: Type,
name: str,
) -> tuple[list[str], list[str], list[str], list[str]]:
if isinstance(typ, BaseType):
if typ.name in base_type_to_c_type:
return (

View File

@ -295,7 +295,7 @@ def gen_unboxing(
) -> None:
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
def key_func(
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
) -> str:
return item[0].root_name + ":" + item[1][0].to_native_string()
@ -739,7 +739,7 @@ def parse_yaml(
# (2) Return BackendIndices if kernel index is absent
def map_index(
m: dict[OperatorName, BackendMetadata]
m: dict[OperatorName, BackendMetadata],
) -> dict[OperatorName, BackendMetadata]:
return {op: m[op] for op in m if op in op_names}

View File

@ -278,13 +278,13 @@ def assert_view_op_properties(func: FunctionSchema) -> None:
args = func.arguments.flat_non_out
# The first argument is a tensor with an alias semantics (annotations)
assert len(args) > 0 and args[0].type == BaseType(
BaseTy.Tensor
assert (
len(args) > 0 and args[0].type == BaseType(BaseTy.Tensor)
), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
# No other arguments have aliasing semantics
assert is_alias(args[0]) and not any(
is_alias(a) for a in args[1:]
assert (
is_alias(args[0]) and not any(is_alias(a) for a in args[1:])
), """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint"""

View File

@ -176,9 +176,9 @@ class default_args:
tensor_class: str = "torch::lazy::LazyTensor"
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
lazy_ir_generator: type[GenLazyIR] = GenLazyIR
native_func_definition_generator: type[
native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
GenLazyNativeFuncDefinition
] = GenLazyNativeFuncDefinition
)
backend_name: str = "TorchScript"
@ -257,9 +257,9 @@ def main() -> None:
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
if options.gen_ts_lowerings:
lazy_ir_generator = GenTSLazyIR
native_func_definition_generator: type[
GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator
native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
default_args.native_func_definition_generator
)
run_gen_lazy_tensor(
aten_path,

View File

@ -1484,14 +1484,15 @@ class FunctionSchema:
else:
# mutable keyword arguments whose name has _scratch_ prefix are
# scratch tensors for memory planning and should not be returned
assert len(
[
arg
for arg in self.arguments.out
if not arg.name.startswith("_scratch_")
]
) == len(
self.returns
assert (
len(
[
arg
for arg in self.arguments.out
if not arg.name.startswith("_scratch_")
]
)
== len(self.returns)
), "Must return as many arguments as there are out arguments, or no return at all"
if self.name.name.inplace:
@ -1590,9 +1591,7 @@ class FunctionSchema:
), "invariant: all scratch operators are expected to be out= operators too"
return SchemaKind.scratch
elif is_out:
assert (
not is_scratch
), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
assert not is_scratch, "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" # noqa: B950
return SchemaKind.out
elif is_mutable:
return SchemaKind.mutable
@ -2701,9 +2700,7 @@ class NativeFunctionsViewGroup:
)
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
if self.view_inplace is not None:
assert (
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
), (
assert self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel, (
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
)

View File

@ -263,7 +263,7 @@ def construct_register_size(register_size_from_yaml: int) -> str:
def construct_version_maps(
upgrader_bytecode_function_to_index_map: dict[str, Any]
upgrader_bytecode_function_to_index_map: dict[str, Any],
) -> str:
version_map = torch._C._get_operator_version_map()
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
@ -305,7 +305,7 @@ def construct_version_maps(
def get_upgrader_bytecode_function_to_index_map(
upgrader_dict: list[dict[str, Any]]
upgrader_dict: list[dict[str, Any]],
) -> dict[str, Any]:
upgrader_bytecode_function_to_index_map = {}
index = 0

View File

@ -366,9 +366,9 @@ def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> N
arg_map["out_int32"] = "false"
else:
arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
arg_map[
"col_indices"
] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
arg_map["col_indices"] = (
"torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
)
arg_map["out_int32"] = "false"
return
if op_name == "_convert_indices_from_coo_to_csr":