mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Format .ci/
/ .github/
/ benchmarks/
/ functorch/
/ tools/
/ torchgen/
with ruff format
(#132577)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132577 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
04adb74d08
commit
267f82b860
@ -45,8 +45,7 @@ def create_cert(path, C, ST, L, O, key):
|
||||
.not_valid_before(datetime.now(timezone.utc))
|
||||
.not_valid_after(
|
||||
# Our certificate will be valid for 10 days
|
||||
datetime.now(timezone.utc)
|
||||
+ timedelta(days=10)
|
||||
datetime.now(timezone.utc) + timedelta(days=10)
|
||||
)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=None),
|
||||
@ -91,8 +90,7 @@ def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
|
||||
.not_valid_before(datetime.now(timezone.utc))
|
||||
.not_valid_after(
|
||||
# Our certificate will be valid for 10 days
|
||||
datetime.now(timezone.utc)
|
||||
+ timedelta(days=10)
|
||||
datetime.now(timezone.utc) + timedelta(days=10)
|
||||
# Sign our certificate with our private key
|
||||
)
|
||||
.sign(private_ca_key, hashes.SHA256())
|
||||
|
@ -409,7 +409,7 @@ def generate_wheels_matrix(
|
||||
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
|
||||
"package_type": package_type,
|
||||
"pytorch_extra_install_requirements": (
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] # fmt: skip
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version]
|
||||
if os != "linux-aarch64"
|
||||
else ""
|
||||
),
|
||||
@ -457,7 +457,7 @@ def generate_wheels_matrix(
|
||||
".", "_"
|
||||
),
|
||||
"pytorch_extra_install_requirements": (
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"] # fmt: skip
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"]
|
||||
if os != "linux" and gpu_arch_type != "xpu"
|
||||
else ""
|
||||
),
|
||||
|
2
.github/scripts/trymerge.py
vendored
2
.github/scripts/trymerge.py
vendored
@ -1506,7 +1506,7 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
|
||||
|
||||
|
||||
def checks_to_markdown_bullets(
|
||||
checks: List[Tuple[str, Optional[str], Optional[int]]]
|
||||
checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
) -> List[str]:
|
||||
return [
|
||||
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
|
||||
|
@ -51,9 +51,7 @@ def main():
|
||||
print()
|
||||
print(f"{'':>10s}", end="") # noqa: E999
|
||||
for _ in [75, 95]:
|
||||
print(
|
||||
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
|
||||
) # noqa: E999
|
||||
print(f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="") # noqa: E999
|
||||
print()
|
||||
|
||||
# Print measurements
|
||||
|
@ -209,9 +209,8 @@ def main():
|
||||
x_axis_variables
|
||||
): # run benchmark for every x axis variable
|
||||
if len(x_axis_variables) > 1:
|
||||
args[
|
||||
args["x_axis_name"]
|
||||
] = x_axis_variable # set x axis variable for this benchmark iteration
|
||||
# set x axis variable for this benchmark iteration
|
||||
args[args["x_axis_name"]] = x_axis_variable
|
||||
processes = []
|
||||
start_time = time.time()
|
||||
for rank in range(args["world_size"]):
|
||||
|
@ -1391,9 +1391,7 @@ class AOTInductorModelCache:
|
||||
strict=False,
|
||||
).module()
|
||||
with torch.no_grad():
|
||||
so_path = torch._inductor.aot_compile(
|
||||
gm, example_args, example_kwargs
|
||||
) # type: ignore[arg-type]
|
||||
so_path = torch._inductor.aot_compile(gm, example_args, example_kwargs) # type: ignore[arg-type]
|
||||
|
||||
cls.cache[key] = torch._export.aot_load(so_path, device)
|
||||
|
||||
@ -1559,12 +1557,10 @@ class OnnxModel(abc.ABC):
|
||||
return model_path
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]:
|
||||
...
|
||||
def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]:
|
||||
...
|
||||
def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: ...
|
||||
|
||||
def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, npt.NDArray]:
|
||||
pt_inputs = self.format_pt_inputs(pt_inputs)
|
||||
@ -3134,9 +3130,9 @@ class BenchmarkRunner:
|
||||
experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
|
||||
experiment_kwargs["dynamo_stats"] = dynamo_stats
|
||||
if self.args.profile_dynamo_cache_lookup:
|
||||
experiment_kwargs[
|
||||
"cache_lookup_latency"
|
||||
] = dynamo_cache_lookup_latency
|
||||
experiment_kwargs["cache_lookup_latency"] = (
|
||||
dynamo_cache_lookup_latency
|
||||
)
|
||||
|
||||
if experiment.func is speedup_experiment_onnx:
|
||||
experiment = functools.partial(
|
||||
@ -3290,9 +3286,9 @@ class BenchmarkRunner:
|
||||
experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
|
||||
experiment_kwargs["dynamo_stats"] = dynamo_stats
|
||||
if self.args.profile_dynamo_cache_lookup:
|
||||
experiment_kwargs[
|
||||
"cache_lookup_latency"
|
||||
] = dynamo_cache_lookup_latency
|
||||
experiment_kwargs["cache_lookup_latency"] = (
|
||||
dynamo_cache_lookup_latency
|
||||
)
|
||||
|
||||
if experiment.func is coverage_experiment:
|
||||
ok, total = Stats.reset_counters()
|
||||
@ -4324,7 +4320,14 @@ def run(runner, args, original_dir=None):
|
||||
runner.skip_models.clear()
|
||||
|
||||
experiment = null_experiment
|
||||
global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler
|
||||
global \
|
||||
current_name, \
|
||||
current_device, \
|
||||
current_batch_size, \
|
||||
output_filename, \
|
||||
disable_output, \
|
||||
optimize_ctx, \
|
||||
current_onnx_compiler
|
||||
optimize_ctx = contextlib.nullcontext()
|
||||
|
||||
if args.disable_output:
|
||||
|
@ -2,6 +2,7 @@
|
||||
A tool to merge multiple csv files (generated by torchbench.py/etc) into a single csv file.
|
||||
Performs an outer join based on the benchmark name, filling in any missing data with zeros.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import operator
|
||||
|
@ -4,6 +4,7 @@ To generate output that can be fed into this script set the env varTORCHINDUCTOR
|
||||
|
||||
That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import click
|
||||
|
@ -214,8 +214,7 @@ def bench(rnn_runners, group_name, print_json=False, sep=" ", **params):
|
||||
k: {"avg": v.avg_fwd, "std": v.std_fwd, "info": v.info_fwd}
|
||||
for k, v in results.items()
|
||||
},
|
||||
group_name
|
||||
+ "-backward": {
|
||||
f"{group_name}-backward": {
|
||||
k: {"avg": v.avg_bwd, "std": v.std_bwd, "info": v.info_bwd}
|
||||
for k, v in results.items()
|
||||
},
|
||||
|
@ -184,7 +184,5 @@ class ConditionalFeedForwardInt8(nn.Module):
|
||||
].to(x.dtype)
|
||||
expert_outs = torch.einsum(
|
||||
"tao, taio -> tai", (x1 * x3), w2_weights
|
||||
) * self.scales2[expert_indices].to(
|
||||
x.dtype
|
||||
) # [T, A, D, D]
|
||||
) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
|
||||
return expert_outs
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Collect instruction counts for continuous integration."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Key enums and structs used to handle data flow within the benchmark."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools as it
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
This is mostly string manipulation, with just a bit of importlib magic.
|
||||
"""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import importlib.abc
|
||||
import importlib.util
|
||||
import itertools as it
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Type annotations for various benchmark objects."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from core.api import AutoLabels, GroupedBenchmark, TimerArgs
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Define some common setup blocks which benchmarks can reuse."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import enum
|
||||
|
||||
from core.api import GroupedSetup
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Run benchmarks while handling parallelism, isolation, and fault tolerance."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import math
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Handle the details of subprocess calls and retries for a given benchmark run."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
|
@ -5,7 +5,9 @@ expressive and robust components (e.g. better runner and result display
|
||||
components) in future iterations. However this allows us to excercise the
|
||||
underlying benchmark generation infrastructure in the mean time.
|
||||
"""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
|
@ -15,6 +15,7 @@ The life of a worker is very simple:
|
||||
Because this file only expects to run in a child context, error handling means
|
||||
plumbing failures up to the caller, not raising in this process.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import io
|
||||
|
@ -48,14 +48,20 @@ class LSTMBenchmark(op_bench.TorchBenchmarkBase):
|
||||
)[0]
|
||||
|
||||
x = torch.randn(
|
||||
sequence_len, batch_size, I # sequence length # batch size
|
||||
) # Number of features in X
|
||||
sequence_len, # sequence length
|
||||
batch_size, # batch size
|
||||
I, # Number of features in X
|
||||
)
|
||||
h = torch.randn(
|
||||
NL * (D + 1), batch_size, H # layer_num * dir_num # batch size
|
||||
) # hidden size
|
||||
NL * (D + 1), # layer_num * dir_num
|
||||
batch_size, # batch size
|
||||
H, # hidden size
|
||||
)
|
||||
c = torch.randn(
|
||||
NL * (D + 1), batch_size, H # layer_num * dir_num # batch size
|
||||
) # hidden size
|
||||
NL * (D + 1), # layer_num * dir_num
|
||||
batch_size, # batch size
|
||||
H, # hidden size
|
||||
)
|
||||
|
||||
self.inputs = {"x": x, "h": h, "c": c}
|
||||
self.set_module_name("QLSTM")
|
||||
|
@ -152,8 +152,8 @@ def run(
|
||||
result_entry["sequence_length"] = sequence_length
|
||||
result_entry["n_heads"] = num_heads
|
||||
result_entry["embed_dim"] = embed_dim
|
||||
result_entry["time_native_mha_slow(\u00B5s)"] = f"{time_native_mha_slow:.3f}"
|
||||
result_entry["time_native_mha_fast (\u00B5s)"] = f"{time_native_mha_fast:.3f}"
|
||||
result_entry["time_native_mha_slow(\u00b5s)"] = f"{time_native_mha_slow:.3f}"
|
||||
result_entry["time_native_mha_fast (\u00b5s)"] = f"{time_native_mha_fast:.3f}"
|
||||
result_entry["speedup flash_mha v native_mha"] = f"{speedup_fast_internal:.3f}"
|
||||
result_entry["padding"] = f"{padding:.3f}"
|
||||
return result_entry
|
||||
|
@ -82,10 +82,10 @@ class ExperimentResults:
|
||||
@classmethod
|
||||
def get_entry_names(cls) -> List[str]:
|
||||
return [
|
||||
"nn_mha_time (\u00B5s)",
|
||||
"compiled_nn_mha_time (\u00B5s)",
|
||||
"composite_mha_time (\u00B5s)",
|
||||
"compiled_composite_mha_time (\u00B5s)",
|
||||
"nn_mha_time (\u00b5s)",
|
||||
"compiled_nn_mha_time (\u00b5s)",
|
||||
"composite_mha_time (\u00b5s)",
|
||||
"compiled_composite_mha_time (\u00b5s)",
|
||||
]
|
||||
|
||||
|
||||
|
@ -32,8 +32,7 @@ class Dim:
|
||||
if self._vmap_level is not None:
|
||||
_vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
|
||||
while (
|
||||
not _vmap_levels[-1].alive
|
||||
and current_level() == _vmap_levels[-1].level # noqa: F821
|
||||
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821
|
||||
):
|
||||
_vmap_decrement_nesting() # noqa: F821
|
||||
_vmap_levels.pop()
|
||||
|
@ -22,6 +22,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import keyword
|
||||
@ -283,16 +284,16 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
|
||||
str: the comma-separated string
|
||||
|
||||
Examples:
|
||||
>>> comma_separate(('d0',))
|
||||
>>> comma_separate(("d0",))
|
||||
'd0'
|
||||
|
||||
>>> comma_separate(('d0', 'd1', 'd2', 'd3'))
|
||||
>>> comma_separate(("d0", "d1", "d2", "d3"))
|
||||
'd0, d1, d2, d3'
|
||||
|
||||
>>> comma_separate([('d1', 'd4')])
|
||||
>>> comma_separate([("d1", "d4")])
|
||||
'(d1, d4)'
|
||||
|
||||
>>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
|
||||
>>> comma_separate([("d0",), (), ("d1",), ("d2",), ("d3", "d4")])
|
||||
'(d0,), (), (d1,), (d2,), (d3, d4)'
|
||||
"""
|
||||
return ", ".join(
|
||||
|
@ -95,7 +95,7 @@ def _create_rearrange_callable(
|
||||
raise ValueError(f"Unexpected dimension: {dimension}")
|
||||
|
||||
def composition_to_dims(
|
||||
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
|
||||
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
|
||||
) -> List[Union[str, Tuple[str, ...]]]:
|
||||
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
||||
class dims."""
|
||||
@ -171,31 +171,31 @@ def rearrange(
|
||||
>>> images = torch.randn((32, 30, 40, 3))
|
||||
|
||||
>>> # stack along first (batch) axis, output is a single array
|
||||
>>> rearrange(images, 'b h w c -> b h w c').shape
|
||||
>>> rearrange(images, "b h w c -> b h w c").shape
|
||||
torch.Size([32, 30, 40, 3])
|
||||
|
||||
>>> # concatenate images along height (vertical axis), 960 = 32 * 30
|
||||
>>> rearrange(images, 'b h w c -> (b h) w c').shape
|
||||
>>> rearrange(images, "b h w c -> (b h) w c").shape
|
||||
torch.Size([960, 40, 3])
|
||||
|
||||
>>> # concatenated images along horizontal axis, 1280 = 32 * 40
|
||||
>>> rearrange(images, 'b h w c -> h (b w) c').shape
|
||||
>>> rearrange(images, "b h w c -> h (b w) c").shape
|
||||
torch.Size([30, 1280, 3])
|
||||
|
||||
>>> # reordered axes to "b c h w" format for deep learning
|
||||
>>> rearrange(images, 'b h w c -> b c h w').shape
|
||||
>>> rearrange(images, "b h w c -> b c h w").shape
|
||||
torch.Size([32, 3, 30, 40])
|
||||
|
||||
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3
|
||||
>>> rearrange(images, 'b h w c -> b (c h w)').shape
|
||||
>>> rearrange(images, "b h w c -> b (c h w)").shape
|
||||
torch.Size([32, 3600])
|
||||
|
||||
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
|
||||
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
|
||||
>>> rearrange(images, "b (h1 h) (w1 w) c -> (b h1 w1) h w c", h1=2, w1=2).shape
|
||||
torch.Size([128, 15, 20, 3])
|
||||
|
||||
>>> # space-to-depth operation
|
||||
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
|
||||
>>> rearrange(images, "b (h h1) (w w1) c -> b h w (c h1 w1)", h1=2, w1=2).shape
|
||||
torch.Size([32, 15, 20, 12])
|
||||
"""
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
|
@ -152,7 +152,7 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
spt_logits = fnet(new_params, buffers, x_spt[i])
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params, create_graph=True)
|
||||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
|
||||
new_params = [p - g * 1e-1 for p, g in zip(new_params, grads)]
|
||||
|
||||
# The final set of adapted parameters will induce some
|
||||
# final loss and accuracy on the query dataset.
|
||||
@ -215,7 +215,7 @@ def test(db, net, device, epoch, log):
|
||||
spt_logits = fnet(new_params, buffers, x_spt[i])
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params)
|
||||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
|
||||
new_params = [p - g * 1e-1 for p, g in zip(new_params, grads)]
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = fnet(new_params, buffers, x_qry[i]).detach()
|
||||
|
@ -132,7 +132,7 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
|
||||
new_params = params
|
||||
for _ in range(n_inner_iter):
|
||||
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
|
||||
new_params = {k: new_params[k] - g * 1e-1 for k, g, in grads.items()}
|
||||
new_params = {k: new_params[k] - g * 1e-1 for k, g in grads.items()}
|
||||
|
||||
# The final set of adapted parameters will induce some
|
||||
# final loss and accuracy on the query dataset.
|
||||
@ -216,7 +216,7 @@ def test(db, net, device, epoch, log):
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params.values())
|
||||
new_params = {
|
||||
k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)
|
||||
k: new_params[k] - g * 1e-1 for k, g in zip(new_params, grads)
|
||||
}
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
|
@ -169,9 +169,7 @@ class OmniglotNShot:
|
||||
),
|
||||
)
|
||||
|
||||
temp = (
|
||||
{}
|
||||
) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
|
||||
temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
|
||||
for img, label in self.x:
|
||||
if label in temp.keys():
|
||||
temp[label].append(img)
|
||||
|
@ -16,6 +16,7 @@ for-loops and speeding them up through vectorization.
|
||||
|
||||
Let's demonstrate how to do this using an ensemble of simple CNNs.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -8,6 +8,7 @@ deep learning models. It is difficult (or annoying) to compute these quantities
|
||||
efficiently using a standard autodiff system like PyTorch Autograd; functorch
|
||||
provides ways of computing various higher-order autodiff quantities efficiently.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
@ -9,6 +9,7 @@ Per-sample-gradient computation is computing the gradient for each and every
|
||||
sample in a batch of data. It is a useful quantity in differential privacy
|
||||
and optimization research.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -9,7 +9,7 @@ from torchgen.utils import T
|
||||
# Like tools.api.context.with_native_function, but for
|
||||
# NativeFunctionWithDifferentiabilityInfo.
|
||||
def with_native_function_with_differentiability_info(
|
||||
func: Callable[[NFWDI], T]
|
||||
func: Callable[[NFWDI], T],
|
||||
) -> Callable[[NFWDI], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: NFWDI) -> T:
|
||||
@ -21,7 +21,7 @@ def with_native_function_with_differentiability_info(
|
||||
|
||||
# Like the above but with an additional dispatch key string argument
|
||||
def with_native_function_with_differentiability_info_and_key(
|
||||
func: Callable[[NFWDI, str], T]
|
||||
func: Callable[[NFWDI, str], T],
|
||||
) -> Callable[[NFWDI, str], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: NFWDI, key: str) -> T:
|
||||
|
@ -70,9 +70,9 @@ def gen_autograd(
|
||||
),
|
||||
key=lambda f: cpp.name(f.func),
|
||||
)
|
||||
fns_with_diff_infos: list[
|
||||
NativeFunctionWithDifferentiabilityInfo
|
||||
] = match_differentiability_info(fns, differentiability_infos)
|
||||
fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo] = (
|
||||
match_differentiability_info(fns, differentiability_infos)
|
||||
)
|
||||
|
||||
# Generate VariableType.h/cpp
|
||||
if not disable_autograd:
|
||||
|
@ -447,7 +447,7 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
||||
|
||||
|
||||
def get_infos_with_derivatives_list(
|
||||
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
|
||||
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||
) -> list[DifferentiabilityInfo]:
|
||||
diff_info_list = [
|
||||
info
|
||||
|
@ -590,8 +590,7 @@ def inplace_or_view_method_definition(
|
||||
# For functions that modify their inputs but don't return them,
|
||||
# we can't give them autograd support.
|
||||
# See https://github.com/pytorch/pytorch/issues/53796
|
||||
not modifies_arguments(f)
|
||||
or len(f.func.returns) == 0
|
||||
not modifies_arguments(f) or len(f.func.returns) == 0
|
||||
):
|
||||
return None
|
||||
return METHOD_DEFINITION.substitute(
|
||||
|
@ -386,9 +386,9 @@ def group_filter_overloads(
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
|
||||
grouped: dict[
|
||||
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
|
||||
] = defaultdict(list)
|
||||
grouped: dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
for pair in pairs:
|
||||
if pred(pair.function):
|
||||
grouped[pair.function.func.name.name].append(pair)
|
||||
@ -522,12 +522,12 @@ def create_python_bindings_sharded(
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
def key_func(
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]],
|
||||
) -> str:
|
||||
return kv[0].base
|
||||
|
||||
def env_func(
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]],
|
||||
) -> dict[str, list[str]]:
|
||||
name, fn_pairs = kv
|
||||
return {
|
||||
@ -679,9 +679,7 @@ def load_deprecated_signatures(
|
||||
function=pair.function,
|
||||
)
|
||||
)
|
||||
assert (
|
||||
any_schema_found
|
||||
), f"No native function with name {aten_name} matched signature:\n {str(schema)}"
|
||||
assert any_schema_found, f"No native function with name {aten_name} matched signature:\n {str(schema)}"
|
||||
|
||||
return results
|
||||
|
||||
|
@ -128,9 +128,9 @@ def load_derivatives(
|
||||
# function schema is the complete declaration including mutability annotation / default value and etc.
|
||||
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
|
||||
# that are semantically related.
|
||||
functions_by_signature: dict[
|
||||
FunctionSchema, list[NativeFunction]
|
||||
] = defaultdict(list)
|
||||
functions_by_signature: dict[FunctionSchema, list[NativeFunction]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
functions_by_schema: dict[str, NativeFunction] = {}
|
||||
for function in native_functions:
|
||||
functions_by_signature[function.func.signature()].append(function)
|
||||
@ -991,7 +991,7 @@ def _create_op_prefix(name: str) -> str:
|
||||
OP names correspond to classes, hence the change to title case.
|
||||
|
||||
Example::
|
||||
>>> _create_op_prefix('add')
|
||||
>>> _create_op_prefix("add")
|
||||
'AddBackward'
|
||||
"""
|
||||
camel_case = "".join([p.title() for p in name.split("_")])
|
||||
|
@ -112,8 +112,8 @@ def build_groups_memberships(
|
||||
assert (
|
||||
_groups[pg_guid].desc == desc
|
||||
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
|
||||
assert _memberships[pg_guid] == set(
|
||||
ranks
|
||||
assert (
|
||||
_memberships[pg_guid] == set(ranks)
|
||||
), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
|
||||
return groups, _groups, memberships, _memberships, _pg_guids
|
||||
|
||||
|
@ -278,7 +278,7 @@ def get_version_detail(version: str) -> Tuple[int, int]:
|
||||
|
||||
|
||||
def align_trace_from_beginning(
|
||||
entries: Dict[int, List[Dict[str, Any]]]
|
||||
entries: Dict[int, List[Dict[str, Any]]],
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Align the trace entries by record ID for entries.
|
||||
|
@ -115,7 +115,8 @@ RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
|
||||
def _test_results_re() -> None:
|
||||
"""
|
||||
>>> def t(s): return RESULTS_RE.search(s).groupdict()
|
||||
>>> def t(s):
|
||||
... return RESULTS_RE.search(s).groupdict()
|
||||
|
||||
>>> t(r"file.py:80:1: E302 expected 2 blank lines, found 1")
|
||||
... # doctest: +NORMALIZE_WHITESPACE
|
||||
|
@ -31,17 +31,11 @@ USE_BLACK_FILELIST = re.compile(
|
||||
[
|
||||
# **
|
||||
# .ci/**
|
||||
".ci/**",
|
||||
# .github/**
|
||||
".github/**",
|
||||
# benchmarks/**
|
||||
"benchmarks/**",
|
||||
# functorch/**
|
||||
"functorch/**",
|
||||
# tools/**
|
||||
"tools/**",
|
||||
# torchgen/**
|
||||
"torchgen/**",
|
||||
# test/**
|
||||
# test/[a-h]*/**
|
||||
"test/[a-h]*/**",
|
||||
|
@ -383,9 +383,9 @@ TORCH_API bool kernel_1();
|
||||
class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.native_functions: list[NativeFunction] = []
|
||||
self.backend_indices: dict[
|
||||
DispatchKey, dict[OperatorName, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
yaml_entry = """
|
||||
- func: op(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
@ -442,9 +442,9 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||
# Test for static_dispatch
|
||||
class TestStaticDispatchGeneratrion(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.backend_indices: dict[
|
||||
DispatchKey, dict[OperatorName, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
yaml_entry = """
|
||||
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
@ -474,7 +474,8 @@ class TestCalculateShards(unittest.TestCase):
|
||||
else:
|
||||
# x.time is not None because of the above check
|
||||
self.assertAlmostEqual(
|
||||
random_times[test], sum(x.time for x in sharded_tests) # type: ignore[misc]
|
||||
random_times[test],
|
||||
sum(x.time for x in sharded_tests), # type: ignore[misc]
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(range(sharded_tests[0].num_shards)),
|
||||
|
@ -52,13 +52,11 @@ class TestPrioritizations:
|
||||
files[test.test_file] |= test
|
||||
|
||||
for test in files.values():
|
||||
assert (
|
||||
test.is_full_file()
|
||||
), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
|
||||
assert test.is_full_file(), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" # noqa: B950
|
||||
|
||||
# Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
|
||||
assert self._original_tests == set(
|
||||
files.keys()
|
||||
assert (
|
||||
self._original_tests == set(files.keys())
|
||||
), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
|
||||
|
||||
def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]:
|
||||
@ -279,9 +277,9 @@ class AggregatedHeuristics:
|
||||
|
||||
stats["heuristics"] = heuristics
|
||||
|
||||
stats[
|
||||
"aggregated"
|
||||
] = self.get_aggregated_priorities().get_priority_info_for_test(test)
|
||||
stats["aggregated"] = (
|
||||
self.get_aggregated_priorities().get_priority_info_for_test(test)
|
||||
)
|
||||
|
||||
stats["aggregated_trial"] = self.get_aggregated_priorities(
|
||||
include_trial=True
|
||||
|
@ -68,12 +68,10 @@ class BenchmarkRunner:
|
||||
self.main(args.num_samples, args.num_reps)
|
||||
|
||||
@abstractmethod
|
||||
def run_benchmark(self, *args: Any) -> None:
|
||||
...
|
||||
def run_benchmark(self, *args: Any) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def create_input(self) -> Tuple[Any, ...]:
|
||||
...
|
||||
def create_input(self) -> Tuple[Any, ...]: ...
|
||||
|
||||
def main(self, num_samples: int, num_reps: int) -> None:
|
||||
for _ in tqdm(range(num_samples)):
|
||||
|
@ -449,8 +449,8 @@ class AHTrainDecisionTree(AHTrain):
|
||||
for row in group.itertuples():
|
||||
choice2time[row.choice] = row.median_execution_time
|
||||
|
||||
assert len(unique_choices) == len(
|
||||
group
|
||||
assert (
|
||||
len(unique_choices) == len(group)
|
||||
), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
|
||||
|
||||
return pd.Series(
|
||||
|
@ -253,9 +253,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
||||
elif t.name == BaseTy.Scalar:
|
||||
return BaseCType(scalarT)
|
||||
elif isinstance(t, ListType):
|
||||
assert (
|
||||
not mutable
|
||||
), "Native functions should never return a mutable tensor list. They should return void."
|
||||
assert not mutable, "Native functions should never return a mutable tensor list. They should return void."
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
@ -378,7 +378,8 @@ class LazyIrSchema:
|
||||
self.generator_arg is None
|
||||
), "We expect there is only one generator arg"
|
||||
self.generator_arg = NamedCType(
|
||||
arg.name, arg.type # type:ignore[arg-type]
|
||||
arg.name,
|
||||
arg.type, # type:ignore[arg-type]
|
||||
)
|
||||
keyword_args.extend(
|
||||
LazyArgument(arg, self.properties, symint=symint)
|
||||
|
@ -551,9 +551,9 @@ class PythonSignatureGroup:
|
||||
|
||||
# Out overloads in C++ don't have TensorOptions arguments,
|
||||
# so take these from the functional variant
|
||||
signature_kwargs[
|
||||
"tensor_options_args"
|
||||
] = functional.signature.tensor_options_args
|
||||
signature_kwargs["tensor_options_args"] = (
|
||||
functional.signature.tensor_options_args
|
||||
)
|
||||
|
||||
return PythonSignatureGroup(
|
||||
signature=type(out.signature)(**signature_kwargs),
|
||||
|
@ -164,42 +164,42 @@ def translate(
|
||||
and isinstance(t.elem.elem, BaseCType)
|
||||
and str(t.elem.elem.type) == "at::Tensor"
|
||||
):
|
||||
ctx[
|
||||
NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
|
||||
] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
|
||||
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = (
|
||||
f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
|
||||
)
|
||||
|
||||
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
|
||||
ctx[
|
||||
NamedCType(t.name, BaseCType(optionalTensorRefT))
|
||||
] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
|
||||
ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = (
|
||||
f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
|
||||
)
|
||||
|
||||
if t.type == ConstRefCType(BaseCType(scalarT)):
|
||||
ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
|
||||
|
||||
if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
|
||||
ctx[
|
||||
NamedCType(t.name, BaseCType(optionalScalarRefT))
|
||||
] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
|
||||
ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = (
|
||||
f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
|
||||
)
|
||||
|
||||
if t.type == BaseCType(scalar_t):
|
||||
ctx[
|
||||
NamedCType(t.name, BaseCType(opmath_t))
|
||||
] = f"static_cast<opmath_t>({b.expr})"
|
||||
ctx[NamedCType(t.name, BaseCType(opmath_t))] = (
|
||||
f"static_cast<opmath_t>({b.expr})"
|
||||
)
|
||||
|
||||
# [Note: IOptTensorListRef]
|
||||
if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
|
||||
ctx[
|
||||
NamedCType(t.name, BaseCType(iOptTensorListRefT))
|
||||
] = f"at::IOptTensorListRef({b.expr})"
|
||||
ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = (
|
||||
f"at::IOptTensorListRef({b.expr})"
|
||||
)
|
||||
|
||||
# Add implicit bindings if the generated code is inside a Tensor method
|
||||
if method:
|
||||
ctx[
|
||||
NamedCType("self", MutRefCType(BaseCType(tensorT)))
|
||||
] = "const_cast<Tensor&>(*this)"
|
||||
ctx[
|
||||
NamedCType("self", ConstRefCType(BaseCType(tensorT)))
|
||||
] = "const_cast<Tensor&>(*this)"
|
||||
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = (
|
||||
"const_cast<Tensor&>(*this)"
|
||||
)
|
||||
ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = (
|
||||
"const_cast<Tensor&>(*this)"
|
||||
)
|
||||
# This is better! Byte-for-byte compat
|
||||
# ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
|
||||
|
||||
|
@ -406,9 +406,7 @@ def kernel_signature(
|
||||
meta = backend_index.get_kernel(f)
|
||||
symint = meta is not None and meta.supports_symint()
|
||||
if symint:
|
||||
assert (
|
||||
f.func.has_symint()
|
||||
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
|
||||
assert f.func.has_symint(), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
|
||||
if backend_index.external:
|
||||
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
|
||||
else:
|
||||
|
@ -194,9 +194,7 @@ if ({arg_name}_opt.has_value()) {{
|
||||
}} else {{
|
||||
{out_name} = {ctype.cpp_type(strip_ref=True)}();
|
||||
}}
|
||||
""".split(
|
||||
"\n"
|
||||
),
|
||||
""".split("\n"),
|
||||
decl,
|
||||
)
|
||||
|
||||
@ -213,9 +211,7 @@ def _gen_code_list_type(
|
||||
code.extend(
|
||||
f"""
|
||||
{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
|
||||
elif isinstance(t.elem, OptionalType):
|
||||
@ -226,9 +222,7 @@ for (c10::IValue {elem_name}: {in_name}) {{
|
||||
{connector.join(res_code)}
|
||||
{out_name}.push_back({res_name});
|
||||
}}
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
else:
|
||||
# use ArrayRef as default.
|
||||
@ -242,8 +236,6 @@ for (c10::IValue {elem_name}: {in_name}) {{
|
||||
{vec_name}.push_back({res_name});
|
||||
}}
|
||||
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
return code, decl
|
||||
|
@ -95,7 +95,7 @@ def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T
|
||||
|
||||
|
||||
def method_with_nested_native_function(
|
||||
func: Callable[[S, F3], T]
|
||||
func: Callable[[S, F3], T],
|
||||
) -> Callable[[S, F3], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(slf: S, f: F3) -> T:
|
||||
@ -108,7 +108,7 @@ def method_with_nested_native_function(
|
||||
# Convenience decorator for functions that explicitly take in a BackendIndex,
|
||||
# instead of indirectly taking one in as a closure
|
||||
def with_native_function_and_index(
|
||||
func: Callable[[F, BackendIndex], T]
|
||||
func: Callable[[F, BackendIndex], T],
|
||||
) -> Callable[[F, BackendIndex], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: F, backend_index: BackendIndex) -> T:
|
||||
@ -120,7 +120,7 @@ def with_native_function_and_index(
|
||||
|
||||
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
|
||||
def with_native_function_and_indices(
|
||||
func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
|
||||
func: Callable[[F, dict[DispatchKey, BackendIndex]], T],
|
||||
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
|
||||
|
@ -184,9 +184,7 @@ def returntype_type(t: Type, *, mutable: bool) -> CType:
|
||||
elif t.name == BaseTy.Scalar:
|
||||
return BaseCType(scalarT)
|
||||
elif isinstance(t, ListType):
|
||||
assert (
|
||||
not mutable
|
||||
), "Native functions should never return a mutable tensor list. They should return void."
|
||||
assert not mutable, "Native functions should never return a mutable tensor list. They should return void."
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
@ -127,9 +127,7 @@ class Unboxing:
|
||||
return (
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
|
||||
""".split(
|
||||
"\n"
|
||||
),
|
||||
""".split("\n"),
|
||||
decl,
|
||||
)
|
||||
|
||||
@ -147,9 +145,7 @@ class Unboxing:
|
||||
code.extend(
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toTensorList();
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
elif isinstance(t.elem, BaseType) and (
|
||||
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
|
||||
@ -157,17 +153,13 @@ class Unboxing:
|
||||
code.extend(
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toIntList();
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
|
||||
code.extend(
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toDoubleList();
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
|
||||
# handle list type with size, e.g., bool[4]
|
||||
@ -183,9 +175,7 @@ for (auto {elem_name}: {in_name}) {{
|
||||
#else
|
||||
auto {out_name} = {arg_name}.toBoolList();
|
||||
#endif
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
# pytorch codegen:
|
||||
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
|
||||
@ -205,9 +195,7 @@ for (auto {elem_name}: {in_name}) {{
|
||||
#else
|
||||
auto {out_name} = {arg_name}.toListOptionalTensor();
|
||||
#endif
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
else:
|
||||
# use ArrayRef as default.
|
||||
@ -223,8 +211,6 @@ auto {out_name} = {arg_name}.toListOptionalTensor();
|
||||
{vec_name}.push_back({res_name});
|
||||
}}
|
||||
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
|
||||
""".split(
|
||||
"\n"
|
||||
)
|
||||
""".split("\n")
|
||||
)
|
||||
return code, decl
|
||||
|
@ -96,7 +96,7 @@ class ETKernelKey:
|
||||
)
|
||||
assert (
|
||||
dim_order in dim_order_alias_map
|
||||
), "Undefined dim_order alias: " + str(dim_order)
|
||||
), f"Undefined dim_order alias: {dim_order}"
|
||||
dtype_alias_used.add(type_alias)
|
||||
|
||||
# Generate all permutations of dtype alias values
|
||||
@ -172,11 +172,11 @@ class ETKernelIndex:
|
||||
|
||||
@staticmethod
|
||||
def from_backend_indices(
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
) -> ETKernelIndex:
|
||||
kernel_index: dict[
|
||||
OperatorName, dict[ETKernelKey, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
|
||||
return ETKernelIndex(kernel_index)
|
||||
|
||||
|
@ -1362,7 +1362,7 @@ def get_grouped_by_view_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
|
||||
def maybe_create_view_group(
|
||||
d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
|
||||
d: dict[ViewSchemaKind | SchemaKind, NativeFunction],
|
||||
) -> list[NativeFunction | NativeFunctionsViewGroup]:
|
||||
funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
|
||||
if ViewSchemaKind.aliasing in d:
|
||||
@ -1409,7 +1409,7 @@ def get_grouped_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
|
||||
def flatten_pre_group(
|
||||
d: dict[SchemaKind, NativeFunction]
|
||||
d: dict[SchemaKind, NativeFunction],
|
||||
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
|
||||
r = NativeFunctionsGroup.from_dict(d)
|
||||
if r is None:
|
||||
@ -1476,9 +1476,7 @@ def get_native_function_declarations_from_ns_grouped_kernels(
|
||||
{ns_helper.prologue}
|
||||
{newline.join(ordered_kernels)}
|
||||
{ns_helper.epilogue}
|
||||
""".split(
|
||||
newline
|
||||
)
|
||||
""".split(newline)
|
||||
)
|
||||
return declarations
|
||||
|
||||
@ -1671,9 +1669,7 @@ def get_namespaced_declaration(
|
||||
{ns_helper.prologue}
|
||||
{newline.join(ordered_kernels)}
|
||||
{ns_helper.epilogue}
|
||||
""".split(
|
||||
newline
|
||||
)
|
||||
""".split(newline)
|
||||
)
|
||||
return declarations
|
||||
|
||||
@ -2386,9 +2382,7 @@ def gen_source_files(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert (
|
||||
old_header == new_header
|
||||
), """
|
||||
assert old_header == new_header, """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
|
@ -71,7 +71,10 @@ base_type_to_callsite_expr = {
|
||||
|
||||
|
||||
# convert args to C types, names in declarations, and expressions in function bodies
|
||||
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
|
||||
def convert_arg_type_and_name( # type: ignore[return]
|
||||
typ: Type,
|
||||
name: str,
|
||||
) -> tuple[list[str], list[str], list[str], list[str]]:
|
||||
if isinstance(typ, BaseType):
|
||||
if typ.name in base_type_to_c_type:
|
||||
return (
|
||||
|
@ -295,7 +295,7 @@ def gen_unboxing(
|
||||
) -> None:
|
||||
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
|
||||
def key_func(
|
||||
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
|
||||
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
|
||||
) -> str:
|
||||
return item[0].root_name + ":" + item[1][0].to_native_string()
|
||||
|
||||
@ -739,7 +739,7 @@ def parse_yaml(
|
||||
|
||||
# (2) Return BackendIndices if kernel index is absent
|
||||
def map_index(
|
||||
m: dict[OperatorName, BackendMetadata]
|
||||
m: dict[OperatorName, BackendMetadata],
|
||||
) -> dict[OperatorName, BackendMetadata]:
|
||||
return {op: m[op] for op in m if op in op_names}
|
||||
|
||||
|
@ -278,13 +278,13 @@ def assert_view_op_properties(func: FunctionSchema) -> None:
|
||||
|
||||
args = func.arguments.flat_non_out
|
||||
# The first argument is a tensor with an alias semantics (annotations)
|
||||
assert len(args) > 0 and args[0].type == BaseType(
|
||||
BaseTy.Tensor
|
||||
assert (
|
||||
len(args) > 0 and args[0].type == BaseType(BaseTy.Tensor)
|
||||
), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
|
||||
but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
|
||||
# No other arguments have aliasing semantics
|
||||
assert is_alias(args[0]) and not any(
|
||||
is_alias(a) for a in args[1:]
|
||||
assert (
|
||||
is_alias(args[0]) and not any(is_alias(a) for a in args[1:])
|
||||
), """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
|
||||
View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint"""
|
||||
|
||||
|
@ -176,9 +176,9 @@ class default_args:
|
||||
tensor_class: str = "torch::lazy::LazyTensor"
|
||||
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
|
||||
lazy_ir_generator: type[GenLazyIR] = GenLazyIR
|
||||
native_func_definition_generator: type[
|
||||
native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
|
||||
GenLazyNativeFuncDefinition
|
||||
] = GenLazyNativeFuncDefinition
|
||||
)
|
||||
backend_name: str = "TorchScript"
|
||||
|
||||
|
||||
@ -257,9 +257,9 @@ def main() -> None:
|
||||
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
|
||||
if options.gen_ts_lowerings:
|
||||
lazy_ir_generator = GenTSLazyIR
|
||||
native_func_definition_generator: type[
|
||||
GenLazyNativeFuncDefinition
|
||||
] = default_args.native_func_definition_generator
|
||||
native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
|
||||
default_args.native_func_definition_generator
|
||||
)
|
||||
|
||||
run_gen_lazy_tensor(
|
||||
aten_path,
|
||||
|
@ -1484,14 +1484,15 @@ class FunctionSchema:
|
||||
else:
|
||||
# mutable keyword arguments whose name has _scratch_ prefix are
|
||||
# scratch tensors for memory planning and should not be returned
|
||||
assert len(
|
||||
[
|
||||
arg
|
||||
for arg in self.arguments.out
|
||||
if not arg.name.startswith("_scratch_")
|
||||
]
|
||||
) == len(
|
||||
self.returns
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
arg
|
||||
for arg in self.arguments.out
|
||||
if not arg.name.startswith("_scratch_")
|
||||
]
|
||||
)
|
||||
== len(self.returns)
|
||||
), "Must return as many arguments as there are out arguments, or no return at all"
|
||||
|
||||
if self.name.name.inplace:
|
||||
@ -1590,9 +1591,7 @@ class FunctionSchema:
|
||||
), "invariant: all scratch operators are expected to be out= operators too"
|
||||
return SchemaKind.scratch
|
||||
elif is_out:
|
||||
assert (
|
||||
not is_scratch
|
||||
), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
|
||||
assert not is_scratch, "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" # noqa: B950
|
||||
return SchemaKind.out
|
||||
elif is_mutable:
|
||||
return SchemaKind.mutable
|
||||
@ -2701,9 +2700,7 @@ class NativeFunctionsViewGroup:
|
||||
)
|
||||
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||
if self.view_inplace is not None:
|
||||
assert (
|
||||
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
|
||||
), (
|
||||
assert self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel, (
|
||||
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||||
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
|
||||
)
|
||||
|
@ -263,7 +263,7 @@ def construct_register_size(register_size_from_yaml: int) -> str:
|
||||
|
||||
|
||||
def construct_version_maps(
|
||||
upgrader_bytecode_function_to_index_map: dict[str, Any]
|
||||
upgrader_bytecode_function_to_index_map: dict[str, Any],
|
||||
) -> str:
|
||||
version_map = torch._C._get_operator_version_map()
|
||||
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
|
||||
@ -305,7 +305,7 @@ def construct_version_maps(
|
||||
|
||||
|
||||
def get_upgrader_bytecode_function_to_index_map(
|
||||
upgrader_dict: list[dict[str, Any]]
|
||||
upgrader_dict: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
upgrader_bytecode_function_to_index_map = {}
|
||||
index = 0
|
||||
|
@ -366,9 +366,9 @@ def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> N
|
||||
arg_map["out_int32"] = "false"
|
||||
else:
|
||||
arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
|
||||
arg_map[
|
||||
"col_indices"
|
||||
] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
|
||||
arg_map["col_indices"] = (
|
||||
"torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
|
||||
)
|
||||
arg_map["out_int32"] = "false"
|
||||
return
|
||||
if op_name == "_convert_indices_from_coo_to_csr":
|
||||
|
Reference in New Issue
Block a user