Files
pytorch/test/inductor/test_aot_inductor.py
Mu-Chu Lee 9fccbdd4f0 Fix incorrect function signature in template (#165567)
Summary:
In https://github.com/pytorch/pytorch/pull/148305 we refactored the grid
argument out, but it's not reflected in our template.

Test Plan:
Included in commit.
python test/inductor/test_aot_inductor.py
AOTInductorTestABICompatibleGpu.test_cond_symint_input_disable_one_pass_cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165567
Approved by: https://github.com/desertfire
2025-10-17 02:40:56 +00:00

7771 lines
280 KiB
Python

# Owner(s): ["module: inductor"]
import itertools
import logging
import os
import pathlib
import subprocess
import sys
import tempfile
import unittest
import zipfile
from unittest import skip
from unittest.mock import patch
import torch
import torch._export
import torch._inductor
import torch._inductor.config
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torch.nn as nn
from torch._dynamo import config as dynamo_config
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided, same
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.codecache import WritableTempFile
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.package import package_aoti
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import TestCase
from torch._inductor.utils import (
is_big_gpu,
maybe_aoti_standalone_config,
run_and_get_cpp_code,
)
from torch._library import capture_triton
from torch._utils_internal import full_aoti_runtime_assert
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.export import Dim, export
from torch.export.pt2_archive._package import load_pt2
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
IS_SM90,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM80OrLater,
tf32_on_and_off,
)
from torch.testing._internal.common_device_type import (
_has_sufficient_memory,
e4m3_type,
skipCUDAIf,
)
from torch.testing._internal.common_quantization import (
_group_quantize_tensor,
skip_if_no_torchvision,
skipIfNoFBGEMM,
)
from torch.testing._internal.common_utils import (
DeterministicGuard,
IS_CI,
IS_FBCODE,
IS_MACOS,
IS_WINDOWS,
MACOS_VERSION,
MI300_ARCH,
parametrize,
runOnRocm,
skipIfMPS,
skipIfRocm,
skipIfRocmArch,
skipIfWindows,
skipIfWindowsXPU,
skipIfXpu,
TEST_MPS,
TEST_WITH_ROCM,
)
from torch.testing._internal.custom_tensor import CustomTensorPlainOut
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
from torch.testing._internal.triton_utils import requires_gpu
from torch.utils import _pytree as pytree
from torch.utils._triton import (
has_triton_experimental_host_tma,
has_triton_tensor_descriptor_host_tma,
)
if HAS_GPU:
import triton # @manual
from triton import language as tl
from torch.testing._internal.triton_utils import (
add_kernel,
add_kernel_2d_autotuned,
add_kernel_autotuned,
add_kernel_autotuned_weird_param_order,
add_kernel_on_device_tma_new_api,
add_kernel_on_device_tma_old_api,
add_kernel_with_boolean_param,
add_kernel_with_none_param_and_equal_to_1_arg,
add_kernel_with_optional_param,
add_kernel_with_scaling,
add_kernel_with_tma_1d_new_api,
add_kernel_with_tma_1d_old_api,
add_kernel_with_tma_2d_new_api,
add_kernel_with_tma_2d_old_api,
create_tensor_descriptor_shim,
mul2_inplace_kernel,
strange_config_matmul_kernel,
sub_kernel_autotuned,
)
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
try:
try:
from .test_aot_inductor_utils import (
AOTIRunnerUtil,
check_model,
check_model_with_multiple_inputs,
code_check_count,
)
from .test_control_flow import (
CondModels,
prepend_counters,
prepend_predicates,
WhileLoopModels,
)
from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
except ImportError:
from test_aot_inductor_utils import ( # @manual=fbcode//caffe2/test/inductor:aot_inductor_utils-library
AOTIRunnerUtil,
check_model,
check_model_with_multiple_inputs,
code_check_count,
)
from test_control_flow import ( # @manual=fbcode//caffe2/test/inductor:control_flow-library
CondModels,
prepend_counters,
prepend_predicates,
WhileLoopModels,
)
from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
copy_tests,
requires_multigpu,
TestFailure,
)
except (unittest.SkipTest, ImportError):
if __name__ == "__main__":
sys.exit(0)
raise
def get_module_ext_type():
if IS_WINDOWS:
return "pyd"
else:
return "so"
class AOTInductorTestsTemplate:
# Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for
# AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners
@common_utils.parametrize("embed_kernel_binary", [False, True])
@common_utils.parametrize("max_autotune", [False, True])
@skipIfRocmArch(MI300_ARCH)
def test_simple(self, embed_kernel_binary, max_autotune):
if self.device == "cpu" and IS_MACOS and max_autotune:
raise unittest.SkipTest("max_autotune not supported on macos")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model()
with config.patch(
{
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
"max_autotune": max_autotune,
}
):
self.check_model(model, example_inputs)
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
if self.device == "mps":
FileCheck().check("aoti_torch_mps_get_kernel_function(").run(code)
elif self.device == GPU_TYPE:
FileCheck().check("launchKernel(").run(code)
if config.aot_inductor.embed_kernel_binary:
# Not expect to see launchKernel("CUBIN_FILE_NAME"
FileCheck().check_not('launchKernel("').run(code)
if self.use_minimal_arrayref_interface:
self.code_check_count(
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
)
def test_triton_kernel_bool_param(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x):
out = torch.zeros_like(x)
add_kernel_with_boolean_param[1,](
in_ptr0=x,
in_ptr1=x,
out_ptr=out,
n_elements=x.numel(),
add_xy=True,
BLOCK_SIZE=1,
)
return out
inputs = (torch.randn(4, device=self.device),)
self.check_model(Model(), inputs)
@unittest.skipIf(
IS_FBCODE,
"toolchain doesn't support ptx to fatbin",
)
@skipIfMPS
@skipIfRocm
# Skip embed_kernel_binary == True for now as it shows random
# failure on CI
@common_utils.parametrize("embed_kernel_binary", [False])
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
def test_simple_multi_arch(self, embed_kernel_binary):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU_TYPE")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 16)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 16, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model()
with config.patch(
{
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
"aot_inductor.emit_multi_arch_kernel": True,
}
):
self.check_model(model, example_inputs)
if not embed_kernel_binary:
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
file_extension = ".spv" if self.device == "xpu" else ".fatbin"
FileCheck().check(file_extension).run(code)
def test_small_constant(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear(x)
example_inputs = (torch.randn(4, 4, device=self.device),)
with config.patch({"always_keep_tensor_constants": True}):
self.check_model(Model().to(self.device), example_inputs)
def test_output_path_1(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
with config.patch("aot_inductor.output_path", "tmp_output_"):
self.check_model(Model(), example_inputs)
def test_output_path_2(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
model = Model().to(device=self.device)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
expected_path = normalize_path_separator(
os.path.join(
tempfile.mkdtemp(dir=cache_dir()), f"model.{get_module_ext_type()}"
)
)
actual_path = AOTIRunnerUtil.legacy_compile(
model, example_inputs, options={"aot_inductor.output_path": expected_path}
)
self.assertTrue(actual_path == expected_path)
@unittest.skipIf(
config.triton.native_matmul,
"different # of input/output/constants in native matmul",
)
def test_empty_constant_folding(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w = torch.randn(4, 4, device=device)
self.b = torch.randn(4, device=device)
def forward(self, x):
return torch.matmul(x, self.w) + self.b
model = Model(self.device)
example_inputs = (torch.randn(4, 4, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
so_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
# We should have 1 input, 1 output, 2 constants for the model.
FileCheck().check_count("AOTInductorModelBase(1,", 1).check_next(
"1,"
).check_next("2,").run(code)
def test_constant_folding(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w_pre = torch.randn(4, 4, device=device)
self.b = torch.randn(4, device=device)
def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.b
return torch.matmul(x, w)
example_inputs = (torch.randn(4, 4, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
def test_constant_folding_with_update(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w_pre = torch.randn(4, 4, device=device)
self.b = torch.randn(4, device=device)
def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.b
return torch.matmul(x, w)
example_inputs = (torch.randn(4, 4, device=self.device),)
with (
torch.no_grad(),
config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.use_runtime_constant_folding": True,
}
),
):
model = Model(self.device)
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(4, 4, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
# Update with new weights on active buffer
new_weights = {
"L__self___b": torch.randn(4, device=self.device),
"L__self___w_pre": torch.randn(4, 4, device=self.device),
}
model.w_pre = new_weights["L__self___w_pre"]
model.b = new_weights["L__self___b"]
expected = model(test_inputs)
runner.update_constant_buffer(new_weights, False, False)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
# Update with new weights on inactive buffer
new_weights = {
"L__self___b": torch.randn(4, device=self.device),
"L__self___w_pre": torch.randn(4, 4, device=self.device),
}
model.w_pre = new_weights["L__self___w_pre"]
model.b = new_weights["L__self___b"]
expected = model(test_inputs)
runner.update_constant_buffer(new_weights, True, False)
new_output = runner_call(test_inputs)
# We have not yet swapped the buffer, new_output should be the same as the old one.
self.assertEqual(output, new_output)
# Swap the buffer, should get the correct result now.
runner.swap_constant_buffer()
new_output = runner_call(test_inputs)
self.assertEqual(expected, new_output)
@requires_gpu
def test_duplicate_constant_folding(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w1 = torch.randn(4, 4, device=device)
self.w2 = torch.randn(4, 4, device=device)
self.w3 = torch.randn(4, 4, device=device)
self.w4 = torch.randn(4, 4, device=device)
def forward(self, x):
w_concat = torch.cat((self.w1, self.w2, self.w3, self.w4))
return torch.cat((x, w_concat))
example_inputs = (torch.randn(4, 4, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
def test_autotune_with_constant_folding(self):
class Model(torch.nn.Module):
def __init__(self, device) -> None:
super().__init__()
self.x = torch.randn(2048, 2048, dtype=torch.float16, device=device)
def _quantize(self, input):
return torch.abs(input)
def forward(self, y):
abs_weight = self._quantize(self.x)
abs_y = self._quantize(y)
return abs_weight, abs_y
input1 = (torch.rand(2048, 2048, dtype=torch.float16, device=self.device),)
model = Model(self.device).to(self.device)
_ = model(*input1)
ep = torch.export.export(model, input1, dynamic_shapes=None, strict=False)
torch._inductor.aoti_compile_and_package(
ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True}
)
@unittest.skipIf(
TEST_MPS and MACOS_VERSION < 14.0,
"Compilation error",
)
def test_aot_inductor_consts_cpp_build(self):
class Model(torch.nn.Module):
def __init__(self, device) -> None:
super().__init__()
self.x = torch.randn(2048, 2048, dtype=torch.float16, device=device)
def _quantize(self, input):
return torch.abs(input)
def forward(self, y):
abs_weight = self._quantize(self.x)
abs_y = self._quantize(y)
return abs_weight, abs_y
input1 = (torch.rand(2048, 2048, dtype=torch.float16, device=self.device),)
model = Model(self.device).to(self.device)
_ = model(*input1)
ep = torch.export.export(model, input1, dynamic_shapes=None, strict=False)
torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={
"aot_inductor.use_runtime_constant_folding": True,
"aot_inductor.use_consts_asm_build": False,
},
)
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("tma_version", ["new", "old"])
def test_triton_kernel_on_device_tma(self, dynamic, tma_version):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
if tma_version == "new" and not has_triton_tensor_descriptor_host_tma():
self.skipTest("requires triton.tools.tensor_descriptor TMA support")
if tma_version == "old" and not has_triton_experimental_host_tma():
self.skipTest("requires triton.tools.experimental_descriptor TMA support")
kernel = (
add_kernel_on_device_tma_new_api
if tma_version == "new"
else add_kernel_on_device_tma_old_api
)
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
BLOCK_SIZE = 32
out = torch.zeros_like(a)
m, n = out.size()
# Allocate workspace for on-device TMA descriptors
# Need 128 bytes per descriptor, 3 descriptors total
if tma_version == "old":
workspace = torch.zeros(3 * 128, dtype=torch.uint8, device=a.device)
else:
workspace = None
grid = (triton.cdiv(m, BLOCK_SIZE), triton.cdiv(n, BLOCK_SIZE))
kernel[grid](
a,
b,
out,
m,
n,
workspace,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
a = torch.randn((32 * 4, 32 * 8), device=self.device)
b = torch.randn((32 * 4, 32 * 8), device=self.device)
example_inputs = (a, b)
triton.set_allocator(
lambda size, align, stream: torch.empty(
size, dtype=torch.int8, device=GPU_TYPE
)
)
dynamic_shapes = None
if dynamic:
dim0 = Dim("s0", min=2, max=1024)
dim1 = Dim("s1", min=2, max=1024)
dynamic_shapes = {
"a": {0: dim0, 1: None},
"b": {0: dim1, 1: None},
}
self.check_model(
Model(),
example_inputs=example_inputs,
dynamic_shapes=dynamic_shapes,
)
@requires_gpu
def test_multi_device(self):
if self.device == "cpu" and GPU_TYPE == "xpu":
raise unittest.SkipTest(
"In this scenario, the test case will run XPU code in "
"AOTIModelContainerRunnerCpu, which is not reasonable,"
"See issue #140805"
)
class Model(torch.nn.Module):
def forward(self, x):
x = x + 1
x = x.cpu()
x = x + 2
x = x.to(GPU_TYPE)
return x
example_inputs = (torch.randn(32, 64, device=self.device),)
self.check_model(Model(), example_inputs)
@unittest.skip(
"install_free_tensors leads to OOM - https://github.com/pytorch/pytorch/issues/164062"
)
def test_large_weight(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2048, 262144)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(1, 262144, device=self.device),
torch.randn(1, 2048, device=self.device),
)
# We only test compilation since we often get OOM running in CI.
model = Model()
model = model.to(self.device)
AOTIRunnerUtil.compile(model, example_inputs)
def test_constant_type_propagation(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w_pre = torch.randn(4, 4, device=device)
self.b = torch.randn(4, device=device)
def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.b
return torch.matmul(x, w)
model = Model(self.device)
example_inputs = (torch.randn(4, 4, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
so_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
FileCheck().check_not("torch::aot_inductor::ConstantType::Unknown").run(
code
)
def test_subclasses(self):
device_to_init = self.device
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.p1 = torch.nn.Parameter(torch.ones(3, 4, device=device_to_init))
self.p2 = torch.nn.Parameter(
CustomTensorPlainOut(
torch.ones(3, 4, device=device_to_init),
torch.ones(3, 4, device=device_to_init),
)
)
def forward(self, x):
a = (2 * self.p1 + self.p2).sum()
return x + a
m = Foo()
ref_x = torch.randn(3, 4, device=device_to_init)
with torch.no_grad():
result = AOTIRunnerUtil.run(
m,
(ref_x,),
)
actual = m(ref_x)
self.assertTrue(same(result, actual))
def test_large_mmaped_weights(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 250112)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(1, 250112, device=self.device),
torch.randn(1, 512, device=self.device),
)
with config.patch({"aot_inductor.force_mmap_weights": True}):
self.check_model(Model(), example_inputs)
def test_large_mmaped_weights_on_disk(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 250112)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(1, 250112, device=self.device),
torch.randn(1, 512, device=self.device),
)
with config.patch(
{"aot_inductor.package_constants_on_disk_format": "binary_blob"}
):
self.check_model(Model(), example_inputs)
def test_with_offset(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.orig_tensor = torch.randn(2, 15, 10, device=device)[0]
self.tensor = self.orig_tensor[5:, :]
def forward(self, x, y):
return (
x
+ torch.nn.functional.linear(y, self.orig_tensor[:10, :])
+ self.tensor
)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_freezing(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(9, 10, device=device)
self.padding = torch.randn(1, 10, device=device)
def forward(self, x, y):
padded_weight = torch.cat((self.weight, self.padding), dim=0)
return x + torch.nn.functional.linear(y, padded_weight)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_conv_freezing(self):
dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float]
for dtype, groups in itertools.product(dtypes, [1, 2]):
iC = 2
oC = 3
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
dtype
)
def forward(self, y):
return torch.nn.functional.conv2d(y, self.weight, groups=groups)
example_inputs = (
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
@tf32_on_and_off(0.005)
def test_deconv_freezing(self):
dtypes = [torch.float]
if torch._C._has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
for dtype, groups in itertools.product(dtypes, [2, 1]):
iC = 4
oC = 2
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
dtype
)
def forward(self, y):
return torch.nn.functional.conv_transpose2d(
y, self.weight, groups=groups
)
example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_linear_freezing(self):
dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float]
for dtype in dtypes:
class LinearModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)
self.bias = torch.randn(10, device=device).to(dtype)
def forward(self, y):
return torch.nn.functional.linear(y, self.weight, self.bias)
example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
model = LinearModel(device=self.device)
self.check_model(model, example_inputs)
def test_same_backing(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo2",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo2", "CompositeExplicitAutograd", lib=lib)
def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
class M(torch.nn.Module):
def forward(self, a, b):
x = a.shape[0]
y = b.shape[0]
a = torch.cat([a, a])
a = torch.ops.mylib.foo2(a, a)
a = a * x
b = torch.cat([b, b])
b = torch.ops.mylib.foo2(b, b)
b = b * y
return a, b
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
self.check_model(M(), inp)
@unittest.skipIf(
TEST_MPS and MACOS_VERSION < 14.0,
"MPS BFloat16 is only supported on MacOS 14+",
)
def test_empty_cat_dtype_promotion(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
z = torch.cat([x, y], dim=1)
z = z.to(dtype=torch.bfloat16)
return z * 2
model = Foo()
inps = (torch.randn(4, 10, dtype=torch.bfloat16), torch.randn(4, 0))
self.check_model(model, inps)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_linear_dynamic_maxautotune(self):
if self.device == "cpu":
raise unittest.SkipTest("using triton backend only is not supported on CPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
model = Model().to(device=self.device)
compile_inputs = (torch.randn(2048, 1, device=self.device),)
dim0_x = Dim("dim0_x", min=2, max=2048)
dynamic_shapes = {"x": {0: dim0_x}}
ep = torch.export.export(
model, compile_inputs, dynamic_shapes=dynamic_shapes, strict=True
)
optimized = torch._inductor.aoti_load_package(
torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
)
runtime_input = torch.randn(10, 1, device=self.device)
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
runtime_input = torch.randn(16, 1, device=self.device)
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
runtime_input = torch.randn(100, 1, device=self.device)
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
"remove_split_with_size_one_pass": {},
"merge_getitem_cat_pass": {},
"merge_stack_tahn_unbind_pass": {},
"merge_splits_pass": {},
"mutate_cat_pass": {},
"split_cat_pass": {},
"unbind_stack_pass": {},
},
post_grad_fusion_options={},
)
def test_simple_split(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
example_inputs = (torch.randn(2, 8, device=self.device),)
counters.clear()
model = Model().to(device=self.device)
actual = AOTIRunnerUtil.legacy_run(self.device, model, example_inputs)
self.assertTrue(same(model(*example_inputs), actual))
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
def test_amp_fallback_random(self):
def fn(x, w):
return torch.functional.F.linear(x, w)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
with config.patch({"fallback_random": True}):
with torch.amp.autocast(device_type=self.device):
self.check_model(fn, example_inputs)
def test_missing_output(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
a = torch.sin(x)
b = torch.mm(a, y)
c = torch.cos(b)
return c
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_output_misaligned(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
x_unsqueeze = torch.unsqueeze(x, dim=0)
y_unsqueeze = torch.unsqueeze(y, dim=0)
cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
x_getitem = cat[0]
y_getitem = cat[1]
x_sigmoid = torch.sigmoid(x_getitem)
return x_sigmoid, y_getitem
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(), example_inputs)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
@skip("Test was marked as expected failure, but does not fail always anymore.")
def test_dynamic_smem_above_default_limit(self):
if self.device == "cpu":
raise unittest.SkipTest("using triton backend only is not supported on CPU")
class Model(torch.nn.Module):
def forward(self, x, y):
return x @ y
model = Model().to(self.device)
# on A100, the generated Triton kernel for this MM
# requires 55296 bytes of dynamic SMEM which is above
# the A100's default dynamic SMEM limit of 49152 bytes.
example_inputs = (
torch.randn(10285, 96, device=self.device),
torch.randn(96, 1, device=self.device),
)
self.check_model(
model,
example_inputs,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
def test_seq(self):
layernorm = torch.nn.LayerNorm(10)
net = torch.nn.Sequential(
layernorm,
torch.nn.ReLU(),
layernorm,
torch.nn.ReLU(),
)
example_inputs = (torch.randn(10, device=self.device),)
self.check_model(net.eval(), example_inputs)
def test_addmm(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M = 8
N = 6
K = 16
model = Model(N, K, self.device)
batch = 2
a = torch.randn(batch, M, K, device=self.device)
# We should be able to call self.check_model here, but torch.export.export
# constants (non-parameter, non-buffer) doesn't work today.
example_inputs = (a,)
self.check_model(model, example_inputs)
def test_aliased_buffer_reuse(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
x = 2 * x
y = 2 * y
c = torch.cat([x, y], dim=-1)
d = 1 + c
m = torch.mm(d, d)
return m[:, :2] + x
example_inputs = (
torch.randn(4, 2, device=self.device),
torch.randn(4, 2, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_buffer_reuse(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
a = torch.sin(x)
b = torch.cos(y)
c = torch.mm(a, b)
d = torch.relu(c)
e = torch.sigmoid(d)
f = torch.mm(x, y)
g = e + f
return g
example_inputs = (
torch.randn(4, 4, device=self.device),
torch.randn(4, 4, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_duplicated_params(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.p = torch.nn.Parameter(torch.rand(6))
self.q = self.p
def forward(self, x):
return self.p * x + self.q
example_inputs = (torch.rand(6, device=self.device),)
self.check_model(Model(), example_inputs)
@unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
def test_inf(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
x = torch.randn(10, 10, device=self.device)
x[0][0] = float("Inf")
example_inputs = (
x,
torch.randn(10, 10, device=self.device),
)
self.check_model(
Model().to(self.device),
example_inputs,
options={"debug_check_inf_and_nan": True},
)
@unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
def test_nan(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
x = torch.randn(10, 10, device=self.device)
x[0][0] = float("nan")
example_inputs = (
x,
torch.randn(10, 10, device=self.device),
)
self.check_model(
Model().to(self.device),
example_inputs,
options={"debug_check_inf_and_nan": True},
)
@skipIfWindowsXPU(msg="crash on Windows XPU.")
def test_assert_async(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU_TYPE")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
u0 = x.item()
torch._check(u0 > 3)
return torch.ones(u0)[0]
x = torch.tensor(23, device=self.device)
example_inputs = (x,)
self.check_model(Model(), example_inputs)
def test_simple_dynamic(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
add_0 = x + y
return torch.nn.functional.relu(input=add_0, inplace=False)
x = torch.randn(128, 2048, device=self.device)
y = torch.randn(128, 2048, device=self.device)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
example_inputs = (x, y)
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation")
def test_large_dynamic_dim(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
add_0 = x + y
return torch.nn.functional.relu(input=add_0, inplace=False)
x = torch.randn(128, 2048, device=self.device)
y = torch.randn(128, 2048, device=self.device)
# Use a dimension that exceeds the maximum value of a C long long (2^63 - 1)
dim0_x = Dim("dim0_x", min=1, max=1171368248680556527362)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
example_inputs = (x, y)
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@skipIfXpu
def test_fp8(self):
# cuda only
if self.device != "cuda":
return
class Model(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.out_dtype = dtype
def forward(self, x, weight, bias, scale_a, scale_b):
weight = weight.to(e4m3_type)
output = torch._scaled_mm(
x,
weight,
bias=input_bias,
out_dtype=self.out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
return output
dtype = torch.float16
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
weight_shape = (32, 16)
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale
x_shape = (16, 16)
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(e4m3_type)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
self.check_model(
Model(dtype),
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(
TEST_WITH_ROCM or not IS_SM90,
"scaled_grouped_mm is only supported on SM90",
)
@skipIfXpu
def test_scaled_grouped_mm(self):
# Test torch._scaled_grouped_mm AOTI lowering
# cuda only
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, scale_a, scale_b, offsets):
# x: [num_groups, batch, in_features] - FP8 inputs
# weight: [total_out_features, in_features] - FP8 weights (transposed)
# scale_a: [num_groups] - input scales
# scale_b: [num_groups] - weight scales
# offsets: [num_groups] - cumulative output sizes
output = torch._scaled_grouped_mm(
x,
weight.t(),
scale_a=scale_a,
scale_b=scale_b,
offs=offsets,
use_fast_accum=True,
)
return output.half()
dtype = torch.float16
num_groups = 3
batch_size = 64
in_features = 128
out_features_list = [64, 128, 256] # Different output sizes for each group
device = GPU_TYPE
# Calculate offsets (cumulative output sizes)
offsets = torch.cumsum(torch.tensor(out_features_list), dim=0).to(
device, dtype=torch.int32
)
total_out_features = sum(out_features_list)
# Create FP8 input tensors - stacked for all groups
x_fp16 = torch.randn(
num_groups, batch_size, in_features, dtype=dtype, device=device
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
# Create FP8 weight tensor - concatenated and transposed
weight_fp16 = torch.randn(
total_out_features, in_features, dtype=dtype, device=device
)
weight_fp8 = weight_fp16.to(torch.float8_e4m3fn)
# Create scales
scale_a = torch.ones(num_groups, batch_size, device=device, dtype=torch.float32)
scale_b = torch.ones(total_out_features, device=device, dtype=torch.float32)
self.check_model(
Model(),
(x_fp8, weight_fp8, scale_a, scale_b, offsets),
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@skipIfXpu
def test_fp8_view_of_param(self):
# cuda only
if self.device != GPU_TYPE:
return
class Model(torch.nn.Module):
def __init__(self, dtype, weight):
super().__init__()
self.out_dtype = dtype
self.weight = weight
def forward(self, x, bias, scale_a, scale_b):
# test: do the view inside of the graph,
# AOTI needs to materialize this view before passing
# it into the scaled_mm extern kernel
weight = self.weight.T
output = torch._scaled_mm(
x,
weight,
bias=input_bias,
out_dtype=self.out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
return output
dtype = torch.float16
a_scale = torch.Tensor([1.0]).to(device=self.device)
b_scale = torch.Tensor([1.0]).to(device=self.device)
input_bias = torch.rand(32, device=self.device, dtype=dtype)
weight_shape = (32, 16)
weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).to(
e4m3_type
)
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale
x_shape = (16, 16)
x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(e4m3_type)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None, None)
self.check_model(
Model(dtype, weight),
(x, input_bias, a_inverse_scale, b_inverse_scale),
dynamic_shapes=dynamic_shapes,
)
def test_poi_multiple_dynamic(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
add_0 = x + y
return torch.nn.functional.relu(input=add_0, inplace=False)
x = torch.randn(128, 2048, device=self.device)
y = torch.randn(128, 2048, device=self.device)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
list_example_inputs = [(x, y)]
list_example_inputs.append(
(
torch.randn(64, 2048, device=self.device),
torch.randn(64, 2048, device=self.device),
),
)
list_example_inputs.append(
(
torch.randn(211, 2048, device=self.device),
torch.randn(211, 2048, device=self.device),
),
)
self.check_model_with_multiple_inputs(
Model(), list_example_inputs, dynamic_shapes=dynamic_shapes
)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_addmm_multiple_dynamic(self):
if self.device == "cpu":
raise unittest.SkipTest("using triton backend only is not supported on CPU")
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M = 8
N = 6
K = 16
model = Model(N, K, self.device)
batch = 2
a = torch.randn(batch, M, K, device=self.device)
dim0_a = Dim("dim0_a", min=1, max=2048)
dynamic_shapes = {"a": {0: dim0_a}}
list_example_inputs = [(a,)]
batch = 2048
list_example_inputs.append(
(torch.randn(batch, M, K, device=self.device),),
)
batch = 128
list_example_inputs.append(
(torch.randn(batch, M, K, device=self.device),),
)
self.check_model_with_multiple_inputs(
model,
list_example_inputs,
dynamic_shapes=dynamic_shapes,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_bmm_multiple_dynamic(self):
if self.device == "cpu":
raise unittest.SkipTest("using triton backend only is not supported on CPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.bmm(a, b)
M = 8
N = 6
K = 16
model = Model()
batch = 1024
a = torch.randn(batch, M, K, device=self.device)
b = torch.randn(batch, K, N, device=self.device)
dim0_a = Dim("dim0_a", min=1, max=2048)
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_a}}
list_example_inputs = [(a, b)]
batch = 2048
list_example_inputs.append(
(
torch.randn(batch, M, K, device=self.device),
torch.randn(batch, K, N, device=self.device),
),
)
batch = 128
list_example_inputs.append(
(
torch.randn(batch, M, K, device=self.device),
torch.randn(batch, K, N, device=self.device),
),
)
self.check_model_with_multiple_inputs(
model,
list_example_inputs,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
dynamic_shapes=dynamic_shapes,
)
@skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation")
def test_foreach_multiple_dynamic(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
x_unsqueeze = torch.unsqueeze(x, dim=0)
y_unsqueeze = torch.unsqueeze(y, dim=0)
cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
return cat
model = Model()
x = torch.randn(128, 2048, device=self.device)
y = torch.randn(128, 2048, device=self.device)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
list_example_inputs = [(x, y)]
list_example_inputs.append(
(
torch.randn(64, 2048, device=self.device),
torch.randn(64, 2048, device=self.device),
),
)
list_example_inputs.append(
(
torch.randn(211, 2048, device=self.device),
torch.randn(211, 2048, device=self.device),
),
)
self.check_model_with_multiple_inputs(
model,
list_example_inputs,
dynamic_shapes=dynamic_shapes,
)
# scaled_dot_product_flash_attention
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
def test_sdpa(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, q, k, v):
return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]
example_inputs = (
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
)
self.check_model(Model(), example_inputs)
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
@unittest.skipIf(
# for archs where this isn't lowered to flash attention, the math
# backend will be used and it doesn't work for bfloat16
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Some archs don't support SDPA with bfloat16",
)
def test_sdpa_2(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, q, k, v, x):
t = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=True
)[0]
return x + t
example_inputs = (
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
)
self.check_model(Model(), example_inputs)
@skipIfNoFBGEMM
def test_quantized_linear(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device)
self.bias = torch.randn(10, device=device)
def forward(self, x):
return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
x, self.weight, self.bias
)
example_inputs = (torch.randn(10, 10, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
@skipIfNoFBGEMM
def test_quantized_linear_bias_none(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device)
def forward(self, x):
return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
x, self.weight, None
)
example_inputs = (torch.randn(10, 10, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
@skipIfNoFBGEMM
def test_quanatized_int8_linear(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device)
self.bias = torch.randn(10, device=device)
self.input_scale = torch.tensor(0.1)
self.input_zero_point = torch.tensor(0)
self.weight_scale = torch.tensor(0.1)
self.weight_zero_point = torch.tensor(0)
self.output_scale = torch.tensor(0.1)
self.output_zero_point = torch.tensor(0)
self.out_channel = 10
def forward(self, x):
return torch.ops._quantized.wrapped_quantized_linear(
x,
self.input_scale,
self.input_zero_point,
self.weight,
self.weight_scale,
self.weight_zero_point,
self.bias,
self.output_scale,
self.output_zero_point,
self.out_channel,
)
example_inputs = (torch.randn(10, 10, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
def test_zero_grid_with_unbacked_symbols(self):
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
nz = torch.nonzero(x)
b = torch.ones_like(nz, dtype=torch.float16)
c = torch.zeros_like(nz, dtype=torch.float16)
d = (b + c) @ y
return d.sum()
example_inputs = (
torch.tensor([1, 1, 1], device=self.device),
torch.randn((1, 32), dtype=torch.float16, device=self.device),
)
self.check_model(Repro(), example_inputs)
@skipIfMPS
@config.patch({"unbacked_symint_fallback": 12})
@parametrize("shift_k", [0, 1, 2, 3])
@parametrize("use_static_size", [True, False])
def test_unbacked_expr_replacements(self, shift_k, use_static_size):
"""
Test parameters
- shift_k: Validates that torch._check assertion order doesn't affect
results by shifting the order of torch._checks
- use_static_size: Tests torch._check compatibility between unbacked
symbolic expressions and static shapes
"""
if self.device != GPU_TYPE:
raise unittest.SkipTest("Need triton for user-defined triton kernel")
def realize_out_tensor_with_size(size):
STATIC_DIM = 256 # large enough to hit IMA w/o compute-sanitizer
tensor = torch.ones((size, STATIC_DIM), device=self.device)
# Realize the tensor as an intermediate buffer
nrows, ncols = tensor.shape
numel = tensor.numel()
add_kernel[nrows,](
in_ptr0=tensor,
in_ptr1=tensor,
out_ptr=tensor,
n_elements=numel,
BLOCK_SIZE=ncols,
)
return tensor
class Repro(torch.nn.Module):
def forward(self, x, y, lst):
STATIC_SIZE = 300
s0, s1 = x.shape
s2, s3 = y.shape
u0, u1, u2, u3, u100 = lst.tolist()
expr1 = s0 + u0
expr2 = s1 + u1
expr3 = (s2 * s3) + (u2 // u3) # make this one a lil complicated
expr4 = STATIC_SIZE if use_static_size else u100
t1 = realize_out_tensor_with_size(expr1)
t2 = realize_out_tensor_with_size(expr2)
t3 = realize_out_tensor_with_size(expr3)
t4 = realize_out_tensor_with_size(expr4)
# shift tensors to change up the torch._check order
tensors = [t1, t2, t3, t4]
shifted_tensors = tensors[shift_k:] + tensors[:shift_k]
# torch.cat implicitly runs torch._check(lhs == rhs)
cat = torch.cat(shifted_tensors, dim=1)
return cat * cat
# Disable cuda caching allocator to check for IMA
torch.cuda.caching_allocator_enable(False)
model = Repro()
example_inputs = (
# s0, s1
torch.randn((100, 200), device=self.device),
# s2, s3
torch.randn((100, 3), device=self.device),
# u0, u1, u2, u3, u100
torch.tensor([200, 100, 0, 1, 300], device=self.device, dtype=torch.int),
)
spec = {
"x": (Dim.DYNAMIC, Dim.DYNAMIC),
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
"lst": (Dim.STATIC,),
}
self.check_model(model, example_inputs, dynamic_shapes=spec)
torch.cuda.caching_allocator_enable(True)
@skipIfMPS
@config.patch({"unbacked_symint_fallback": 12})
@config.patch({"triton.autotune_at_compile_time": None})
def test_replace_unbacked_symbol_with_backed_expr(self):
# This will test how autotune_at_compile_time generates sample inputs
# when the user torch._checks(s0 + s1 == u0).
# We may fail with IMA if the generated input sizes aren't correct.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires triton")
def force_realize(tensor):
# Realize the tensor as an intermediate buffer
nrows, ncols = tensor.shape
numel = tensor.numel()
add_kernel[nrows,](
in_ptr0=tensor,
in_ptr1=tensor,
out_ptr=tensor,
n_elements=numel,
BLOCK_SIZE=ncols,
)
INNER_DIM = 256
class Repro(torch.nn.Module):
def forward(self, x, y, lengths):
# Realize an intermediate buffer with backed shape: s0 + s1
relevant_embeddings = torch.cat([x, y], dim=0)
force_realize(relevant_embeddings)
# Realize an intermediate buffer with unbacked shape: u0
num_relevant_embeddings = lengths.nonzero().size(0)
ones = torch.ones((num_relevant_embeddings, INNER_DIM), device=x.device)
force_realize(ones)
# Add deferred runtime assertion: s0 + s1 == u0
torch._check(relevant_embeddings.size(0) == ones.size(0))
relevant_embeddings += ones
return relevant_embeddings * relevant_embeddings
torch.cuda.caching_allocator_enable(False)
model = Repro()
example_inputs = (
torch.randn((1000, INNER_DIM), device=self.device),
torch.randn((2000, INNER_DIM), device=self.device),
torch.ones(3000),
)
spec = {
"x": (Dim.DYNAMIC, Dim.STATIC),
"y": (Dim.DYNAMIC, Dim.STATIC),
"lengths": (Dim.DYNAMIC,),
}
self.check_model(model, example_inputs, dynamic_shapes=spec)
torch.cuda.caching_allocator_enable(True)
@config.patch({"triton.autotune_at_compile_time": None})
def test_stride_with_unbacked_expr(self):
class Repro(torch.nn.Module):
def forward(self, x, y):
u0 = x.item()
torch._check(u0 >= 1)
s0 = y.size(0)
expr = u0 * s0
sevens = torch.empty_strided(
size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device
).fill_(7)
return sevens * 3
example_inputs = (
torch.scalar_tensor(2, dtype=torch.int, device=self.device),
torch.ones(8, device=self.device),
)
self.check_model(Repro(), example_inputs)
@unittest.skipIf(
TEST_MPS and MACOS_VERSION < 14.0,
"bfloat16 is only supported on MacOS 14+",
)
def test_size_with_unbacked_add_expr(self):
# Tests AOTI autotuning to make sure the correct input tensor sizes
# are generated for sizes that include an expr such as s0 + u0.
class Repro(torch.nn.Module):
def forward(self, values, repeats, mask, embeddings, x, z, scalar):
repeat_interleave = torch.repeat_interleave(values, repeats)
index = torch.clamp(repeat_interleave, min=0, max=400).int()
index_select = torch.index_select(embeddings, 0, index)
backed = z.size(0)
unbacked = scalar.item()
unbacked_add_expr = backed + unbacked
repeated = x.repeat(unbacked_add_expr, 1)
return torch.cat([repeated, index_select], dim=1)
example_inputs = (
torch.ones(64, dtype=torch.int64, device=self.device),
torch.ones(64, dtype=torch.int64, device=self.device) * 12,
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
torch.randn((1, 256), dtype=torch.bfloat16, device=self.device),
torch.ones(758, 127, dtype=torch.int64, device=self.device),
torch.scalar_tensor(10, dtype=torch.int32, device=self.device),
)
spec = {
"values": (Dim.DYNAMIC,),
"repeats": (Dim.DYNAMIC,),
"mask": (Dim.DYNAMIC,),
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
"x": (Dim.STATIC, Dim.STATIC),
"z": (Dim.DYNAMIC, Dim.STATIC),
"scalar": (),
}
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
@skipIfWindowsXPU(msg="crash on Windows XPU.")
def test_size_with_unbacked_add_expr_transitive(self):
# Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked).
# When generating example input sizes for autotuning, it should coalesce
# expr1, expr2, unbacked into a single size.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
index = torch.repeat_interleave(values, repeats)
index_select = torch.index_select(embeddings, 0, index)
u0, u1 = lst.tolist()
backed0, backed1 = z.size(0), z.size(1)
repeated0 = y.repeat(backed0 + u0, 1)
repeated1 = x.repeat(backed1 + u1, 1)
out1 = torch.empty_like(repeated1)
add_kernel[(out1.numel(),)](
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
)
# Implicitly add torch._check(expr2, unbacked)
cat = torch.cat([out1, index_select], dim=1)
add = repeated0 + repeated1
# Explicitly add torch._check(expr1, expr2)
torch._check(repeated0.size(0) == out1.size(0))
return cat, add
example_inputs = (
torch.ones(64, dtype=torch.int64, device=self.device),
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.ones(758, 758, dtype=torch.int64, device=self.device),
torch.tensor([10, 10], dtype=torch.int32, device=self.device),
)
spec = {
"values": (Dim.DYNAMIC,),
"repeats": (Dim.DYNAMIC,),
"mask": (Dim.DYNAMIC,),
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
"x": (Dim.DYNAMIC, Dim.STATIC),
"y": (Dim.DYNAMIC, Dim.STATIC),
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
"lst": (Dim.STATIC,),
}
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
@config.patch({"unbacked_symint_fallback": 128})
def test_size_with_unbacked_add_and_mul_expr(self):
# Edge case with torch._check(add_expr, mul_expr). When generating example
# input sizes for autotuning, make sure they coalesce into a single size.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
u0, u1, u2 = lst.tolist()
backed = z.size(0)
backed1 = z.size(1)
unbacked_add_expr = backed + u0
unbacked_mul_expr = backed1 + (u1 * u2)
repeated0 = x.repeat(unbacked_add_expr, 1)
repeated1 = y.repeat(unbacked_mul_expr, 1)
out0 = torch.empty_like(repeated0)
out1 = torch.empty_like(repeated1)
add_kernel[(out0.numel(),)](
repeated0, repeated0, out0, out0.numel(), BLOCK_SIZE=2
)
add_kernel[(out1.numel(),)](
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
)
return torch.cat([out1, out0], dim=1)
example_inputs = (
torch.ones(64, dtype=torch.int64, device=self.device),
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
torch.ones(758, 758, dtype=torch.int64, device=self.device),
torch.tensor([10, 5, 2], dtype=torch.int32, device=self.device),
)
spec = {
"values": (Dim.DYNAMIC,),
"repeats": (Dim.DYNAMIC,),
"mask": (Dim.DYNAMIC,),
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
"x": (Dim.DYNAMIC, Dim.STATIC),
"y": (Dim.DYNAMIC, Dim.STATIC),
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
"lst": (Dim.STATIC,),
}
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA"
)
def test_fallback_kernel_with_symexpr_output(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Module(torch.nn.Module):
def forward(self, q, k, v):
q = q.reshape(
q.shape[0],
2,
q.shape[2] * q.shape[3],
q.shape[1] // 2,
)
k = k.reshape(
k.shape[0],
2,
k.shape[2] * k.shape[3],
k.shape[1] // 2,
)
v = v.reshape(
v.shape[0],
2,
v.shape[2] * v.shape[3],
v.shape[1] // 2,
)
res = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
)
return res[0]
m = Module().to(device=self.device)
tensor_shape = (4, 32, 4, 4)
inputs = (
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
)
dynamic_shapes = {
"q": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
"k": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
"v": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
}
ep = torch.export.export(m, inputs, dynamic_shapes=dynamic_shapes, strict=False)
path = torch._inductor.aot_compile(ep.module(), inputs)
aot_model = torch._export.aot_load(path, device=self.device)
torch.testing.assert_close(m(*inputs), aot_model(*inputs))
def test_aoti_constant_tensor(self):
class Foo(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.a = torch.ones(4, 4, device=device)
self.b = torch.ones(4, 4, device=device)
def forward(self, x):
return torch.ops.aten.linear.default(x, self.a, self.b)
example_inputs = (torch.ones(4, 4, device=self.device),)
self.check_model(Foo(self.device), example_inputs)
def test_aoti_constant_tensor_name_collision(self):
class SubModule(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.register_buffer(
"_tensor_constant1",
torch.ones(1, device=device, dtype=torch.float32),
persistent=True,
)
def forward(self, x):
return self.linear(x)
class Foo(torch.nn.Module):
def __init__(self, user_float_feature_idx, device):
super().__init__()
self.user_float_feature_idx = user_float_feature_idx
self.register_buffer(
"_tensor_constant0",
torch.ones(5, device=device, dtype=torch.float32),
persistent=True,
)
self.register_buffer(
"_tensor_constant1",
torch.ones(1, device=device, dtype=torch.float32),
persistent=True,
)
self.sub_mod = SubModule(device)
def forward(self, x):
self._tensor_constant0[1:2] = 1
return (
torch.index_select(
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
),
self._tensor_constant0,
self._tensor_constant1,
self.sub_mod._tensor_constant1,
)
example_inputs = (torch.ones(4, 4, device=self.device),)
user_float_feature_idx = [1]
# we have to have run_decomposition first to trigger the name collision
ep = torch.export.export(
Foo(user_float_feature_idx, self.device), example_inputs, strict=False
).run_decompositions()
gm = ep.module()
self.check_model(gm.to(self.device), example_inputs)
def test_large_grid(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, primals_5):
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
primals_5 = None
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
clone = torch.ops.aten.clone.default(
permute, memory_format=torch.contiguous_format
)
return clone
# let y_grid = 65537
s0 = 16777472
s1 = 8
example_inputs = (torch.rand(s0, s1, device=self.device),)
self.check_model(Model(), example_inputs)
def test_cond_simple(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"p": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.Simple(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_nested(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_abc = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"p0": {},
"p1": {},
"p2": {},
"a": {0: dim0_abc, 1: None},
"b": {0: dim0_abc, 1: None},
"c": {0: dim0_abc, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.Nested(),
prepend_predicates(inputs, num_predicates=3),
dynamic_shapes=dynamic_shapes,
)
def test_cond_with_parameters(self):
inputs = (torch.randn((10, 20), device=self.device),)
dim0_abc = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"p": {},
"a": {0: dim0_abc, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.Parameters(self.device),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_with_reinterpret_view_inputs_outputs(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
# TODO: the min value need to be 5 because in the body_fn, we're slicing over z1[2:],
# since the output size is [dim0_ab-3], when we extract tensor metadata out of the output
# we call guard_size_oblivious, which assumes the dim0_ab-3 != 0 or 1. So we have to set
# the minimum to 5 for now. We need to relax this restriction either by writing a less
# constrained shape checking in fake impl of cond.
dim0_ab = Dim("s0", min=5, max=1024)
dynamic_shapes = {
"p": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.ReinterpretView(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_with_multiple_outputs(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
torch.randn((30, 40), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dim0_c = Dim("s1", min=2, max=1024)
dynamic_shapes = {
"p": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
"c": {0: dim0_c, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.MultipleOutputs(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_with_outer_code_before_after(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"p": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.OuterCode(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_use_buffers_from_outer_scope(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_abc = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"p": {},
"a": {0: dim0_abc, 1: None},
"b": {0: dim0_abc, 1: None},
"c": {0: dim0_abc, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.OuterBuffers(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_cond_non_tensor_predicates(self, dynamic):
inputs1 = (
torch.randn((10, 20), device=self.device),
torch.randn((15, 20), device=self.device),
)
inputs2 = (
torch.randn((10, 20), device=self.device),
torch.randn((5, 20), device=self.device),
)
inputs = (inputs1,)
dynamic_shapes = None
if dynamic:
inputs = (inputs1, inputs2)
dim0_a = Dim("s0", min=2, max=1024)
dim0_b = Dim("s1", min=2, max=1024)
dynamic_shapes = {
"a": {0: dim0_a, 1: None},
"b": {0: dim0_b, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.WithNonTensorPredicate(),
inputs,
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_cond_unbacked_symint_closure(self, dynamic):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((15, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dynamic_shapes = None
if dynamic:
dim0_a = Dim("s0", min=2, max=1024)
dim0_b = Dim("s1", min=2, max=1024)
dynamic_shapes = {
"p": {},
"x": {0: dim0_a, 1: None},
"y": {0: dim0_b, 1: None},
"z": {0: dim0_a, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.UnbackedSymIntClosure(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
@skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation")
@common_utils.parametrize("dynamic", [False, True])
def test_cond_mismatched_branch_output(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dynamic_shapes = None
if dynamic:
# Note the minimum has to be 4 because the model
# is slicing over the first dim with [2:], if first
# dim is 2 or 3, the slicing will be 0/1 specialized,
# causing a constraint violation error.
dim0_a = Dim("s0", min=4, max=1024)
dim0_b = Dim("s1", min=4, max=1024)
dynamic_shapes = {
"p": {},
"x": {0: dim0_a, 1: None},
"y": {0: dim0_b, 1: None},
"z": {0: dim0_a, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.MismatchedOutputSize(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_symint_input(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
input1 = (
torch.ones(3, 3, device=self.device),
torch.ones(5, device=self.device),
torch.ones(3, 3, device=self.device),
)
input2 = (
torch.ones(10, 3, device=self.device),
torch.ones(6, device=self.device),
torch.ones(10, 3, device=self.device),
)
inputs = (input1, input2)
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
self.check_model_with_multiple_inputs(
M(),
inputs,
dynamic_shapes=dynamic_shapes,
)
def test_cond_symint_input_disable_one_pass(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
input1 = (
torch.ones(3, 3, device=self.device),
torch.ones(5, device=self.device),
torch.ones(3, 3, device=self.device),
)
input2 = (
torch.ones(10, 3, device=self.device),
torch.ones(6, device=self.device),
torch.ones(10, 3, device=self.device),
)
inputs = (input1, input2)
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
with torch._inductor.config.patch({"triton.autotune_at_compile_time": False}):
self.check_model_with_multiple_inputs(
M(),
inputs,
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_simple(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"ci": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Simple(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_nested(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"ci": {},
"cj": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Nested(),
prepend_counters(inputs, num_counters=2),
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_with_outer_code(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.OuterCode(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
# mps doesn't support float64
@skipIfMPS
@unittest.skipIf(
config.triton.native_matmul,
"FIXME: cannot do get_size on FakeTensor during lowering.",
)
def test_while_loop_with_parameters(self):
inputs = (
torch.randn(
(
10,
20,
),
dtype=torch.float64,
device=self.device,
),
)
dim0_a = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"c": {},
"a": {0: dim0_a, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Parameters(self.device),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_with_outer_buffers(self):
inputs = (
torch.randn((10, 20), device=self.device),
torch.randn((10, 20), device=self.device),
)
# dynamic shapes don't work now due to
# https://github.com/pytorch/pytorch/issues/123596
# dim0_ab = Dim("s0", min=2, max=1024)
# dynamic_shapes = {
# "c": {},
# "a": {0: dim0_ab, 1: None},
# "b": {0: dim0_ab, 1: None},
# }
dynamic_shapes = None
self.check_model_with_multiple_inputs(
WhileLoopModels.OuterBuffers(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_with_pytree_inputs(self):
inputs = (
torch.tensor(0, device=self.device),
(
[torch.randn(10, 20, device=self.device)],
{
"x": torch.randn(10, 20, device=self.device),
"y": torch.randn(10, 20, device=self.device),
},
),
)
self.check_model_with_multiple_inputs(
WhileLoopModels.PytreeCarry(),
[inputs],
dynamic_shapes=None,
)
@common_utils.parametrize("dynamic", [False, True])
def test_while_loop_with_unbacked_symint_closure(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = None
if dynamic:
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.UnbackedSymIntClosure(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_while_loop_with_mixed_device(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = None
if dynamic:
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.MixedDevice(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_while_loop_with_sym_expr_cond(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = None
if dynamic:
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.SymExprCond(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_while_loop_with_conv(self, dynamic):
inputs = (torch.randn(2, 4, 4, 4, device=self.device, dtype=torch.float64),)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = None
if dynamic:
dynamic_shapes = {
"c": {},
"x": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.Conv(self.device),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
@config.patch({"is_predispatch": True})
def test_constant(self):
class M(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.device = device
def forward(self, x):
t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
t = torch.sqrt(t * 3)
return x * t
self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),))
@unittest.skipIf(IS_MACOS, "no CUDA on Mac")
def test_zero_grid_with_backed_symbols(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, b):
return x + b
example_inputs = (
torch.randn((3, 2), device=self.device),
torch.randn((1, 2), device=self.device),
)
dynamic_shapes = {
"x": {0: Dim("dx"), 1: Dim.STATIC},
"b": None,
}
# Compile & run model where dynamic dim size > 0.
package_path: str = AOTIRunnerUtil.compile(
Repro(),
example_inputs,
dynamic_shapes=dynamic_shapes,
)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
aot_inductor_module(*example_inputs)
# Re-run where dynamic dim size is 0.
example_inputs = (
torch.randn((0, 2), device=self.device),
torch.randn((1, 2), device=self.device),
)
actual = aot_inductor_module(*example_inputs)
expected = Repro()(*example_inputs)
torch.testing.assert_close(actual, expected)
def test_repeat_interleave(self):
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12)
example_inputs = (torch.ones((1,), dtype=torch.int32, device=self.device) * 12,)
self.check_model(Repro(), example_inputs)
def test_dynamic_cat(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.cat([a, b], dim=0)
a = torch.randn(2, 4, device=self.device)
b = torch.randn(3, 4, device=self.device)
dim0_a = Dim("dim0_a", min=1, max=10)
dim0_b = Dim("dim0_b", min=1, max=20)
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
example_inputs = (a, b)
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
def test_buffer_mutation_1(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.foo = torch.nn.Buffer(torch.randn(4, 4, device=device))
def forward(self, x):
self.foo.add_(1)
return self.foo + x
example_inputs = (torch.rand(4, 4, device=self.device),)
self.check_model(Model(self.device), example_inputs)
def test_non_tensor_input(self):
class Model(torch.nn.Module):
def forward(self, a, b, alpha=1.0):
return torch.add(a, b, alpha=alpha)
a = torch.randn(10, device=self.device)
b = torch.randn(10, device=self.device)
for simdlen in [0, None]:
with torch._inductor.config.patch({"cpp.simdlen": simdlen}):
so_path = torch._export.aot_compile(
torch.ops.aten.add,
args=(a, b),
kwargs={"alpha": 2.0},
)
kernel_runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
res = kernel_runner.run([a, b])
self.assertTrue(isinstance(res, list))
self.assertTrue(len(res) == 1)
self.assertEqual(Model()(a, b, alpha=2.0), res[0])
def test_buffer_mutation_2(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.foo = torch.nn.Buffer(torch.arange(10, device=device))
self.bar = torch.nn.Buffer(torch.arange(10, device=device))
def forward(self, x):
self.bar.mul_(2)
self.foo[5] = self.bar[0]
return x + self.bar, x * self.foo
example_inputs = (torch.randn(10, device=self.device),)
self.check_model(Model(self.device), example_inputs)
@skipIfWindows(
msg="OpenMP crashed application on windows"
) # TODO: (xuhancn) need to root cause and fix.
def test_buffer_mutation_3(self):
class KVCache(torch.nn.Module):
def __init__(
self,
max_batch_size,
max_seq_length,
n_heads,
head_dim,
dtype=torch.float,
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.k_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
self.v_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.kv_cache = KVCache(1, 256, 6, 48)
def forward(self, inp_pos, k, v):
self.kv_cache.update(inp_pos, k, v)
return self.kv_cache.k_cache + 1, self.kv_cache.v_cache / 2
example_inputs = (
torch.tensor([0], device=self.device),
torch.randn(1, 6, 1, 48, device=self.device),
torch.randn(1, 6, 1, 48, device=self.device),
)
model = Model(self.device)
self.check_model(model, example_inputs)
self.code_check_count(model, example_inputs, "empty_strided", 2)
def test_buffer_mutation_4(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"_tensor_constant0",
torch.randint(1, size=[38], dtype=torch.int64, device="cpu"),
)
def forward(self, x):
return x + self._tensor_constant0.to(
torch.device(type=GPU_TYPE, index=0)
)
example_inputs = (
torch.randint(1, size=[38], dtype=torch.int64, device=GPU_TYPE),
)
torch._export.aot_compile(Model(), example_inputs)
@skipCUDAIf(True, "Test for x86 backend")
@unittest.skipIf(IS_FBCODE, "Need newer ideep")
def test_buffer_mutation_and_force_mmap_weights(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(16, 15)
self.linear2 = torch.nn.Linear(15, 14)
def forward(self, x):
x = self.linear1(x)
out = self.linear2(x)
return out
example_inputs = (torch.randn(32, 16),)
model = Model().eval()
with (
config.patch({"freezing": True, "aot_inductor.force_mmap_weights": True}),
torch.no_grad(),
):
exported_model = export(model, example_inputs, strict=True).module()
quantizer = X86InductorQuantizer()
quantizer.set_global(
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)
)
prepared_model = prepare_pt2e(exported_model, quantizer)
prepared_model(*example_inputs)
converted_model = convert_pt2e(prepared_model)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
self.check_model(converted_model, example_inputs)
@skipIfMPS
def test_fallback_mem_leak_fix(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y, idx):
tmp = x + y
w = torch.ops.aten.as_strided(tmp, x.shape, x.stride())
out = torch.ops.aten.index.Tensor(w, [idx])
return w, out
example_inputs = (
torch.randn(4, 1, 4, device=GPU_TYPE),
torch.randn(4, 1, 4, device=GPU_TYPE),
torch.randn(4, device=GPU_TYPE) > 0,
)
dim0 = Dim("dim0", min=1, max=2048)
dynamic_shapes = {
"x": {0: dim0},
"y": {0: dim0},
"idx": {0: dim0},
}
package_path: str = AOTIRunnerUtil.compile(
Model(),
example_inputs,
dynamic_shapes=dynamic_shapes,
)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
device_interface = get_interface_for_device(GPU_TYPE)
device: int = device_interface.current_device()
mem_before = device_interface.memory_allocated(device)
aot_inductor_module(*example_inputs)
mem_after = device_interface.memory_allocated(device)
self.assertEqual(mem_before, mem_after)
actual = aot_inductor_module(*example_inputs)
expected = Model()(*example_inputs)
torch.testing.assert_close(actual, expected)
@requires_multigpu()
@skipIfMPS
def test_replicate_on_devices(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, w1, w2):
super().__init__()
self.w1 = w1
self.w2 = w2
def forward(self, x, y):
a = x * self.w1
b = y * self.w2
return a + b
w1 = torch.randn(10, 10)
w2 = torch.randn(10, 10)
inputs = (torch.randn(10, 10), torch.randn(10, 10))
result_cpu = Model(w1, w2)(*inputs)
# Compile model with AOTInductor
device_interface = get_interface_for_device(GPU_TYPE)
with device_interface.device(0):
package_path = AOTIRunnerUtil.compile(
model=Model(
w1.to(torch.device(GPU_TYPE, 0)), w2.to(torch.device(GPU_TYPE, 0))
),
example_inputs=tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
)
# Run model on gpu:N
for i in range(device_interface.device_count()):
with device_interface.device(i):
example_inputs = tuple(t.to(torch.device(GPU_TYPE, i)) for t in inputs)
optimized = torch._inductor.aoti_load_package(package_path)
result_gpu = optimized(*example_inputs)
self.assertTrue(same(result_cpu, result_gpu.cpu()))
@requires_multigpu()
@skipIfMPS
def test_on_gpu_device1(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
device_interface = get_interface_for_device(GPU_TYPE)
try:
device_interface.get_device_properties(1)
except AssertionError:
raise unittest.SkipTest("GPU device 1 is not available") from None
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(16, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
device = f"{GPU_TYPE}:1"
model = Model().to(device)
example_inputs = (torch.randn(8, 10, device=device),)
expected = model(*example_inputs)
so_path = AOTIRunnerUtil.legacy_compile(model, example_inputs)
optimized = AOTIRunnerUtil.legacy_load(device, so_path)
actual = optimized(*example_inputs)
torch.testing.assert_close(actual, expected)
def test_pytree_inputs(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: dict[str, torch.Tensor]):
device = next(iter(x.values())).device
add_ = torch.zeros(5, device=device)
mul_ = torch.ones(5, device=device)
for v in x.values():
add_ += v
mul_ *= v
return [add_, mul_]
self.check_model(
M(),
(
{
"x": torch.ones(5, device=self.device),
"y": torch.ones(5, device=self.device),
},
),
)
@requires_multigpu()
def test_non_default_gpu_device(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x, y):
return x + torch.nn.functional.linear(y, self.weight)
weight = torch.randn(10, 10)
inputs = (torch.randn(10, 10), torch.randn(10, 10))
result_cpu = Model(weight)(*inputs)
device_interface = get_interface_for_device(GPU_TYPE)
with device_interface.device(0), torch.no_grad():
result_gpu_0 = AOTIRunnerUtil.run(
Model(weight.to(torch.device(GPU_TYPE, 0))),
tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
)
with device_interface.device(1), torch.no_grad():
result_gpu_1 = AOTIRunnerUtil.run(
Model(weight.to(torch.device(GPU_TYPE, 1))),
tuple(t.to(torch.device(GPU_TYPE, 1)) for t in inputs),
)
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
@requires_multigpu()
def test_load_package_multiple_gpus(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x, y):
return x + torch.nn.functional.linear(y, self.weight)
weight = torch.randn(10, 10, device=self.device)
inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model(weight).to(device=self.device)
result_ref = model(*inputs)
package_path = AOTIRunnerUtil.compile(model, inputs)
# Load AOT package on gpu:N
device_interface = get_interface_for_device(GPU_TYPE)
for i in range(device_interface.device_count()):
device = torch.device(GPU_TYPE, i)
with device_interface.device(i), torch.no_grad():
model_package = torch._inductor.aoti_load_package(
package_path, device_index=i
)
inputs_on_device = [input.to(device=device) for input in inputs]
result_package = model_package(*inputs_on_device)
self.assertTrue(same(result_ref.cpu(), result_package.cpu()))
@unittest.skipIf(
config.triton.native_matmul, "sin and mm are fused in native matmul"
)
def test_reuse_kernel(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
a = torch.sin(x)
b = torch.mm(a, y)
c = torch.sin(b)
d = torch.mm(b, c)
return d
example_inputs = (
torch.randn(87, 87, device=self.device),
torch.randn(87, 87, device=self.device),
)
model = Model()
# 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
self.check_model(model, example_inputs, atol=1e-4, rtol=1e-4)
if self.device == "mps":
self.code_check_count(
model, example_inputs, "aoti_torch_mps_get_kernel_function(", 1
)
elif self.device == GPU_TYPE:
self.code_check_count(
model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1
)
def test_reuse_kernel_dynamic(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.cst = torch.randn(48, device=device, dtype=torch.float)
self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float)
self.cst_1 = torch.randn(48, device=device, dtype=torch.float)
self.weights_1 = torch.randn(
6, 48, 48, device=device, dtype=torch.float
)
def forward(self, x, y, z):
dim0 = x.size(1)
add_0 = z + z
expand_2 = add_0.expand(-1, -1, 48)
# [s0, 6, 48]
mul_3 = add_0 * expand_2
# [6, s0, 48]
permute_4 = torch.permute(mul_3, (1, 0, 2))
# [6, s0, 48]
bmm_5 = torch.bmm(permute_4, self.weights)
add_6 = bmm_5 + self.cst
reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8])
# [6*s0, 6, 8]
permute_8 = torch.permute(reshape_7, (1, 0, 2))
mul_9 = permute_8 * 0.123
reshape_10 = torch.reshape(y, [8, dim0 * 6, 4])
# [6*s0, 8, 4]
permute_11 = torch.permute(reshape_10, (1, 0, 2))
bmm_12 = torch.bmm(mul_9, permute_11)
add_0_1 = z + z
expand_2_1 = add_0_1.expand(-1, -1, 48)
# [s0, 6, 48]
mul_3_1 = add_0_1 * expand_2_1
# [6, s0, 48]
permute_4_1 = torch.permute(mul_3_1, (1, 0, 2))
# [6, s0, 48]
bmm_5_1 = torch.bmm(permute_4_1, self.weights_1)
add_6_1 = bmm_5_1 + self.cst_1
reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8])
# [6*s0, 6, 8]
permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2))
mul_9_1 = permute_8_1 * 0.123
reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4])
# [6*s0, 8, 4]
permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2))
bmm_12_1 = torch.bmm(mul_9_1, permute_11_1)
return bmm_12 + bmm_12_1
x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float)
y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float)
z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float)
dim0 = Dim("dim0", min=1, max=2048)
dynamic_shapes = {
"x": {1: dim0},
"y": {1: dim0},
"z": {0: dim0},
}
example_inputs = (x, y, z)
model = Model(self.device).to(dtype=torch.float)
self.check_model(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
atol=1e-5,
rtol=1e-5,
)
def test_fake_tensor_device_validation(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return x + y
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
# Export on CPU
exported_program = export(Model(), example_inputs, strict=True)
# Compile exported model on GPU
gm = exported_program.graph_module.to(self.device)
with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"):
torch._inductor.aot_compile(
gm, tuple(i.to(self.device) for i in example_inputs)
)
def test_fx_gm_return_tuple_validation(self):
from torch.fx.experimental.proxy_tensor import make_fx
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return x + y
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
gm = make_fx(Model(), tracing_mode="symbolic")(*example_inputs)
with self.assertRaisesRegex(
AssertionError,
r"Graph output must be a tuple\(\). This is so that we can avoid "
"pytree processing of the outputs.",
):
torch._inductor.aot_compile(gm, example_inputs)
def test_consecutive_compiles(self):
"""Test that compilation behaves correctly with cache hits"""
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + 1
mod = TestModule()
inp = torch.rand(1)
mod(inp)
mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])
so = torch._export.aot_compile(mod2, (inp,))
assert so is not None
# compile the 2nd time with cache hit
so = torch._export.aot_compile(mod2, (inp,))
assert so is not None
def test_normal_functional(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.normal_functional.default(x)
self.check_model(Model(), (torch.empty(4, 1, 4, 4, device=self.device),))
def test_empty_graph(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x
example_inputs = (torch.randn(8, 4, 4, device=self.device),)
self.check_model(Model(), example_inputs)
@patch("torch._dynamo.utils.CompileEventLogger.log_instant_event")
def test_backward_no_op_logging(self, mock_log_instant_event):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x
model = Model()
dummy_input = torch.randn(1, 5)
from torch._dynamo.utils import CompileEventLogLevel
from torch._inductor import compile_fx
graph_module = torch.fx.symbolic_trace(model)
compile_fx._compile_fx_inner(graph_module, (dummy_input,))
mock_log_instant_event.assert_called_once_with(
"backward no-op",
metadata={"compile_id": None},
log_level=CompileEventLogLevel.PT2_COMPILE,
)
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
def test_dup_unbacked_sym_decl(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
abs_1 = torch.ops.aten.abs.default(x)
lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
eq = torch.ops.aten.eq.Scalar(lt, 0)
index_1 = torch.ops.aten.index.Tensor(x, [eq])
sin = torch.ops.aten.sin.default(index_1)
index_2 = torch.ops.aten.index.Tensor(x, [eq])
div_3 = torch.ops.aten.div.Tensor(sin, index_2)
return div_3
example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),)
self.check_model(Model(), example_inputs)
# This exercises _eliminate_unbacked path in ShapeEnv
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
def test_dup_unbacked_sym_decl_with_refinement(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
abs_1 = torch.ops.aten.abs.default(x)
lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
eq = torch.ops.aten.eq.Scalar(lt, 0)
index_1 = torch.ops.aten.index.Tensor(x, [eq])
torch._check(index_1.size(0) == 4**4)
sin = torch.ops.aten.sin.default(index_1)
index_2 = torch.ops.aten.index.Tensor(x, [eq])
div_3 = torch.ops.aten.div.Tensor(sin, index_2)
return div_3
example_inputs = (torch.ones(4, 4, 4, 4).to(self.device),)
self.check_model(Model(), example_inputs)
def test_run_with_grad_enabled(self):
class Model(torch.nn.Module):
def forward(self, x, weight, bias):
return torch.ops.aten.addmm(bias, weight, x)
m = Model().to(device=self.device)
x = torch.rand(8, 8, device=self.device, requires_grad=True)
weight = torch.rand(8, 8, device=self.device, requires_grad=True)
bias = torch.rand(8, device=self.device, requires_grad=True)
example_inputs = (x, weight, bias)
expected = m(*example_inputs)
expected = pytree.tree_leaves(expected)
# compiler under no_grad
with torch.no_grad():
package_path = AOTIRunnerUtil.compile(m, example_inputs)
# run under grad enabled
self.assertTrue(torch.is_grad_enabled())
optimized = torch._inductor.aoti_load_package(package_path)
actual = optimized(*example_inputs)
actual = pytree.tree_leaves(actual)
self.assertTrue(same(actual, expected))
def test_return_constant(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.cst = torch.randn(5, 5, device=device)
def forward(self, x):
a = self.cst.clone()
return (x, a)
x = torch.randn(5, device=self.device)
self.check_model(Model(self.device), (x,))
def test_return_view_constant(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.cst = torch.randn(5, 5, device=device)
def forward(self, x):
a = torch.transpose(self.cst, 0, 1)
return (x, a)
x = torch.randn(5, device=self.device)
self.check_model(Model(self.device), (x,))
def test_profile_benchmark_harness(self):
batch_size = 32
seq_length = 50
hidden_size = 768
def create_test_fn():
def test_fn():
inp = torch.randn(
batch_size, seq_length, hidden_size, device=self.device
)
weight = torch.randn(hidden_size, hidden_size, device=self.device)
matmul_output = inp @ weight
torch.nn.LayerNorm(hidden_size, device=self.device)(matmul_output)
return True
return test_fn
fn = torch.compile(
options={"profile_bandwidth_output": "foo", "benchmark_harness": False}
)(create_test_fn())
fn()
def test_with_profiler(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
self.check_model(Model(), example_inputs)
def test_with_no_triton_profiler(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.permute(x, (1, 0))
example_inputs = (torch.randn(10, 10, device=self.device),)
with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
self.check_model(Model(), example_inputs)
def test_repeat_output(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
y = torch.sin(x)
return y, y
example_inputs = (torch.randn(3, 10, device=self.device),)
self.check_model(Model(), example_inputs)
def test_repeated_calling(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.sin(x)
example_inputs = (torch.randn(10, 10, device=self.device),)
optimized = torch._inductor.aoti_load_package(
torch._inductor.aoti_compile_and_package(
torch.export.export(Model(), example_inputs, strict=True)
)
)
try:
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history(context=None)
for _ in range(10):
optimized(*example_inputs)
finally:
torch.cuda.memory._record_memory_history(False)
segments = torch.cuda.memory._snapshot()["segments"]
self.assertEqual(segments[0]["requested_size"], 400)
def test_view_outputs(self):
class Model(torch.nn.Module):
def forward(self, x):
y = torch.sin(x)
y_same_size = y.view(*y.shape)
y_diff_size = y.view(1, *y.shape)
return y, y_same_size, y_diff_size
example_inputs = (torch.randn(3, 10, device=self.device),)
self.check_model(Model(), example_inputs)
@skip_if_no_torchvision
def test_missing_cubin(self):
from torchvision.models.resnet import Bottleneck, ResNet
class Model(ResNet):
def __init__(self) -> None:
super().__init__(
block=Bottleneck,
layers=[3, 4, 6, 3],
replace_stride_with_dilation=[False, False, True],
norm_layer=None,
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
f1 = x
x = self.maxpool(x)
x = self.layer1(x)
f2 = x
x = self.layer2(x)
f3 = x
x = self.layer3(x)
x = self.layer4(x)
f4 = x
return [f1, f2, f3, f4]
# Call eval() here so that batch_norm won't update the running stats
# Use float64 to avoid numeric difference failure
dtype = torch.float32 if self.device == "mps" else torch.float64
model = Model().to(device=self.device, dtype=dtype).eval()
example_inputs = (torch.randn(4, 3, 64, 64, device=self.device, dtype=dtype),)
self.check_model(model, example_inputs)
def test_triton_next_power_of_2(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, a, b, lengths):
n_elements = a.numel()
out = torch.empty_like(a)
max_len = int(lengths.max())
scaling_factor = triton.next_power_of_2(max_len)
add_kernel_with_scaling[(n_elements,)](
a,
b,
out,
n_elements,
scaling_factor,
BLOCK_SIZE=16,
)
return out
example_inputs = (
torch.randn(2, device=self.device),
torch.randn(2, device=self.device),
torch.arange(end=4, device=self.device),
)
self.check_model(Model(), example_inputs)
@common_utils.parametrize("minmax", [min, max])
@skipIfWindowsXPU(msg="crash on Windows XPU.")
def test_sympy_cpp_printer_min_max(self, minmax):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, a, b, ranks):
n_elements = a.numel()
out = torch.empty_like(a)
backed = a.size(0)
unbacked = int(ranks.max())
scaling_factor = minmax(backed, unbacked, 100)
add_kernel_with_scaling[(n_elements,)](
a,
b,
out,
n_elements,
scaling_factor,
BLOCK_SIZE=16,
)
return out
example_inputs = (
torch.randn(16, device=self.device),
torch.randn(16, device=self.device),
torch.arange(end=4, device=self.device, dtype=torch.int16),
)
torch._dynamo.mark_dynamic(example_inputs[0], 0)
torch._dynamo.mark_dynamic(example_inputs[1], 0)
self.check_model(Model(), example_inputs)
@skipIfMPS
@common_utils.parametrize("grid_type", [1, 2, 3])
@common_utils.parametrize("num_dims", [1, 2])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("autotune", [False, True])
def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
output = torch.zeros_like(x)
if autotune and num_dims == 2:
x_elements = output.size()[0]
y_elements = output.size()[1]
else:
n_elements = output.numel()
# Select grid
if autotune and num_dims == 2:
if grid_type == 1:
grid = (x_elements, y_elements)
elif grid_type == 2:
grid = lambda meta: ( # noqa: E731
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
)
else:
def grid_fn(meta):
return (
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
)
grid = grid_fn
else:
if grid_type == 1:
grid = (n_elements,)
elif grid_type == 2:
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
else:
def grid_fn(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
grid = grid_fn
# Select kernel
if autotune:
if num_dims == 1:
add_kernel_autotuned[grid](x, y, output, n_elements)
else:
add_kernel_2d_autotuned[grid](
x, y, output, x_elements, y_elements
)
else:
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output
dims = [10] * num_dims
x = torch.randn(*dims, device=self.device)
y = torch.randn(*dims, device=self.device)
dynamic_shapes = []
if dynamic:
dim0_x = Dim("dim0_x", min=1, max=10)
dim0_y = Dim("dim0_y", min=1, max=10)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes)
def test_triton_kernel_dynamic_shape_with_div(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.jit
def pass_kernel(x, num):
pass
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
num = x.numel() // 4
grid = lambda meta: (triton.cdiv(num, 16),) # noqa: E731
pass_kernel[grid](x, num)
return x
x = torch.randn(10, device=self.device)
dim0_x = Dim("dim0_x", min=1, max=10)
dynamic_shapes = {"x": {0: dim0_x}}
self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes)
def test_triton_kernel_reinterpret_view(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.jit
def pass_kernel(x, y):
pass
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
out = torch.zeros_like(x[:, 4:])
# the slicing below creates two ReinterpretView
# instances: with offset=3 and offset=4
add_kernel[(10,)](
in_ptr0=x[:, 3:-1],
in_ptr1=x[:, 4:],
out_ptr=out,
n_elements=160,
BLOCK_SIZE=16,
)
return out
example_inputs = (torch.randn(10, 20, device=self.device),)
self.check_model(Model(), example_inputs)
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("tma_version", ["new", "old"])
def test_triton_kernel_tma_descriptor_1d(self, dynamic, tma_version):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
if tma_version == "new" and not has_triton_tensor_descriptor_host_tma():
self.skipTest("requires triton.tools.tensor_descriptor TMA support")
if tma_version == "old" and not has_triton_experimental_host_tma():
self.skipTest("requires triton.tools.experimental_descriptor TMA support")
kernel = (
add_kernel_with_tma_1d_new_api
if tma_version == "new"
else add_kernel_with_tma_1d_old_api
)
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
BLOCK_SIZE = 256
out = torch.zeros_like(a)
n_elements = out.numel()
desc_a, desc_b, desc_out = (
create_tensor_descriptor_shim(
t, [BLOCK_SIZE], new_api=(tma_version == "new")
)
for t in (a, b, out)
)
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
kernel[grid](
desc_a,
desc_b,
desc_out,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
a = torch.randn(301, device=self.device)
b = torch.randn(301, device=self.device)
example_inputs = (a, b)
dynamic_shapes = None
if dynamic:
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model(
Model(),
example_inputs=example_inputs,
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("tma_version", ["new", "old"])
def test_triton_kernel_tma_descriptor_2d(self, dynamic, tma_version):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
if tma_version == "new" and not has_triton_tensor_descriptor_host_tma():
self.skipTest("requires triton.tools.tensor_descriptor TMA support")
if tma_version == "old" and not has_triton_experimental_host_tma():
self.skipTest("requires triton.tools.experimental_descriptor TMA support")
kernel = (
add_kernel_with_tma_2d_new_api
if tma_version == "new"
else add_kernel_with_tma_2d_old_api
)
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
BLOCK_SIZE_X = 16
BLOCK_SIZE_Y = 32
out = torch.zeros_like(a)
x_size, y_size = out.size()
desc_a, desc_b, desc_out = (
create_tensor_descriptor_shim(
t,
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
new_api=(tma_version == "new"),
)
for t in (a, b, out)
)
grid = lambda meta: ( # noqa: E731
triton.cdiv(x_size, meta["BLOCK_SIZE_X"]),
triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]),
)
kernel[grid](
desc_a,
desc_b,
desc_out,
BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
)
return out
a = torch.randn((25, 16), device=self.device)
b = torch.randn((25, 16), device=self.device)
example_inputs = (a, b)
dynamic_shapes = None
if dynamic:
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model(
Model(),
example_inputs=example_inputs,
dynamic_shapes=dynamic_shapes,
)
def test_triton_kernel_sympy_expr_arg(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, e):
sympy_expr = max(1, e.item())
out = torch.zeros_like(x)
add_kernel[(1,)](
in_ptr0=x,
in_ptr1=x,
out_ptr=out,
n_elements=sympy_expr,
BLOCK_SIZE=1,
)
return out
NUMEL = 64
inputs = (
torch.randn(NUMEL, device=self.device),
torch.tensor(NUMEL, device=self.device),
)
self.check_model(Model(), inputs)
def test_triton_kernel_sympy_fn_like_arg(self):
# This test should hit sympy.expand("sqrt") which crashes with
# AttributeError: 'function' object has no attribute 'expand'.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x):
out = torch.zeros_like(x)
add_kernel_with_optional_param[1,](
in_ptr0=x,
in_ptr1=x,
out_ptr=out,
n_elements=x.numel(),
BLOCK_SIZE=1,
ARGS_PASSED="sqrt", # sqrt is a valid sympy fn
)
return out
inputs = (torch.randn(4, device=self.device),)
self.check_model(Model(), inputs)
def test_triton_kernel_with_none_input(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
n_elements = x.size()[0]
BLOCK_SIZE = 1024
output_wo_y = torch.empty_like(x)
output_with_y = torch.empty_like(x)
add_kernel_with_optional_param[(1,)](
x,
None,
output_wo_y,
n_elements,
ARGS_PASSED="one",
BLOCK_SIZE=BLOCK_SIZE,
)
add_kernel_with_optional_param[(1,)](
x,
y,
output_with_y,
n_elements,
ARGS_PASSED="two",
BLOCK_SIZE=BLOCK_SIZE,
)
return 2.71 * output_wo_y + 3.14 * output_with_y
example_inputs = (
torch.randn(1023, device=self.device),
torch.randn(1023, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_triton_kernel_equal_to_1_arg(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.empty_like(x)
n_elements = x.numel()
add_kernel[(n_elements,)](x, y, out, n_elements, BLOCK_SIZE=16)
return out
example_inputs = (
torch.randn(1, device=self.device),
torch.randn(1, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_triton_kernel_with_none_inputs_and_equal_to_1_arg(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
n_elements = x.size()[0]
BLOCK_SIZE = 1024
out1 = torch.empty_like(x)
out2 = torch.empty_like(x)
# Run the same kernel multiple times to test the optimization
# of removing None arguments and then update the indices of
# equal_to_1 arguments. The None arguments need to be before
# the equal_to_1 arguments
add_kernel_with_none_param_and_equal_to_1_arg[(1,)](
x,
None,
out1,
n_elements,
x.stride(0), # equal to 1
ARGS_PASSED="one",
BLOCK_SIZE=BLOCK_SIZE,
)
add_kernel_with_none_param_and_equal_to_1_arg[(1,)](
2.71 * out1,
None,
out2,
n_elements,
x.stride(0), # equal to 1
ARGS_PASSED="one",
BLOCK_SIZE=BLOCK_SIZE,
)
return out2
example_inputs = (torch.randn(1023, device=self.device),)
self.check_model(Model(), example_inputs)
@common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.empty_like(x)
n_elements = x.numel()
scaling_factor = (n_elements**0) / 1.0
add_kernel_with_scaling[(n_elements,)](
x,
y,
out,
n_elements,
scaling_factor,
BLOCK_SIZE=16,
)
return out
dynamic_shapes = None
if dynamic:
dim0_xy = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"x": {0: dim0_xy},
"y": {0: dim0_xy},
}
example_inputs = (
torch.randn(2, device=self.device),
torch.randn(2, device=self.device),
)
self.check_model(
Model(),
example_inputs,
dynamic_shapes=dynamic_shapes,
)
def test_triton_kernel_weird_param_order(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
out = torch.empty_like(x)
add_kernel_autotuned_weird_param_order[16,](
in_ptr0=x,
in_ptr1=x,
n_elements=x.numel(),
out_ptr=out,
)
return out
x = torch.randn(16, 16, device=self.device)
self.check_model(Model(), (x,))
@skipIfWindowsXPU(msg="crash on Windows XPU.")
def test_triton_kernel_dynamic_grid(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
import math
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y, n_elements_tensor):
output = torch.zeros_like(x)
n_elements_symint = n_elements_tensor.item()
n_elements = x.numel()
def grid(meta):
n_elements_complicated = n_elements_symint // 1.0
return (math.trunc(n_elements_complicated / meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](
x,
y,
output,
n_elements,
)
return output
x = torch.randn(128, device=self.device)
y = torch.randn(128, device=self.device)
n_elem = torch.tensor(128)
dim0_x = Dim("dim0_x", min=8, max=256)
dim0_y = Dim("dim0_y", min=8, max=256)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}, "n_elements_tensor": {}}
self.check_model(Model(), (x, y, n_elem), dynamic_shapes=dynamic_shapes)
def test_shifted_constraint_ranges(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
):
torch._check(y.size(0) == x.size(0) + 1)
return x.sum(0) + y.sum(0)
a = torch.randn((4, 5), device=self.device)
b = torch.randn((5, 5), device=self.device)
dim0_x = Dim("dim0_x", min=2, max=1024)
dim0_y = dim0_x + 1
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
self.check_model(
Model(),
(a, b),
dynamic_shapes=dynamic_shapes,
)
def test_scatter_fallback(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
inp: torch.Tensor,
index: torch.Tensor,
src: torch.Tensor,
):
return torch.scatter(inp, 1, index, src)
inputs = (
torch.ones((3, 5), device=self.device, dtype=torch.int64),
torch.tensor([[0, 1, 2, 0]], device=self.device, dtype=torch.int64),
torch.zeros((2, 5), device=self.device, dtype=torch.int64),
)
self.check_model(Model(), inputs)
def test_scatter_reduce_fallback(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
inp: torch.Tensor,
index: torch.Tensor,
src: torch.Tensor,
):
return torch.scatter_reduce(inp, 0, index, src, reduce="sum")
inputs = (
torch.tensor([1, 10, 100, 1000], device=self.device, dtype=torch.int64),
torch.tensor([0, 1, 0, 1, 2, 1], device=self.device, dtype=torch.int64),
torch.tensor([1, 2, 3, 4, 5, 6], device=self.device, dtype=torch.int64),
)
self.check_model(Model(), inputs)
def test_index_put_fallback(self):
# index_put falls back in the deterministic mode
with DeterministicGuard(True):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
self_tensor: torch.Tensor,
indices: tuple[torch.Tensor],
values: torch.Tensor,
):
return torch.index_put(
self_tensor, indices, values, accumulate=True
)
inputs = (
torch.ones(4, device=self.device, dtype=torch.int64),
(torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),),
torch.ones(4, device=self.device, dtype=torch.int64),
)
self.check_model(Model(), inputs)
def test_narrow_fallback(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, inp: torch.Tensor, dim: int, start: int, length: int):
return torch.ops.aten.narrow(inp, dim, start, length)
inputs = (torch.rand((3, 4), device=self.device), 0, 0, 2)
self.check_model(Model(), inputs)
def test_pad_fallback(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
inp: torch.Tensor,
pad: tuple[int, ...],
):
return torch.ops.aten.pad(inp, pad)
inputs = (torch.rand((3, 3, 4, 2), device=self.device), (0, 1, 2, 1, 3, 3))
self.check_model(Model(), inputs)
def test_fill__fallback(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, inp: torch.Tensor, scalar: float):
torch.ops.aten.fill_(inp, scalar)
return inp
inputs = (torch.rand((3, 3, 4, 2), device=self.device), 0.5)
self.check_model(Model(), inputs)
@common_utils.parametrize("embed_kernel_binary", [False, True])
def test_repeated_user_defined_triton_kernel(self, embed_kernel_binary):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
for _ in range(3):
mul2_inplace_kernel[4,](x, n_elements=4, BLOCK_SIZE=16)
return x
inputs = (torch.randn(4, 4, device=self.device),)
with config.patch({"aot_inductor.embed_kernel_binary": embed_kernel_binary}):
model = Model()
self.check_model(model, inputs)
_, code = run_and_get_cpp_code(AOTIRunnerUtil.compile, model, inputs)
FileCheck().check("launchKernel(").run(code)
if config.aot_inductor.embed_kernel_binary:
# Not expect to see launchKernel("CUBIN_FILE_NAME"
FileCheck().check_not('launchKernel("').run(code)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_convolution(self):
if self.device == "cpu":
raise unittest.SkipTest("using triton backend only is not supported on CPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, w, b):
return torch.ops.aten.convolution(x, w, b, [4], [0], [1], True, [0], 1)
example_inputs = (
torch.randn([2, 32, 90], device=self.device),
torch.randn([32, 16, 8], device=self.device),
torch.randn([16], device=self.device),
)
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "Triton",
}
):
self.check_model(Model(), example_inputs)
def test_zero_size_weight(self):
class Model(torch.nn.Module):
def __init__(self, channel, r=8):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d(1)
self.net = torch.nn.Sequential(
torch.nn.Linear(channel, channel // r, bias=False),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(channel // r, channel, bias=False),
torch.nn.Sigmoid(),
)
def forward(self, inp):
b, c, _, _ = inp.shape
x = self.pool(inp).view(b, c)
x = self.net(x).view(b, c, 1, 1)
x = inp * x
return x
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
self.check_model(Model(4), inputs)
def test_zero_size_buffer(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.foo = torch.nn.Buffer(torch.zeros((0, 0), device=device))
def forward(self, x):
return x + 1, self.foo
example_inputs = (torch.rand(4, 4, device=self.device),)
self.check_model(Model(self.device), example_inputs)
def test_no_args(self):
class Model(torch.nn.Module):
def __init__(self, m, n):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(m, n),
)
self.alpha = torch.nn.Parameter(torch.randn(m, n))
def forward(self):
return self.weight * self.alpha
self.check_model(Model(6, 4), ())
def test_dynamic_scalar(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.criterion_ce = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self, inputs, targets, split_index=None):
statistics = {}
total_loss = self.criterion_ce(inputs, targets).sum()
statistics["dl"] = total_loss.item()
return total_loss, statistics
inputs = (
torch.rand(4, 4, 4, 4, device=self.device),
torch.rand(4, 4, 4, 4, device=self.device),
)
self.check_model(Model(), inputs)
def test_symint_item(self):
class Model(torch.nn.Module):
def forward(self, tensor):
return tensor.item()
inputs = (torch.tensor([1], dtype=torch.int, device=self.device),)
self.check_model(Model(), inputs)
def test_symbool_item(self):
class Model(torch.nn.Module):
def forward(self, tensor):
return tensor.item()
inputs = (torch.tensor([0], dtype=torch.bool, device=self.device),)
self.check_model(Model(), inputs)
def test_symfloat_item(self):
class Model(torch.nn.Module):
def forward(self, tensor):
return tensor.item()
inputs = (torch.tensor([3.14], dtype=torch.float, device=self.device),)
self.check_model(Model(), inputs)
def test_constant_original_fqn_and_dtype(self):
class FooBarModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4)))
self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
self.register_parameter(
"test_param", torch.nn.Parameter(torch.randn(3, 4))
)
def forward(self, x):
return ((x + self.test_buf) * getattr(self, "0")) / self.test_param
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo_bar = FooBarModule()
self.register_parameter(
"test_param", torch.nn.Parameter(torch.randn(3, 4))
)
self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
def forward(self, x):
return (self.foo_bar(x) + self.test_param) * self.test_buf
with torch.no_grad():
so_path = AOTIRunnerUtil.legacy_compile(
model=TestModule().to(device=self.device),
example_inputs=(torch.rand(3, 4, device=self.device),),
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
expected_original_fqns = {
"L__self___test_param": "test_param",
"L__self___test_buf": "test_buf",
"L__self___foo_bar_0": "foo_bar.0",
"L__self___foo_bar_test_param": "foo_bar.test_param",
"L__self___foo_bar_test_buf": "foo_bar.test_buf",
}
self.assertEqual(
expected_original_fqns, runner.get_constant_names_to_original_fqns()
)
expected_dtypes = {
"L__self___test_param": 6,
"L__self___test_buf": 6,
"L__self___foo_bar_0": 6,
"L__self___foo_bar_test_param": 6,
"L__self___foo_bar_test_buf": 6,
}
self.assertEqual(expected_dtypes, runner.get_constant_names_to_dtypes())
def test_masked_select_dynamic(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
mask = x.ge(0.5)
return torch.masked_select(x, mask)
example_args = (torch.randn(3, 4, 5, device=self.device),)
dim0_x_max, dim1_x_max = 100, 7
dynamic_shapes = {
"x": {
0: Dim("dim0_x", max=dim0_x_max),
1: Dim("dim1_x_max", max=dim1_x_max),
}
}
m = M()
self.check_model(m, example_args, dynamic_shapes=dynamic_shapes)
def test_proxy_executor_permute(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.permute.default(x, [0, 2, 1])
example_args = (torch.randn((1, 3001, 201), dtype=torch.complex64),)
m = M()
self.check_model(m, example_args)
def test_proxy_executor_abs(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.abs.default(x)
example_args = (torch.randn((1, 3001, 201), dtype=torch.complex64),)
m = M()
self.check_model(m, example_args)
def test_proxy_executor_squeeze(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.squeeze.dim(x, 0)
example_args = (torch.randn((1, 300, 201), dtype=torch.complex64),)
m = M()
self.check_model(m, example_args)
def test_proxy_executor_hann(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self):
return torch.ops.aten.hann_window.default(400)
example_args = ()
m = M()
self.check_model(m, example_args)
def test_fqn(self):
class NestedChild(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nestedchild3buffer = torch.nn.Buffer(torch.ones(2, 3) * 3)
def forward(self, x):
return x / self.nestedchild3buffer
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3) * 2)
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3) * 4)
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
self.check_model(MyModule(), (torch.randn(2, 3, device=self.device),))
def test_model_modified_weights(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M = 16
N = 10
K = 128
example_inputs = (torch.randn(2, M, K, device=self.device),)
model = Model(N, K, self.device)
self.check_model(model, example_inputs)
# Update model weights, after this AOTInductor should re-generate model.so
# if weights are stored in the model.so
model.weight += 1
self.check_model(model, example_inputs)
@skipIfWindowsXPU(msg="crash on Windows XPU.")
def test_triton_kernel_extern_kernel_arg(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.zeros_like(x)
# torch.mm is ExternKernelOut
add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
return out
example_inputs = (
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
)
self.check_model(Model(), example_inputs)
def test_triton_kernel_multi_output_arg(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.zeros_like(x)
# torch.sort creates fallback kernel and hence MultiOutput
add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
return out
example_inputs = (
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
)
self.check_model(Model(), example_inputs)
def test_triton_kernel_reinterpret_view_mem_leak(self):
# Check for memory leak when using user-defined Triton Kernel + AOTI.
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
out = torch.zeros_like(x)
yy = y * y
# reshape creates a ReinterpretView
add_kernel[(4,)](x, yy.reshape_as(x), out, 4, 16)
return out
example_inputs = (
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(1, 16, device=GPU_TYPE),
)
package_path: str = AOTIRunnerUtil.compile(
Model(),
example_inputs,
)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
# Don't assign outputs to a variable b/c it will allocate GPU memory.
device_interface = get_interface_for_device(GPU_TYPE)
device: int = device_interface.current_device()
mem_before = device_interface.memory_allocated(device)
aot_inductor_module(*example_inputs)
aot_inductor_module(*example_inputs)
mem_after = device_interface.memory_allocated(device)
self.assertEqual(mem_before, mem_after)
actual = aot_inductor_module(*example_inputs)
expected = Model()(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipIfMPS
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("autotuning", [False, True])
def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y, n_elements_tensor):
output = torch.zeros_like(x)
n_elements_symint = n_elements_tensor.item()
n_elements = x.numel()
def grid(meta):
return (triton.cdiv(n_elements_symint, meta["BLOCK_SIZE"]),)
if autotuning:
add_kernel_autotuned[grid](
x,
y,
output,
n_elements,
)
else:
add_kernel[grid](
x,
y,
output,
n_elements,
BLOCK_SIZE=16,
)
return output
example_inputs = (
torch.randn(123, device=GPU_TYPE),
torch.randn(123, device=GPU_TYPE),
torch.tensor(123),
)
dynamic_shapes = None
if dynamic:
dim0 = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"x": {0: dim0},
"y": {0: dim0},
"n_elements_tensor": {},
}
self.check_model(
Model(),
example_inputs,
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
)
def test_scaled_dot_product_efficient_attention(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, q, k, v, attn_bias):
return torch.ops.aten._scaled_dot_product_efficient_attention(
q, k, v, attn_bias, False
)[0]
example_inputs = (
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
)
self.check_model(Model(), example_inputs)
def test_aoti_runtime_asserts(self):
from torch.export._draft_export import draft_export, FailureType
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a[: b.item()]
@torch.library.register_fake("mylib::foo", lib=lib)
def foo_fake_impl(a, b):
ctx = torch.library.get_ctx()
u = ctx.new_dynamic_size()
return torch.empty(u)
class M(torch.nn.Module):
def forward(self, a, b):
res = torch.ops.mylib.foo(a, b)
s = res.shape[0]
torch._check(s > 3)
torch._check(s < a.shape[0])
return a[s - 3]
example_inputs = (torch.randn(100), torch.tensor(10))
ep = draft_export(M(), example_inputs)
report = ep._report
need_config_patch = any(
not f.xfail and f.failure_type == FailureType.MISMATCHED_FAKE_KERNEL
for f in report.failures
)
m = ep.module()
# This should no longer be needed after #150093
from torch._functorch import config as functorch_config
with functorch_config.patch(
{"generate_fake_kernels_from_real_mismatches": need_config_patch}
):
pt2_file = torch._inductor.aoti_compile_and_package(ep)
optimized = torch._inductor.aoti_load_package(pt2_file)
self.assertTrue(same(optimized(*example_inputs), m(*example_inputs)))
with self.assertRaisesRegex(Exception, "run_func_(.*) API call failed "):
optimized(torch.randn(100), torch.tensor(2))
@patch.dict(os.environ, {"TORCHINDUCTOR_SCALAR_ASSERTS_FULL": "1"})
def test_aoti_runtime_asserts_backed_symint(self):
if not full_aoti_runtime_assert():
raise unittest.SkipTest("full runtime assert not turned on")
class Model(torch.nn.Module):
def forward(self, x):
y = x.reshape(100, -1).clone()
y = y + 1
return y
model = Model().to(self.device)
input1 = (torch.rand(100, device=self.device),)
input2 = (torch.rand(2099, device=self.device),)
dynamic_shapes = {
"x": {0: torch.export.Dim.DYNAMIC},
}
package_path = AOTIRunnerUtil.compile(
model,
input1,
dynamic_shapes=dynamic_shapes,
)
optimized = torch._inductor.aoti_load_package(package_path)
self.assertEqual(model(*input1), optimized(*input1))
with self.assertRaisesRegex(Exception, "run_func_(.*) API call failed "):
optimized(*input2)
@skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation")
def test_index_put_with_none_index(self):
# index_put falls back in the deterministic mode
with DeterministicGuard(True):
class Model(torch.nn.Module):
def forward(self, x, i1, i2, y):
return torch.ops.aten.index_put(
x,
(None, None, i1, i2.transpose(0, 1)),
y,
accumulate=True,
)
example_inputs = (
torch.rand(8, 192, 30, 30, device=self.device),
torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
torch.ones(14, 3, dtype=torch.int64, device=self.device),
torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
)
self.check_model(Model(), example_inputs)
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
def test_runtime_checks(self):
class Model(torch.nn.Module):
def forward(self, inputs):
return list(inputs.values())
inputs = {}
dtypes = [
torch.float16,
torch.float32,
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
]
if not TEST_MPS:
dtypes.append(torch.float64)
if SM80OrLater:
dtypes.append(torch.bfloat16)
for dtype in dtypes:
inputs[f"x_{str(dtype)}"] = torch.ones(
4, 8, 10, dtype=dtype, device=self.device
)
dim0 = Dim("s0", min=2, max=1024)
dim1 = Dim("s1", min=2, max=512)
dim2 = Dim("s2", min=2, max=128)
dynamic_shapes = {
"x_torch.float16": {0: dim0},
"x_torch.float32": {0: dim0},
"x_torch.bool": {1: dim1},
"x_torch.int8": {1: dim1},
"x_torch.int16": {},
"x_torch.int32": {2: dim2},
"x_torch.int64": {2: dim2},
"x_torch.uint8": {2: dim2},
}
if not TEST_MPS:
dynamic_shapes["x_torch.float64"] = {0: dim0}
if SM80OrLater:
dynamic_shapes["x_torch.bfloat16"] = {1: dim1}
m = Model()
inputs = (inputs,)
dynamic_shapes = (dynamic_shapes,)
with torch.no_grad():
so_path = AOTIRunnerUtil.legacy_compile(
m, inputs, dynamic_shapes=dynamic_shapes
)
# Expected results for the following checks:
# ("unmatched dtype", "unmatched dim value at", "dim value is too", "unmatched stride value at")
if SM80OrLater:
# 10 dynamic dims
expected_results = (10, 21, 18, 21)
elif TEST_MPS:
# 8 dynamic dims
expected_results = (8, 17, 14, 16)
else:
# 9 dynamic dims
expected_results = (9, 19, 16, 19)
with open(os.path.splitext(so_path)[0] + ".cpp") as cpp:
src_code = cpp.read()
FileCheck().check_count(
"unmatched dtype",
expected_results[0],
exactly=True,
).run(src_code)
FileCheck().check_count(
"unmatched dim value at",
expected_results[1],
exactly=True,
).run(src_code)
FileCheck().check_count(
"dim value is too",
expected_results[2],
exactly=True,
).run(src_code)
FileCheck().check_count(
"unmatched stride value at",
expected_results[3],
exactly=True,
).run(src_code)
self.check_model(m, inputs)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
def test_runtime_checks_fp8(self):
# cuda only
if self.device != "cuda":
return
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x0, x1):
t = x0.to(torch.float) + x1.to(torch.float)
return t
inputs = []
for dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
# FP8 funz are for AMD
# see https://github.com/pytorch/pytorch/issues/126734
# torch.float8_e4m3fnuz,
# torch.float8_e5m2fnuz,
):
inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device))
dim0 = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"x0": {0: dim0},
"x1": {0: dim0},
}
with torch.no_grad():
self.check_model(
Model(),
tuple(inputs),
dynamic_shapes=dynamic_shapes,
)
@skipIfXpu(msg="Total size of kernel arguments exceeds driver limit on XPU")
def test_runtime_checks_large(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, *inputs):
result = inputs[0]
for i in range(1, len(inputs)):
result = result + inputs[i]
return result
inputs = []
for i in range(1000):
inputs.append(torch.ones(8, 8, 8, dtype=torch.float16, device=self.device))
inputs = tuple(inputs)
model = Model()
with torch.no_grad():
AOTIRunnerUtil.compile(
model,
inputs,
)
def test_runtime_checks_complex(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x0, x1, x2):
return (x0, x1, x2)
inputs = []
x0 = torch.tensor([1, -1], dtype=torch.complex32, device=self.device)
x1 = torch.tensor(
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
dtype=torch.complex64,
device=self.device,
)
x2 = torch.tensor(128, dtype=torch.complex128, device=self.device)
inputs.append(x0)
inputs.append(x1)
inputs.append(x2)
dim0 = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"x0": {0: dim0},
"x1": {},
"x2": {},
}
with torch.no_grad():
self.check_model(
Model(),
tuple(inputs),
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
def test_runtime_checks_dtype_failed(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
y = x.type(torch.float)
return y
x = torch.randn(1, 4, dtype=torch.float16, device=self.device)
model = Model()
with torch.no_grad():
package_path: str = AOTIRunnerUtil.compile(
model,
(x,),
)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
x_casted = x.float()
with self.assertRaisesRegex(Exception, ""):
aot_inductor_module(x_casted)
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
def test_runtime_checks_device_type_failed(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + 1
x = torch.randn(1, 4, dtype=torch.float16, device="cpu")
model = Model()
with torch.no_grad():
package_path: str = AOTIRunnerUtil.compile(
model,
(x,),
)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
aot_inductor_module(x)
x_casted = x.to(GPU_TYPE)
with self.assertRaisesRegex(Exception, ""):
aot_inductor_module(x_casted)
def test_non_contiguous_output_alias(self):
# Test return x, x.contiguous() where x is non-contiguous.
class Model(torch.nn.Module):
def forward(self, x):
squared = x * x
transposed = squared.t() # non-contiguous
contig = transposed.contiguous()
return transposed, contig
x = torch.randn(3, 4, dtype=torch.float16, device=self.device)
model = Model()
with torch.no_grad():
result = AOTIRunnerUtil.run(
model,
(x,),
)
actual = model(x)
self.assertTrue(same(result, actual))
# contiguous() should create a new tensor
self.assertTrue(result[0].data_ptr() != result[1].data_ptr())
def test_multiple_output_alias(self):
# Test when multiple outputs alias the same tensor
class Model(torch.nn.Module):
def forward(self, x):
squared = x * x
contig = squared.contiguous() # alias
reshaped = squared.reshape(squared.shape) # alias
cubed = squared * x
return squared, contig, reshaped, cubed
x = torch.randn(3, 4, dtype=torch.float32, device=self.device)
model = Model()
with torch.no_grad():
result = AOTIRunnerUtil.run(
model,
(x,),
)
actual = model(x)
self.assertTrue(same(result, actual))
# squared, contig and reshaped alias the same tensor.
self.assertTrue(result[0].data_ptr() == result[1].data_ptr())
self.assertTrue(result[0].data_ptr() == result[2].data_ptr())
# cubed shouldn't be an alias.
self.assertTrue(result[0].data_ptr() != result[3].data_ptr())
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
def test_runtime_checks_shape_failed(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x
x = torch.randn(4, 4, 4, dtype=torch.float16, device=self.device)
y0 = torch.randn(8, 4, 4, dtype=torch.float16, device=self.device)
y1 = torch.randn(4, 8, 4, dtype=torch.float16, device=self.device)
y2 = rand_strided(
(4, 4, 4), (16, 1, 4), dtype=torch.float16, device=self.device
)
# batch size is outside of the range
y3 = torch.randn(2048, 3, 4, dtype=torch.float16, device=self.device)
y4 = torch.randn(2048, 4, 4, dtype=torch.float16, device=self.device)
dim0 = Dim("s0", min=4, max=1024)
dynamic_shapes = {
"x": {0: dim0},
}
model = Model()
with torch.no_grad():
package_path: str = AOTIRunnerUtil.compile(
model, (x,), dynamic_shapes=dynamic_shapes
)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
# dynamic dim works fine
_ = aot_inductor_module(y0)
with self.assertRaisesRegex(Exception, ""):
aot_inductor_module(y1)
with self.assertRaisesRegex(Exception, ""):
aot_inductor_module(y2)
with self.assertRaisesRegex(Exception, ""):
aot_inductor_module(y3)
with self.assertRaisesRegex(Exception, ""):
aot_inductor_module(y4)
def test_add_complex(self):
class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)
x = torch.tensor(
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
)
y = torch.tensor(
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
)
self.check_model(Model(), (x, y))
def test_embedding_bag(self):
class Model(torch.nn.Module):
def forward(self, w, i, o):
return torch.ops.aten._embedding_bag(w, i, o, False, 0, False, None)
example_inputs = (
torch.randn([10, 4], device=self.device),
torch.randint(10, [8], device=self.device),
torch.tensor([0, 2, 6], device=self.device),
)
self.check_model(Model(), example_inputs)
@unittest.skipIf(
TEST_MPS and MACOS_VERSION < 14.0,
"FFT operations are only supported on MacOS 14+",
)
def test_fft_c2c(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.fft.fftn(x), torch.fft.fftn(x).real
example_inputs = (torch.randn(16, 16, 16, device=self.device),)
self.check_model(Model(), example_inputs)
def test_bool_input(self):
# Specialize on whichever branch the example input for b is
class Model(torch.nn.Module):
def forward(self, x, b):
if b:
return x * x
else:
return x + x
example_inputs = (torch.randn(3, 3, device=self.device), True)
self.check_model(Model(), example_inputs)
def test_int_list_input(self):
class Model(torch.nn.Module):
def forward(self, x, i):
return x * i[0] * i[1]
example_inputs = (torch.randn(3, 3, device=self.device), [3, 4])
self.check_model(Model(), example_inputs)
def test_nested_tensor_from_jagged(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
)
def forward(self, values, offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
res = self.mlp(nt)
return res.values()
model = Model().to(device=self.device)
example_inputs_1 = (
torch.randn((15, 128), device=self.device),
torch.tensor([0, 3, 4, 10, 15], device=self.device),
)
# same "NT batch size", different actual amount of data
example_inputs_2 = (
torch.randn((31, 128), device=self.device),
torch.tensor([0, 1, 20, 25, 31], device=self.device),
)
# same actual amount of data, different "NT batch size"
example_inputs_3 = (
torch.randn((15, 128), device=self.device),
torch.tensor([0, 3, 10, 15], device=self.device),
)
# different "NT batch size"
example_inputs_4 = (
torch.randn((37, 128), device=self.device),
torch.tensor([0, 5, 16, 25, 29, 37], device=self.device),
)
dim0_values = Dim("dim0_values", min=1, max=128)
dim0_offsets = Dim("dim0_offsets", min=1, max=9)
dynamic_shapes = {"values": {0: dim0_values}, "offsets": {0: dim0_offsets}}
example_inputs_list = [
example_inputs_1,
example_inputs_2,
example_inputs_3,
example_inputs_4,
]
for example_input in example_inputs_list:
actual = AOTIRunnerUtil.legacy_run(
self.device,
model,
example_input,
dynamic_shapes=dynamic_shapes,
)
self.assertTrue(same(model(*example_input), actual))
# Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for
# AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners
@common_utils.parametrize("max_autotune", [True, False])
@skipIfRocmArch(MI300_ARCH)
def test_misc_1(self, max_autotune):
if self.device == "cpu" and IS_MACOS and max_autotune:
raise unittest.SkipTest("max_autotune not supported on macos")
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
)
self.emb = nn.EmbeddingBag(num_embeddings=128, embedding_dim=32)
self.over_arch = nn.Sequential(
nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 32), nn.Sigmoid()
)
def forward(self, x, y):
mlp_output = self.mlp(x)
emb_output = self.emb(y)
return self.over_arch(torch.concat([mlp_output, emb_output], dim=1))
example_inputs = (
torch.randn(16, 128, device=self.device),
torch.randint(0, 128, (16, 10), device=self.device),
)
self.check_model(
Model(), example_inputs, options=dict(max_autotune=max_autotune)
)
@skip_if_no_torchvision
def test_torchvision_transforms_functional_tensor_resize(self):
import torchvision
# https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/
class A(torch.nn.Module):
def forward(self, image: torch.Tensor, target_size: torch.Tensor):
target_h, target_w = target_size.tolist()
torch._check(target_h > 0)
torch._check(target_w > 0)
torch._check(target_h <= 4000)
torch._check(target_w <= 4000)
return torchvision.transforms._functional_tensor.resize(
image,
size=[target_h, target_w],
interpolation="bilinear",
antialias=False,
)
model = A()
example_inputs = (
torch.ones([3, 800, 600], device=self.device),
torch.tensor([448, 336], device=self.device),
)
dynamic_shapes = {
"image": {
1: torch.export.Dim("height", min=1, max=4000),
2: torch.export.Dim("width", min=1, max=4000),
},
"target_size": None,
}
self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
@unittest.skipIf(config.triton.native_matmul, "matmul is generated")
def test_aoti_debug_printer_codegen(self):
# basic addmm model to test codegen for aoti intermediate debug printer
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M = 8
N = 6
K = 16
model = Model(N, K, self.device)
batch = 2
a = torch.randn(batch, M, K, device=self.device)
example_inputs = (a,)
if self.device == "mps":
kernel_calls = [("aoti_torch_mps_addmm_out", 2)]
elif self.device == GPU_TYPE:
kernel_calls = [
("triton_poi_fused_0", 1),
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
]
else:
kernel_calls = [("aoti_torch_cpu_addmm_out", 2)]
# test default debug printing all tensor values codegen
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
# check the codegen for debug printing around the actual kernel call is expected
for kernel_call, count in kernel_calls:
FileCheck().check_count(
f"before_launch - {kernel_call}",
count,
).run(code)
FileCheck().check_count(
f"after_launch - {kernel_call}",
count,
).run(code)
# test printing selected kernel's tensor values codegen
filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
with config.patch(
{
"aot_inductor.debug_intermediate_value_printer": "2",
"aot_inductor.filtered_kernel_names": filtered_kernel_name,
}
):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
filtered_kernel_calls = [
(filtered_kernel_name, 2),
]
for kernel_call, count in filtered_kernel_calls:
FileCheck().check_count(
f"before_launch - {kernel_call}",
count,
).run(code)
FileCheck().check_count(
f"after_launch - {kernel_call}",
count,
).run(code)
kernel_calls_not_to_print = [
kernel_call
for kernel_call in kernel_calls
if kernel_call[0] != filtered_kernel_name
]
for kernel_name, _ in kernel_calls_not_to_print:
FileCheck().check_not(f"before_launch - {kernel_name}").run(code)
FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
@unittest.skipIf(
config.triton.native_matmul, "different kernel name when native matmul"
)
@common_utils.parametrize("enable_kernel_profile", (True, False))
def test_aoti_profiler(self, enable_kernel_profile):
# basic addmm model
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
if sys.platform not in ["linux", "win32"]:
raise unittest.SkipTest(
"enable_kernel_profile only supported on linux and win32"
)
M = 8
N = 6
K = 16
model = Model(N, K, self.device)
batch = 2
a = torch.randn(batch, M, K, device=self.device)
example_inputs = (a,)
kernel_calls = (
f"aoti_torch_{GPU_TYPE}_addmm_out"
if self.device == GPU_TYPE
else "aoti_torch_cpu_addmm_out"
)
with config.patch({"cpp.enable_kernel_profile": enable_kernel_profile}):
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
shim_fn_codes = f'RAIIAtenRecordFunctionHandle .*\\("{kernel_calls}"'
if enable_kernel_profile:
FileCheck().check_regex(shim_fn_codes).run(code)
else:
FileCheck().check_not("RAIIAtenRecordFunctionHandle").run(code)
self.check_model(Model(N, K, self.device), example_inputs)
def test_aoti_user_defined_triton_kernel_profiling(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
out = torch.zeros_like(x)
add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16)
return out
example_inputs = (
torch.randn(4, 4, device=self.device),
torch.randn(4, 4, device=self.device),
)
with (
config.patch({"cpp.enable_kernel_profile": True}),
torch.profiler.profile(
record_shapes=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
getattr(torch.profiler.ProfilerActivity, GPU_TYPE.upper()),
],
) as prof,
):
self.check_model(Model(), example_inputs)
with common_utils.TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
import json
j = json.load(f)
op_events = [
e
for e in j["traceEvents"]
if e.get("name", "") == "kernels_.add_kernel_0"
]
self.assertEqual(len(op_events), 1)
self.assertEqual(
op_events[0]["args"].get("Input Args", ""),
["in_ptr0", "in_ptr1", "out_ptr", "n_elements"],
)
def test_aoti_debug_printer_user_defined_triton_kernel(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
out = torch.zeros_like(x)
add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16)
return out
example_inputs = (
torch.randn(4, 4, device=self.device),
torch.randn(4, 4, device=self.device),
)
kernel_calls = [
("add_kernel_0", 3),
]
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, Model(), example_inputs
)
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
# check the codegen for debug printing around the actual kernel call is expected
for kernel_call, count in kernel_calls:
FileCheck().check_count(
f"before_launch - {kernel_call}",
count,
).run(code)
FileCheck().check_count(
f"after_launch - {kernel_call}",
count,
).run(code)
def test_aoti_debug_printer_cpp_kernel(self):
if self.device != "cpu":
raise unittest.SkipTest("cpu test case only")
# a simple cpp kernel test case for testing the debug printer codegen
# on cpp kernel cpu device.
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
t = torch.tensor(x.size(-1), device="cpu", dtype=torch.float)
t = torch.sqrt(t * 3)
return x * t
example_inputs = (torch.randn(4, 4, device="cpu"),)
kernel_calls = [
("cpp_fused_mul_sqrt_0", 2),
]
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, Model(), example_inputs
)
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
# check the codegen for debug printing around the actual kernel call is expected
for kernel_call, count in kernel_calls:
FileCheck().check_count(
f"before_launch - {kernel_call}",
count,
).run(code)
FileCheck().check_count(
f"after_launch - {kernel_call}",
count,
).run(code)
def test_aoti_debug_printer_sym_inputs(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
from torch.testing._internal.triton_utils import add_kernel
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
maxlen = max(x.item(), 512)
a = torch.ones(maxlen, device=GPU_TYPE)
b = torch.ones(maxlen, device=GPU_TYPE)
out = torch.zeros_like(a)
# unbacked symint in grid
add_kernel[(1, 1, maxlen)](a, b, out, maxlen, 32)
return out
example_inputs = (torch.randint(high=1024, size=(1,), device=self.device),)
expected_scalar_args = [
"triton_poi_fused_zeros_like_0_xnumel",
"triton_poi_fused_ones_1_xnumel",
"std::max(static_cast<int64_t>(512L), static_cast<int64_t>(u0))",
]
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, Model(), example_inputs
)
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
for scalar in expected_scalar_args:
FileCheck().check_count(
f"{scalar}",
2,
).run(code)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@skipIfXpu
def test_aoti_debug_printer_fp8_dtype(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.out_dtype = dtype
def forward(self, x, weight, bias, scale_a, scale_b):
weight = weight.to(e4m3_type)
output = torch._scaled_mm(
x,
weight,
bias=input_bias,
out_dtype=self.out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
return output
dtype = torch.float16
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
weight_shape = (32, 16)
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale
x_shape = (16, 16)
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(e4m3_type)
kernel_calls = [
(f"aoti_torch_{GPU_TYPE}__scaled_mm_out", 5),
]
# test default debug printing all tensor values codegen
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile,
Model(dtype),
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
)
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
# check the codegen for debug printing around the actual kernel call is expected and float8 dtype is printed as expected
for kernel_call, count in kernel_calls:
FileCheck().check_count(
f"before_launch - {kernel_call}",
count,
).run(code)
FileCheck().check_count(
f"after_launch - {kernel_call}",
count,
).run(code)
def test_aoti_debug_printing_model_inputs_codegen(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
x = a * 3.14
y = torch.addmm(c, x, b)
z = torch.nn.functional.gelu(y)
return z
example_inputs = (
torch.randn(10, 20, device=GPU_TYPE),
torch.randn(20, 30, device=GPU_TYPE),
torch.randn(10, 30, device=GPU_TYPE),
)
model = Model()
kernel_calls = [
("aoti_model_inputs", 3),
]
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
# check if the triton kernel is printed as comment
self.assertEqual("def triton_" in code, True)
# check the codegen for debug printing around aoti model inputs is expected
for kernel_call, count in kernel_calls:
FileCheck().check_count(
f"{kernel_call}",
count,
).run(code)
def test_size_from_multi_output(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
_x, _i = torch.unique(x, sorted=True, return_inverse=True)
_x = _x.detach().clone()
return self.relu(_x), _i
example_inputs = (torch.randn(8, device=self.device),)
self.check_model(Model(), example_inputs)
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_sym_i64_input_codegen(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
from torch.testing._internal.triton_utils import add_kernel
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
x_symint = x.item()
a = torch.ones(x_symint, device=GPU_TYPE)
b = torch.ones(x_symint, device=GPU_TYPE)
out = torch.zeros_like(a)
# unbacked symint in grid
add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32)
return out
example_inputs = (
torch.randint(high=1024, size=(1,), device=self.device, dtype=torch.int32),
)
# This simple unit test case model generates two triton kernels:
# 1. triton_poi_fused_ones_1:
# triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i64'}
# 2. add_kernel:
# triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr': '*fp32', 'n_elements': 'i64'}
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
expected_scalar_args = [
"buf3, u0",
"buf4, u0",
"buf4, buf5, buf3, u0",
]
if full_aoti_runtime_assert():
# we'll have one more assertion
expected_scalar_args = [
"buf4, u0",
"buf5, u0",
"buf5, buf6, buf4, u0",
]
# check the new behavior of codegen is expected
result, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, Model(), example_inputs
)
for scalar_line in expected_scalar_args:
FileCheck().check_count(
scalar_line,
1,
).run(code)
self.check_model(Model(), example_inputs)
def test_input_codegen_with_sympy_expr(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class MyModel(torch.nn.Module):
def forward(self, getitem_54, getitem_52, getitem_19, values_2, offsets):
bitwise_or = torch.bitwise_or(getitem_54, getitem_52)
combined = torch.cat([getitem_19, values_2], dim=0)
add = combined + bitwise_or
sliced = values_2[:-1] + offsets
return add, sliced
inps = (
torch.randint(0, 1, (240,), device=GPU_TYPE, dtype=torch.uint8),
torch.randint(0, 1, (240,), device=GPU_TYPE, dtype=torch.uint8),
torch.randn((192,), device=GPU_TYPE),
torch.randn((48,), device=GPU_TYPE),
torch.randint(0, 100, (47,), device=GPU_TYPE, dtype=torch.uint8),
)
dim = torch.export.Dim("dimensionality")
derived_dim = 2 * dim
spec = {
"getitem_54": (Dim.AUTO,), # [s33 + 2*s40 + 1]
"getitem_52": (Dim.AUTO,), # [s33 + 2*s40 + 1]
"getitem_19": (derived_dim,), # [2*s40]
"values_2": (Dim.AUTO,), # [s33 + 1]
"offsets": (Dim.AUTO,), # [s33]
}
self.check_model(MyModel(), inps, dynamic_shapes=spec)
@common_utils.parametrize("mark_unbacked", (True, False))
def test_unbacked_equals_input_size_runtime_assertion(self, mark_unbacked: bool):
# This test checks the unbacked symint runtime assertions, for the following cases:
# (A) an unbacked symint equals an unbacked symint (mark_unbacked=True)
# (B) an unbacked symint equals a backed symint (mark_unbacked=False)
class Model(torch.nn.Module):
def forward(self, a, b, c):
nz = torch.nonzero(a)
ones = a.new_ones([nz.size(0), b.size(0)])
torch._check(ones.size(0) >= 1)
equals = torch.add(ones, c)
return equals
model = Model()
example_inputs = (
torch.ones(64, device=self.device),
b := torch.randn((32,), device=self.device),
c := torch.randn((64, 32), device=self.device),
)
if mark_unbacked:
torch._dynamo.decorators.mark_unbacked(c, 0)
else:
torch._dynamo.mark_dynamic(c, 0)
# Check the runtime assertion is codegen'ed.
so_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
lowerbound_check = "u1 >= 1" if mark_unbacked else "u0 >= 2"
FileCheck().check_count(lowerbound_check, 1).run(code)
compiled = AOTIRunnerUtil.legacy_load(self.device, so_path)
compiled(*example_inputs)
# Check the runtime assertion.
with self.assertRaisesRegex(Exception, ""):
unexpected_inputs = (torch.ones(0, device=self.device), b, c)
compiled(*unexpected_inputs)
# Try it again without runtime assertions.
with config.patch({"scalar_asserts": False}):
AOTIRunnerUtil.run_multiple(model, [example_inputs, unexpected_inputs])
def test_none_args_aot_codegen(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=["n_elements"],
)
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
# We want to include an arg known to be 1 at compile time
# This is because we remove None args from the arg list; changing the eq_1/constexpr arg indices.
# We want to make sure we recompute these correctly
EQ_1_ARG,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
if in_ptr0 is not None:
x = tl.load(in_ptr0 + offsets, mask=mask)
else:
x = 0.0
output = tl.sin(x) + EQ_1_ARG
tl.store(out_ptr + offsets, output, mask=mask)
def sin_triton(x, out):
n_elements = out.numel()
sin_kernel[(n_elements,)](x, out, 1, n_elements)
return out
x = torch.randn(65, device=self.device)
out = torch.empty_like(x)
not_none_inputs = (x, out)
none_inputs = (None, out)
# AOTI compilation specializes on either None or non-None inputs
# So we have to check twice here
self.check_model(sin_triton, none_inputs)
self.check_model(sin_triton, not_none_inputs)
@skipIfRocm # RoCM does not support the config block size in test suite.
def test_autotune_int64_user_defined_triton_kernel(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0).to(tl.int64)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.library.triton_op("mylib::add", mutates_args=())
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output
class Model(torch.nn.Module):
def forward(self, x):
x = custom_add(x, x)
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(
x, [512, 512, 512, 512], 1
)
getitem_29 = split_with_sizes_1[0]
return getitem_29 * 3
n = 1379584
try:
buf196 = torch.randint(
0, 100, (n, 2048), dtype=torch.int8, device=self.device
)
example_inputs = (buf196,)
self.check_model(
Model(),
example_inputs,
dynamic_shapes={
"x": (Dim("x", max=1379584), Dim.STATIC),
},
options={"max_autotune": True},
)
except torch.OutOfMemoryError:
# CI can OOM because this test uses too much memory
raise unittest.SkipTest("OOM. Test is too large") from None
@skipIfWindows(
msg="OpenMP crashed application on windows"
) # TODO: (xuhancn) need to root cause and fix.
def test_issue_140766(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(128, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 128),
)
self.norm = torch.nn.LayerNorm(128)
self.attn = torch.nn.functional.scaled_dot_product_attention
def forward(self, x):
# [2, 128, 4096]
x = x.transpose(1, 2)
# [2, 4096, 128]
for _ in range(2):
x = self.forward_block(x)
return x
def forward_block(self, x):
# x: B, H*W, C
B = x.shape[0]
H, W, C = 64, 64, 128
shortcut = x
x = self.norm(x)
x = x.reshape(B, H, W, C)
# B, H, W, C
x = self.attn(x, x, x)
x = x.reshape(B, H // 8, W // 8, 8, 8, -1)
x = x.transpose(2, 3).reshape(B, H * W, -1)
x = shortcut + x
x = x + self.mlp(self.norm(x))
return x
bs = torch.export.Dim("bs", max=12)
example_inputs = (torch.randn(2, 128, 4096, device=self.device),)
self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}})
@requires_gpu
def test_d2h_copy(self):
# device to copy host should always have the same stride
if "cuda" not in self.device:
raise unittest.SkipTest("This test is only for CUDA")
class ToCpuModel(nn.Module):
def forward(self, x):
predictions = x.permute(1, 0)
predictions = torch.nan_to_num(
predictions, nan=0.0, posinf=0.0, neginf=0.0
)
predictions = predictions.to("cpu", non_blocking=True)
p = predictions[0]
ones = p.new_ones(1)
return p, ones
model = ToCpuModel().to(GPU_TYPE)
input_tensor = torch.randn(5, 10, device=GPU_TYPE).to(dtype=torch.float16)
ep = torch.export.export(model, (input_tensor,))
package_path = torch._inductor.aoti_compile_and_package(ep)
aoti_model = torch._inductor.aoti_load_package(package_path)
expect_res = model(input_tensor)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
) as prof:
true_res = aoti_model(input_tensor)
self.assertEqual(expect_res, true_res)
all_ops = [event.key for event in prof.key_averages()]
self.assertTrue(not any("aten::contiguous" in op for op in all_ops))
def test_so_without_weight(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 128, 2048, 4096
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with (
torch.no_grad(),
config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": True,
}
),
):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
with (
torch.no_grad(),
config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
}
),
):
so_path_weightless = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
self.assertTrue(os.path.getsize(so_path) > 10_000_000)
self.assertTrue(os.path.getsize(so_path_weightless) < 10_000_000)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path_weightless)
# Let's check whether the model has correct constant name mapping.
expected_original_fqns = {
"L__self___weight": "L__self___weight",
"L__self___bias": "L__self___bias",
}
self.assertEqual(
expected_original_fqns, runner.get_constant_names_to_original_fqns()
)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
attach_weights = {
"L__self___weight": model.weight,
"L__self___bias": model.bias,
}
runner.update_constant_buffer(attach_weights, False, False)
expected = model(test_inputs)
output = runner_call(test_inputs)
atol, rtol = 3e-4, 3e-4
self.assertEqual(expected, output, atol=atol, rtol=rtol)
def test_weight_on_disk_legacy(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 128, 2048, 4096
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with (
torch.no_grad(),
config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk_format": "pickle_weights",
"aot_inductor.package": True,
}
),
):
aoti_files = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
with WritableTempFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name,
{"model": aoti_files},
)
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
loaded1 = pt2_contents.aoti_runners["model"]
atol, rtol = 3e-4, 3e-4
self.assertEqual(loaded1(a), model(a), atol=atol, rtol=rtol)
def test_extract_constants_map(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 8, 6, 16
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
original_weights = {
"L__self___weight": model.weight,
"L__self___bias": model.bias,
}
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
# Extract weights with use_inactive = False, this should be the current weight.
extracted_original_weights = runner.extract_constants_map(False)
self.assertEqual(original_weights, extracted_original_weights)
# update the inactive weights with new_weights, extract inactive weights.
runner.update_constant_buffer(new_weights, True, False)
extracted_new_weights = runner.extract_constants_map(True)
self.assertEqual(new_weights, extracted_new_weights)
# Swap constant buffer, this should give us the opposite weights.
runner.swap_constant_buffer()
extracted_inactive_weights = runner.extract_constants_map(True)
extracted_active_weights = runner.extract_constants_map(False)
self.assertEqual(original_weights, extracted_inactive_weights)
self.assertEqual(new_weights, extracted_active_weights)
def test_update_constant_buffer(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 8, 6, 16
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
# Attribute naming has changed in the new export API, so still use the legacy API here.
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
# Let's check whether the model has correct constant name mapping.
expected_original_fqns = {
"L__self___weight": "L__self___weight",
"L__self___bias": "L__self___bias",
}
self.assertEqual(
expected_original_fqns, runner.get_constant_names_to_original_fqns()
)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
runner.update_constant_buffer(new_weights, False, False)
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output)
def test_update_constant_buffer_simple(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn((3, 3), device=device)
def forward(self, a):
return a + self.weight
model = Model(self.device)
a = torch.randn((3, 3), device=self.device)
example_inputs = (a,)
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
# Let's check whether the model has correct constant name mapping.
expected_original_fqns = {
"L__self___weight": "L__self___weight",
}
self.assertEqual(
expected_original_fqns, runner.get_constant_names_to_original_fqns()
)
test_inputs = torch.randn((3, 3), device=self.device)
new_weight = torch.randn((3, 3), device=self.device)
model.weight = new_weight
attach_weights = {"L__self___weight": new_weight}
runner.update_constant_buffer(attach_weights, False, False, False)
expected = model(test_inputs)
def runner_call(*args, **kwargs):
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, kwargs))[0]
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
def test_update_inactive_constant_buffer(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 8, 6, 16
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output, atol=1e-3, rtol=1e-3)
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
runner.update_constant_buffer(new_weights, True, False)
output_before_swap = runner_call(test_inputs)
runner.swap_constant_buffer()
output_after_swap = runner_call(test_inputs)
self.assertEqual(expected, output_before_swap)
self.assertEqual(new_expected, output_after_swap)
def test_free_inactive_buffer(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 8, 6, 16
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
# Check the outputs, make sure the model is correct here.
self.assertEqual(expected, output)
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
runner.update_constant_buffer(new_weights, True, False)
# Make sure we have swapped buffer
runner.swap_constant_buffer()
output_after_swap = runner_call(test_inputs)
self.assertEqual(new_expected, output_after_swap)
# Free the secondary buffer
runner.free_inactive_constant_buffer()
# Create a new set of weights to refill into the already freed buffer.
new_weights_1 = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
new_expected_1 = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
runner.update_constant_buffer(new_weights_1, True, False)
output_after_swap_1 = runner_call(test_inputs)
self.assertEqual(new_expected_1, output_after_swap_1)
runner.free_inactive_constant_buffer()
def test_update_user_managed_buffer(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 1024, 4096, 4096
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
# Attribute naming has changed in the new export API, so still use the legacy API here.
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
so_path = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output, atol=1e-3, rtol=1e-3)
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
mem_before, _ = torch.cuda.mem_get_info(self.device)
# Do not use user managed_buffer, should have less free memory.
runner.update_constant_buffer(new_weights, True, False, False)
mem_after, _ = torch.cuda.mem_get_info(self.device)
self.assertGreater(mem_before, mem_after)
runner.swap_constant_buffer()
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output, atol=1e-3, rtol=1e-3)
# Inplace substitube tensor, without user managed buffer, result should be different.
new_weights["L__self___weight"].add_(1)
new_weights["L__self___bias"].add_(1)
new_output = runner_call(test_inputs)
# Same as the previous result
self.assertEqual(new_expected, new_output, atol=1e-3, rtol=1e-3)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
# Differ from latest result
self.assertNotEqual(new_expected, new_output)
# Clear out all buffers
runner.free_inactive_constant_buffer()
runner.swap_constant_buffer()
runner.free_inactive_constant_buffer()
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
mem_before, _ = torch.cuda.mem_get_info(self.device)
# Try user managed_buffer, should have same free memory.
runner.update_constant_buffer(new_weights, True, False, True)
mem_after, _ = torch.cuda.mem_get_info(self.device)
self.assertEqual(mem_before, mem_after, atol=1e-3, rtol=1e-3)
runner.swap_constant_buffer()
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output, atol=1e-3, rtol=1e-3)
# Inplace substitube tensor, with user managed buffer, result should be the same.
new_weights["L__self___weight"].add_(1)
new_weights["L__self___bias"].add_(1)
new_output = runner_call(test_inputs)
new_expected = torch.nn.functional.linear(
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
)
self.assertEqual(new_expected, new_output, atol=1e-3, rtol=1e-3)
new_weights = {
"L__self___weight": torch.randn(N, K, device=self.device),
"L__self___bias": torch.randn(N, device=self.device),
}
runner.update_constant_buffer(new_weights, True, False, True)
runner.swap_constant_buffer()
model.weight = torch.nn.Parameter(new_weights["L__self___weight"])
model.bias = torch.nn.Parameter(new_weights["L__self___bias"])
updated_state_dict = {
"weight": torch.ones_like(model.weight),
"bias": torch.zeros_like(model.bias),
}
model.load_state_dict(updated_state_dict)
new_output = runner_call(test_inputs)
expected_output = model(test_inputs)
torch.testing.assert_close(new_output, expected_output, atol=1e-3, rtol=1e-3)
with self.assertRaises(AssertionError):
torch.testing.assert_close(new_expected, new_output, atol=1e-3, rtol=1e-3)
def test_cond_share_predicte(self):
class Model(torch.nn.Module):
def forward(self, predicate, x):
y = torch.cond(
predicate,
lambda: x + 1,
lambda: x + 2,
)
z = torch.cond(
predicate,
lambda: y + 1,
lambda: y + 2,
)
return (z,)
example_inputs = (
torch.tensor([True]).to(self.device),
torch.tensor([1, 2, 3]).to(self.device),
)
self.check_model(Model(), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"To enable after the C shim FC window ends",
)
def test_misaligned_input_1(self):
if self.device != "cuda":
raise unittest.SkipTest("CUDA test only")
class Model(torch.nn.Module):
def forward(self, x):
return x.sin() + x.cos()
N = 64 * 64 * 64 + 64
arg = torch.randn(N, device=self.device)
example_inputs = (arg,)
model = Model()
expected = model(*example_inputs)
package_path = AOTIRunnerUtil.compile(model, example_inputs)
optimized = torch._inductor.aoti_load_package(package_path)
# If the model is compiled with aligned inputs, the generated
# code will check inputs alignment at runtime
self.code_check_count(
model, example_inputs, "aoti_torch_clone_preserve_strides", 1
)
misaligned_arg = torch.zeros(N + 1, device=self.device)
misaligned_arg = misaligned_arg[1:]
misaligned_arg.copy_(arg)
actual = optimized(misaligned_arg)
torch.testing.assert_close(actual, expected)
def test_misaligned_input_2(self):
if self.device != "cuda":
raise unittest.SkipTest("CUDA test only")
class Model(torch.nn.Module):
def forward(self, x):
return x.sin() + x.cos()
N = 64 * 64 * 64 + 64
arg = torch.randn(N, device=self.device)
misaligned_arg = torch.zeros(N + 1, device=self.device)
misaligned_arg = misaligned_arg[1:]
misaligned_arg.copy_(arg)
example_inputs = (misaligned_arg,)
model = Model()
self.check_model(model, example_inputs)
# If the model is already compiled with a misaligned input, the
# generated code should NOT contain an alignment check for that input.
self.code_check_count(
model, example_inputs, "aoti_torch_clone_preserve_strides", 0
)
def test_autotuning_args_reuse(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
x_out = torch.empty_strided(
(x.size()[0], x.size()[1]), (x.size()[1], 1), device=GPU_TYPE
)
x_out = torch.permute(x_out, [0, 1])
add_kernel_autotuned[(4,)](x, x, x_out, 16)
y_out = torch.empty_strided(
(y.size()[0], y.size()[1]), (y.size()[1], 1), device=GPU_TYPE
)
y_out = torch.permute(y_out, [0, 1])
add_kernel_autotuned[(64,)](y, y, y_out, 64)
sub_kernel_autotuned[(4,)](x, x, x_out, 16)
return x_out, y_out
example_inputs = (
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(8, 8, device=GPU_TYPE),
)
dim0_x = Dim("dim0_x", min=1, max=2048)
dim0_y = Dim("dim0_y", min=1, max=2048)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
self.check_model(
Model(),
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"max_autotune": True},
)
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
@unittest.skipIf(
TEST_MPS and MACOS_VERSION < 14.0,
"FFT operations are only supported on MacOS 14+",
)
def test_stft(self):
N_FFT = 400
HOP_LENGTH = 160
class Model(torch.nn.Module):
def forward(self, x):
window = torch.hann_window(N_FFT, device=x.device)
stft = torch.stft(
x, N_FFT, HOP_LENGTH, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
return magnitudes
model = Model()
example_inputs = (torch.randn(500, device=self.device),)
self.check_model(model, example_inputs)
@skipIfXpu
def test_conv3d(self):
if self.device != GPU_TYPE or not is_big_gpu():
raise unittest.SkipTest("requires modern GPU to run max-autotune")
if not _has_sufficient_memory(self.device, 2**35):
raise unittest.SkipTest("insufficient memory")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
convert_element_type_1271,
convert_element_type_1272,
convert_element_type_1273,
):
return torch.ops.aten.convolution.default(
convert_element_type_1271,
convert_element_type_1272,
convert_element_type_1273,
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
)
example_inputs = (
torch.randn(1, 64, 5160, 5160, device=self.device),
torch.randn(3, 64, 3, 3, device=self.device),
torch.randn(3, device=self.device),
)
dynamic_shapes = {
"convert_element_type_1271": {
3: torch.export.Dim.DYNAMIC,
4: torch.export.Dim.DYNAMIC,
},
"convert_element_type_1272": None,
"convert_element_type_1273": None,
}
with config.patch(
{
"max_autotune": True,
"max_autotune_conv_backends": "TRITON",
}
):
self.check_model(
Model(),
example_inputs,
atol=0.1,
rtol=1e-3,
dynamic_shapes=dynamic_shapes,
)
def test__int_mm(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return torch._int_mm(x, y)
example_inputs = (
torch.randint(-10, 10, (64, 32), device=self.device, dtype=torch.int8),
torch.randint(-10, 10, (32, 64), device=self.device, dtype=torch.int8),
)
self.check_model(Model(), example_inputs)
@skipIfMPS
@skipIfXpu(
msg="aten::convert_weight_to_int4pack is not currently implemented for XPU"
)
@parametrize("m", [32])
@parametrize("n", [64])
@parametrize("q_group", [32, 64])
@parametrize("num_groups", [1, 2])
def test__weight_int4pack_mm(self, m, n, q_group, num_groups):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, weight, scale_and_zeros) -> None:
super().__init__()
self.weight = weight
self.scale_and_zeros = scale_and_zeros
def forward(self, a):
return torch._weight_int4pack_mm(
a, self.weight, q_group, self.scale_and_zeros
)
def convert_weight_to_int4pack(b):
b_int32, b_scales_and_zeros = _group_quantize_tensor(
b, n_bit=4, q_group_size=q_group
)
b_int4pack = torch._convert_weight_to_int4pack(b_int32, innerKTiles=2)
return b_int4pack, b_scales_and_zeros
k = q_group * num_groups
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b)
model = Model(b_int4pack, b_scales_and_zeros_f32)
self.check_model(model, (a,))
@parametrize("m", [32])
@parametrize("n", [64])
@parametrize("q_group", [32, 64])
@parametrize("num_groups", [1, 2])
def test__weight_int4pack_mm_with_scales_and_zeros(self, m, n, q_group, num_groups):
if "xpu" not in self.device:
raise unittest.SkipTest("requires Intel GPU")
class Model(torch.nn.Module):
def __init__(self, weight, scale, zeros) -> None:
super().__init__()
self.weight = weight
self.scale = scale
self.zeros = zeros
def forward(self, a):
return torch._weight_int4pack_mm_with_scales_and_zeros(
a, self.weight, q_group, self.scale, self.zeros
)
def _group_quantize_tensor_xpu(w, n_bit=4, q_group_size=16):
# w [k, n] = [32, 48]
assert w.dim() == 2
# w [n, k] = [48, 32]
w = w.transpose(0, 1).contiguous()
assert q_group_size > 1
assert w.shape[-1] % q_group_size == 0
# to_quant: [n * k / group_size, group_size]
to_quant = w.reshape(-1, q_group_size)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
assert torch.isnan(scales).sum() == 0
zeros = min_int - min_val.div(scales).round()
zeros = torch.clamp(zeros, min_int, max_int)
zeros = zeros.to(torch.int8)
assert torch.isnan(zeros).sum() == 0
out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int)
assert torch.isnan(out).sum() == 0
# [n, k]
out = out.to(dtype=torch.int32).reshape(w.shape)
if out.device != torch.device("cpu"):
out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8)
# Scales and zeros for the same q-group should be contiguous, so we can
# load as a 32-bit word
scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous()
zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous()
return out, scales, zeros
def convert_weight_to_int4pack(b):
# b_uint8 [n, k //2]
b_uint8, scales, zeros = _group_quantize_tensor_xpu(
b, n_bit=4, q_group_size=q_group
)
# b_int4pack [k//8, n]
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=2)
return b_int4pack, scales, zeros
k = q_group * num_groups
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b)
model = Model(b_int4pack, b_scales, zeros_int8)
self.check_model(model, (a,))
def test_assert_tensor_meta(self):
class Module(torch.nn.Module):
def forward(self, x):
torch.ops.aten._assert_tensor_metadata.default(
x,
dtype=torch.int32,
)
return (x + 1,)
example_inputs = (torch.tensor(1, dtype=torch.int32),)
with config.patch(
{
"implicit_fallbacks": False,
}
):
self.check_model(
Module(),
example_inputs,
atol=0.1,
rtol=1e-3,
)
@runOnRocm
def test_rocm_triton_autotuning(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y, m):
_M, K = x.shape
K, N = y.shape
M = torch.abs(m)
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
grid = lambda META: ( # noqa: E731
triton.cdiv(
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
),
)
strange_config_matmul_kernel[grid](
x,
y,
out,
M,
N,
K,
)
return out
x = torch.randn(4096, 1024, device=self.device)
y = torch.randn(1024, 2048, device=self.device)
m = torch.tensor([4096], dtype=torch.int32, device=self.device)
with (
torch.no_grad(),
config.patch(
{
"triton.autotune_with_sample_inputs": True,
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
}
),
):
torch._export.aot_compile(Model(), (x, y, m))
@skipIfRocm # RoCM does not support the config block size in test suite.
def test_triton_autotuning(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y, m):
_M, K = x.shape
K, N = y.shape
M = torch.abs(m)
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
grid = lambda META: ( # noqa: E731
triton.cdiv(
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
),
)
strange_config_matmul_kernel[grid](
x,
y,
out,
M,
N,
K,
)
return out
x = torch.randn(4096, 1024, device=self.device)
y = torch.randn(1024, 2048, device=self.device)
m = torch.tensor([4096], dtype=torch.int32, device=self.device)
with config.patch("triton.autotune_with_sample_inputs", True):
# The tuned best config on XPU is different with CUDA.
grid_0 = 32736 if GPU_TYPE == "xpu" else 1023
self.code_check_count(
Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1
)
@skipIfRocm # RoCM does not support the config block size in test suite.
def test_triton_mutated_autotuning(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.jit
def add_one_kernel(X, Y, N):
pid = tl.program_id(axis=0)
block_start = pid
offsets = block_start + tl.arange(0, 1)
x = tl.load(X + offsets, mask=offsets < N)
y = x + 1
tl.store(Y + offsets, y, mask=offsets < N)
class Model(torch.nn.Module):
def forward(self, x, y, m):
_M, K = x.shape
K, N = y.shape
M = torch.empty((1), device=x.device, dtype=torch.int32)
add_one_kernel[(1,)](m, M, 1)
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
grid = lambda META: ( # noqa: E731
triton.cdiv(
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
),
)
strange_config_matmul_kernel[grid](
x,
y,
out,
M,
N,
K,
)
return out
x = torch.randn(4096, 1024, device=self.device)
y = torch.randn(1024, 2048, device=self.device)
m = torch.tensor([4095], dtype=torch.int32, device=self.device)
with config.patch("triton.autotune_with_sample_inputs", True):
# The tuned best config on XPU is different with CUDA.
grid_0 = 32736 if GPU_TYPE == "xpu" else 1023
self.code_check_count(
Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1
)
@skipIfRocm
@patch.dict(os.environ, {"TRITON_DEBUG": "1"})
def test_triton_dynamic_launcher_grid(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=["numel"],
)
@triton.jit
def add_one_kernel(X, Y, numel, BLOCK_SIZE: "tl.constexpr"):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
tl.device_assert(block_start < numel)
offsets = block_start + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + offsets)
y = x + 1
tl.store(Y + offsets, y)
class Model(torch.nn.Module):
def forward(self, x, value):
numel = value.item()
out = torch.zeros_like(x, dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv(numel, META["BLOCK_SIZE"]),
)
add_one_kernel[grid](x, out, numel)
return out
example_inputs = (
torch.randn(1024, device=self.device),
torch.tensor([1024], dtype=torch.int32, device=self.device),
)
with config.patch("triton.autotune_with_sample_inputs", True):
dim0_x = Dim("dim0_x", min=2, max=8192)
dynamic_shapes = {"x": {0: dim0_x}, "value": {0: Dim.AUTO}}
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@skipIfRocm
@patch.dict(os.environ, {"TRITON_DEBUG": "1"})
def test_triton_dynamic_launcher_grid_infer_from_tensor(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=["numel"],
)
@triton.jit
def add_one_kernel(X, Y, numel, BLOCK_SIZE: "tl.constexpr"):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
tl.device_assert(block_start < numel)
offsets = block_start + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + offsets)
y = x + 1
tl.store(Y + offsets, y)
class Model(torch.nn.Module):
def forward(self, x, dim_D):
numel = x.shape[1] * dim_D.item()
x = x.repeat(dim_D, 1)
out = torch.zeros_like(x, dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv(numel, META["BLOCK_SIZE"]),
)
add_one_kernel[grid](x, out, numel)
return out
example_inputs = (
torch.randn(1, 1024, device=self.device),
torch.tensor([2], dtype=torch.int32, device=self.device),
)
with config.patch("triton.autotune_with_sample_inputs", True):
dim1_x = Dim("dim1_x", min=2, max=8192)
dynamic_shapes = {"x": {0: Dim.AUTO, 1: dim1_x}, "dim_D": {0: Dim.AUTO}}
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
def test_composed_dynamic_size(self):
class Model(torch.nn.Module):
def forward(self, x):
return x + 1
example_inputs = (torch.randn(10, device=self.device),)
dim = torch.export.Dim("dim_0")
dim_even = 2 * dim
dynamic_shapes = {
"x": {0: dim_even},
}
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
def test_boolean_indexing(self):
class Model(torch.nn.Module):
def forward(self, x, y, z, x1, z1):
a = x[y]
a1 = x1[y]
b = torch.cat([a, z], dim=1)
b1 = torch.cat([a1, z1], dim=1)
return b, b1
x = torch.randn(3, 5, device=self.device)
y = torch.tensor([0, 1, 1], dtype=torch.bool, device=self.device)
z = torch.randn(2, 4, device=self.device)
x1 = torch.randn(3, 5, device=self.device)
z1 = torch.randn(2, 4, device=self.device)
example_inputs = (x, y, z, x1, z1)
s0 = Dim("s0", min=0, max=10240)
s1 = Dim("s1", min=0, max=10240)
s2 = Dim("s2", min=0, max=10240)
s3 = Dim("s3", min=0, max=10240)
dynamic_shapes = {
"x": {0: s0, 1: s1},
"y": {0: s0},
"z": {0: s2, 1: s3},
"x1": {0: s0, 1: s1},
"z1": {0: s2, 1: s3},
}
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
def test_sym_expr_indexing(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
episode_builder_position_encoding_observations_weight,
add_15,
add_16,
add_17,
add_18,
add_13,
):
arange_1 = torch.ops.aten.arange.start(
180,
181,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
add_14 = torch.ops.aten.add.Tensor(arange_1, 198)
arange_1 = None
stack_1 = torch.ops.aten.stack.default(
[add_13, add_14, add_15, add_16, add_17, add_18]
)
add_13 = add_14 = add_15 = add_16 = add_17 = add_18 = None
select_13 = torch.ops.aten.select.int(stack_1, 0, 0)
stack_1 = None
embedding_11 = torch.ops.aten.embedding.default(
episode_builder_position_encoding_observations_weight, select_13
)
episode_builder_position_encoding_observations_weight = select_13 = None
return (embedding_11,)
# Embedding weight: vocab_size x emb_dim
episode_builder_position_encoding_observations_weight = torch.randn(
100, 16, device=self.device
)
# These six must all be 1-D (shape [1]) and same dtype; use Long for embedding indices
add_13 = torch.tensor(
[7], dtype=torch.long, device=self.device
) # this one is used as the index
add_15 = torch.tensor([5], dtype=torch.long, device=self.device)
add_16 = torch.tensor([6], dtype=torch.long, device=self.device)
add_17 = torch.tensor([7], dtype=torch.long, device=self.device)
add_18 = torch.tensor([8], dtype=torch.long, device=self.device)
# Instantiate and run
m = Repro().to(self.device)
example_inputs = (
episode_builder_position_encoding_observations_weight,
add_15,
add_16,
add_17,
add_18,
add_13,
)
self.check_model(m, example_inputs)
def test_with_cudagraphs(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
# define CUDAGraph handling wrapper (only works with kwargs for simplicity)
def cudagraph(f):
_graphs = {}
def f_(**kwargs):
key = hash(
tuple(
tuple(kwargs[a].shape)
for a in sorted(kwargs.keys())
if isinstance(kwargs[a], torch.Tensor)
)
)
if key in _graphs:
wrapped, *_ = _graphs[key]
return wrapped(**kwargs)
g = torch.cuda.CUDAGraph()
in_tensors = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
}
f(**in_tensors) # stream warmup
with torch.cuda.graph(g):
out_tensors = f(**in_tensors)
def wrapped(**kwargs):
for key in kwargs:
in_tensors[key].copy_(kwargs[key])
g.replay()
if isinstance(out_tensors, torch.Tensor):
return out_tensors.clone()
elif isinstance(out_tensors, (list, tuple)):
return type(out_tensors)(o.clone() for o in out_tensors)
raise ValueError("unsupported output type encountered")
_graphs[key] = (wrapped, g, in_tensors, out_tensors)
return wrapped(**kwargs)
return f_
# define a simple model
model = torch.nn.Linear(10, 20).to(device=self.device)
# export + AOTI
model_kwargs = {
"input": torch.randn(3, 10, device=self.device),
}
ep = torch.export.export(model, args=(), kwargs=model_kwargs, strict=True)
optimized = torch._inductor.aoti_load_package(
torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={"max_autotune": True},
),
# NB: this flag avoids a CUDAGraph + AOTI runtime multi-threading conflict
# "Error: operation not permitted when stream is capturing"
run_single_threaded=True,
)
# enable CUDAGraphs
optimized = cudagraph(optimized)
# warmup -> run with CUDAGraphs
for _ in range(3):
optimized(**model_kwargs)
# compare against eager
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
def test_custom_op_in_subgraph(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo_add1",
"(Tensor a) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo_add1", "CompositeExplicitAutograd", lib=lib)
@torch.library.register_fake("mylib::foo_add1", lib=lib)
def foo_add1_impl(a: torch.Tensor) -> torch.Tensor:
return a + 1
torch.library.define(
"mylib::foo_add2",
"(Tensor a) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo_add2", "CompositeExplicitAutograd", lib=lib)
@torch.library.register_fake("mylib::foo_add2", lib=lib)
def foo_add2_impl(a: torch.Tensor) -> torch.Tensor:
return a + 2
class M(torch.nn.Module):
def forward(self, x):
return torch.cond(
x.shape[0] < 5,
torch.ops.mylib.foo_add1,
torch.ops.mylib.foo_add2,
(x,),
)
list_example_inputs = [
(torch.ones(6, device=self.device),),
(torch.ones(3, device=self.device),),
]
self.check_model_with_multiple_inputs(
M(), list_example_inputs, dynamic_shapes=({0: Dim.DYNAMIC},)
)
def test_clamp_decomposition(self):
class Model1(torch.nn.Module):
def forward(self, x):
return x.clamp(min=1.5)
class Model2(torch.nn.Module):
def forward(self, x):
return x.clamp(min=2)
x = torch.randint(4, (4,))
# the output should have float32 type, not int
self.check_model(Model1(), (x,))
# the output should have int type
self.check_model(Model2(), (x,))
def test_upper_bound_i64(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
inp = (
torch.randint(0, 100, (2**18,), device=self.device, dtype=torch.int8),
torch.tensor([4], device=self.device, dtype=torch.int8),
)
ep = torch.export.export(
Model(),
inp,
dynamic_shapes=({0: Dim("d", min=0, max=2**33)}, {0: Dim.STATIC}),
)
so_path = torch._inductor.aot_compile(ep.module(), inp)
m = torch._export.aot_load(so_path, self.device)
self.assertEqual(Model()(*inp), m(*inp))
del inp
inp = (
torch.randint(0, 100, (3 * 2**30,), device=self.device, dtype=torch.int8),
torch.tensor([4], device=self.device, dtype=torch.int8),
)
# don't check the accuracy of the result to reduce memory usage
# this test is mostly checking to ensure there's no IMA.
m(*inp)
def test_using_model_name_for_files(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model().to(self.device)
with torch.no_grad():
package_path: str = AOTIRunnerUtil.compile(
model,
example_inputs,
inductor_configs={
"aot_inductor.model_name_for_generated_files": "test_model"
},
)
with zipfile.ZipFile(package_path, "r") as zip_ref:
all_files = zip_ref.namelist()
base_dir = "test_model.wrapper/data/aotinductor/model/test_model"
ext_type = get_module_ext_type()
self.assertTrue(f"{base_dir}.wrapper.cpp" in all_files)
self.assertTrue(f"{base_dir}.kernel.cpp" in all_files)
self.assertTrue(f"{base_dir}.wrapper.{ext_type}" in all_files)
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs))
def test_copy_non_blocking_is_pinned(self):
if self.device == "cpu" or self.device == "mps":
raise unittest.SkipTest("only matters for device-to-cpu copy")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
a_cpu = a.to(device="cpu", non_blocking=True)
b_cpu = b.to(device="cpu", non_blocking=True)
a_to_cpu_event = torch.Event()
a_to_cpu_event.record()
a_to_cpu_event.synchronize()
return torch.cat([a_cpu, b_cpu])
model = Model()
a = torch.randn(2, 2, device=self.device)
b = torch.randn(2, 2, device=self.device)
example_inputs = (a, b)
outputs = model(*example_inputs)
package_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
FileCheck().check("pinned").run(code)
model_aoti = torch._inductor.aoti_load_package(package_path)
outputs_aoti = model_aoti(*example_inputs)
self.assertEqual(outputs, outputs_aoti)
@unittest.skipIf(config.triton.native_matmul, "different code generated")
def test_pad_non_zero_memory_leak(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("test is only for GPU_TYPE")
class Model(torch.nn.Module):
def forward(self, x):
x = x + 1
x = torch.ops.aten.constant_pad_nd(x, (0, 1, 0, 0), 12345.0)
return x @ x
model = Model()
example_inputs = (torch.randn(2048, 2047, device=self.device),)
package_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
outputs = model(*example_inputs)
model_aoti = torch._inductor.aoti_load_package(package_path)
outputs_aoti = model_aoti(*example_inputs)
self.assertEqual(outputs, outputs_aoti, atol=1e-2, rtol=1e-2)
FileCheck().check_regex(
r"aoti_torch_as_strided\(buf0_handle, .*, &buf0_handle_restrided\)"
).check("wrap_with_raii_handle_if_needed(buf0_handle);").check(
"RAIIAtenTensorHandle buf0(buf0_handle_restrided);"
).run(code)
@unittest.skipIf(IS_MACOS, "might have no readelf on Mac")
def test_libtorch_free_so(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model().to(self.device)
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={
"aot_inductor.link_libtorch": False,
},
)
torch_libs = {
"libtorch.so",
"libc10.so",
"libtorch_cuda.so",
"libc10_cuda.so",
"libtorch_cpu.so",
"libtorch_xpu.so",
"libc10_xpu.so",
}
with tempfile.TemporaryDirectory() as tmpdir:
# Unpack
with zipfile.ZipFile(package_path, "r") as zf:
zf.extractall(tmpdir)
so_files = list(pathlib.Path(tmpdir).rglob("*.so"))
self.assertTrue(len(so_files) > 0)
for so_file in so_files:
so_copy = pathlib.Path(tmpdir) / f"{so_file.name}.checkcopy"
so_copy.write_bytes(so_file.read_bytes())
result = subprocess.run(
["readelf", "-d", str(so_copy)],
check=True,
capture_output=True,
text=True,
)
for line in result.stdout.splitlines():
if "NEEDED" in line:
for lib in torch_libs:
self.assertTrue(lib not in line)
def test_unbounded_expr_substitutions(self):
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
u0, s0 = a.item(), b.item()
u_max = max(u0, 15)
# construct the equality rule Max(15, u0) == s0 * Max(15, u0)
torch._check(u_max == s0 * u_max)
# size x - [Max(u0, 15), 64]
x = x.expand(u_max, *x.shape).clone()
return x @ y
model = Model()
example_inputs = (
torch.randn((64,), dtype=torch.bfloat16, device=self.device),
torch.randn((64, 16), dtype=torch.bfloat16, device=self.device),
torch.tensor(19, device=self.device),
torch.tensor(1, device=self.device),
)
torch._dynamo.mark_dynamic(example_inputs[-1], 0)
so_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
compiled = AOTIRunnerUtil.legacy_load(self.device, so_path)
compiled_outputs = compiled(*example_inputs)
eager_outputs = model(*example_inputs)
torch.testing.assert_close(eager_outputs, compiled_outputs)
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)
def test_shape_env_reuse(self, records):
# make sure ShapeEnv is only created once and reused afterwards
class Foo(torch.nn.Module):
def forward(self, x):
return x + 2
inputs = (torch.randn(4, 4),)
dynamic_shapes = {
"x": {0: Dim.AUTO, 1: Dim.AUTO},
}
ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False)
with torch.no_grad():
torch._inductor.aot_compile(ep.module(), inputs)
self.assertEqual([r.msg == "create_env" for r in records].count(True), 1)
@make_logging_test(dynamic=logging.DEBUG)
def test_shape_env_reuse_zero_consts_use_consts_asm_false(self, records):
# make sure ShapeEnv is only created once and reused afterwards
class Foo(torch.nn.Module):
def forward(self, x):
return x + 2
inputs = (torch.randn(4, 4),)
dynamic_shapes = {
"x": {0: Dim.AUTO, 1: Dim.AUTO},
}
ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False)
with (
torch.no_grad(),
config.patch({"aot_inductor.use_consts_asm_build": False}),
):
torch._inductor.aot_compile(ep.module(), inputs)
self.assertEqual([r.msg == "create_env" for r in records].count(True), 1)
class TestAOTInductorConfig(TestCase):
def test_no_compile_standalone(self):
with config.patch({"aot_inductor_mode.compile_standalone": False}):
result = maybe_aoti_standalone_config({})
self.assertEqual(result, {})
def test_compile_standalone_sets_package_cpp(self):
result = maybe_aoti_standalone_config(
{"aot_inductor_mode.compile_standalone": True}
)
self.assertEqual(result["aot_inductor.package_cpp_only"], True)
self.assertEqual(result["aot_inductor_mode.compile_standalone"], True)
self.assertEqual(result["aot_inductor.embed_kernel_binary"], True)
self.assertEqual(
result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip
)
self.assertEqual(
result["aot_inductor.model_name_for_generated_files"], "aoti_model"
)
self.assertEqual(result["aot_inductor.dynamic_linkage"], False)
def test_compile_standalone_explicit_set(self):
patches = {
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.package_cpp_only": True,
"aot_inductor.embed_kernel_binary": True,
"aot_inductor.dynamic_linkage": False,
"aot_inductor.link_libtorch": False,
"aot_inductor.emit_multi_arch_kernel": not torch.version.hip,
"aot_inductor.model_name_for_generated_files": "aoti_model",
}
result = maybe_aoti_standalone_config(patches)
self.assertEqual(result, patches)
def test_compile_standalone_package_cpp_false_raises(self):
patches = {
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.package_cpp_only": False,
}
with self.assertRaises(RuntimeError):
maybe_aoti_standalone_config(patches)
with config.patch({"aot_inductor.package_cpp_only": False}):
patches = {
"aot_inductor_mode.compile_standalone": True,
}
with self.assertRaises(RuntimeError):
maybe_aoti_standalone_config(patches)
def test_compile_standalone_cross_compile_windows_package_format(self):
patches = {
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.package_constants_in_so": True,
}
with self.assertRaises(RuntimeError):
maybe_aoti_standalone_config(patches)
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
def fail_cpu(is_skip=False):
return TestFailure(
("cpu",),
is_skip=is_skip,
)
def fail_mps(is_skip=False):
return TestFailure(
("mps",),
is_skip=is_skip,
)
def fail_gpu(suffixes: tuple[str, ...], is_skip=False):
return TestFailure(
suffixes,
is_skip=is_skip,
)
# test_failures, xfail by default, set is_skip=True to skip
CPU_TEST_FAILURES = {
# TODO: failed internally
"test_multiple_output_alias": fail_cpu(is_skip=True),
}
# test_failures, xfail by default, set is_skip=True to skip
GPU_TEST_FAILURES = {
# quantized unsupported for GPU
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
"test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")),
# No scaled_dot_product_efficient_attention implementation for XPU yet.
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
# No fft implementation for XPU yet.
"test_fft_c2c": fail_gpu(("xpu",), is_skip=True),
}
MPS_TEST_FAILURES = {
# aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device.
"test_scaled_dot_product_efficient_attention": fail_mps(),
# aten::_int_mm is not implemented for MPS backend
"test__int_mm": fail_mps(),
# MPS doesn't support float64
"test_while_loop_with_conv_dynamic_True": fail_mps(),
"test_while_loop_with_conv_dynamic_False": fail_mps(),
# MPS doesn't support float8
"test_fp8": fail_mps(),
"test_fp8_view_of_param": fail_mps(),
# cannot initialize a parameter of type 'double' with an rvalue of type 'std::nullptr_t'
"test_fallback_kernel_with_symexpr_output": fail_mps(),
# correctness issue
"test_index_put_with_none_index": fail_mps(),
# Error device may not be nil
"test_zero_size_weight": fail_mps(is_skip=True),
# MPSGraph does not support tensor dims > INT_MAX
"test_upper_bound_i64": fail_mps(is_skip=True),
# MPS doesn't support triton
"test_autotuning_args_reuse": fail_mps(),
"test_triton_autotuning": fail_mps(),
"test_triton_dynamic_launcher_grid": fail_mps(),
"test_triton_dynamic_launcher_grid_infer_from_tensor": fail_mps(),
"test_triton_kernel_on_device_tma_dynamic_False_tma_version_new": fail_mps(),
"test_triton_kernel_on_device_tma_dynamic_False_tma_version_old": fail_mps(),
"test_triton_kernel_on_device_tma_dynamic_True_tma_version_new": fail_mps(),
"test_triton_kernel_on_device_tma_dynamic_True_tma_version_old": fail_mps(),
"test_size_with_unbacked_add_expr_transitive": fail_mps(),
"test_size_with_unbacked_add_and_mul_expr": fail_mps(),
"test_triton_next_power_of_2": fail_mps(),
"test_sympy_cpp_printer_min_max_minmax0": fail_mps(),
"test_sympy_cpp_printer_min_max_minmax1": fail_mps(),
"test_triton_kernel_dynamic_shape_with_div": fail_mps(),
"test_triton_kernel_reinterpret_view": fail_mps(),
"test_triton_kernel_tma_descriptor_1d_dynamic_False_tma_version_new_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_1d_dynamic_False_tma_version_old_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_1d_dynamic_True_tma_version_new_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_1d_dynamic_True_tma_version_old_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_2d_dynamic_False_tma_version_new_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_2d_dynamic_False_tma_version_old_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_2d_dynamic_True_tma_version_new_mps": fail_mps(),
"test_triton_kernel_tma_descriptor_2d_dynamic_True_tma_version_old_mps": fail_mps(),
"test_triton_kernel_sympy_expr_arg": fail_mps(),
"test_triton_kernel_sympy_fn_like_arg": fail_mps(),
"test_triton_kernel_with_none_input": fail_mps(),
"test_triton_kernel_equal_to_1_arg": fail_mps(),
"test_triton_kernel_with_none_inputs_and_equal_to_1_arg": fail_mps(),
"test_triton_kernel_equal_to_1_float_arg_dynamic_True": fail_mps(),
"test_triton_kernel_equal_to_1_float_arg_dynamic_False": fail_mps(),
"test_triton_kernel_weird_param_order": fail_mps(),
"test_triton_kernel_dynamic_grid": fail_mps(),
"test_repeated_user_defined_triton_kernel_embed_kernel_binary_False": fail_mps(),
"test_repeated_user_defined_triton_kernel_embed_kernel_binary_True": fail_mps(),
"test_triton_kernel_extern_kernel_arg": fail_mps(),
"test_triton_kernel_multi_output_arg": fail_mps(),
"test_triton_kernel_reinterpret_view_mem_leak": fail_mps(),
"test_triton_mutated_autotuning": fail_mps(),
"test_sym_i64_input_codegen": fail_mps(),
"test_none_args_aot_codegen": fail_mps(),
"test_aoti_debug_printer_sym_inputs": fail_mps(),
"test_aoti_debug_printer_user_defined_triton_kernel": fail_mps(),
"test_autotune_int64_user_defined_triton_kernel": fail_mps(),
}
class AOTInductorTestABICompatibleCpu(TestCase):
device = "cpu"
device_type = "cpu"
check_model = check_model
check_model_with_multiple_inputs = check_model_with_multiple_inputs
code_check_count = code_check_count
allow_stack_allocation = False
use_minimal_arrayref_interface = False
copy_tests(
AOTInductorTestsTemplate,
AOTInductorTestABICompatibleCpu,
"cpu",
CPU_TEST_FAILURES,
)
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
class AOTInductorTestABICompatibleGpu(TestCase):
device = GPU_TYPE
device_type = GPU_TYPE
check_model = check_model
check_model_with_multiple_inputs = check_model_with_multiple_inputs
code_check_count = code_check_count
allow_stack_allocation = False
use_minimal_arrayref_interface = False
copy_tests(
AOTInductorTestsTemplate,
AOTInductorTestABICompatibleGpu,
GPU_TYPE,
GPU_TEST_FAILURES,
)
@unittest.skipIf(not torch.backends.mps.is_available(), "No MPS backend available")
class AOTInductorTestABICompatibleMps(TestCase):
device = "mps"
device_type = "mps"
check_model = check_model
check_model_with_multiple_inputs = check_model_with_multiple_inputs
code_check_count = code_check_count
allow_stack_allocation = False
use_minimal_arrayref_interface = False
copy_tests(
AOTInductorTestsTemplate,
AOTInductorTestABICompatibleMps,
"mps",
MPS_TEST_FAILURES,
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
# cpp_extension N/A in fbcode
if HAS_GPU or sys.platform == "darwin":
run_tests(needs="filelock")