mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156496 Approved by: https://github.com/albanD
6546 lines
235 KiB
Python
6546 lines
235 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
import zipfile
|
|
from unittest import skip
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._export
|
|
import torch._inductor
|
|
import torch._inductor.config
|
|
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
|
import torch.nn as nn
|
|
from torch._dynamo import config as dynamo_config
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
from torch._dynamo.testing import rand_strided, same
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor import config
|
|
from torch._inductor.package import package_aoti
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.utils import is_big_gpu, run_and_get_cpp_code
|
|
from torch._utils_internal import full_aoti_runtime_assert
|
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
|
from torch.export import Dim, export, export_for_training
|
|
from torch.export.pt2_archive._package import load_pt2
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal import common_utils
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
|
|
from torch.testing._internal.common_device_type import (
|
|
_has_sufficient_memory,
|
|
skipCUDAIf,
|
|
)
|
|
from torch.testing._internal.common_quantization import (
|
|
_group_quantize_tensor,
|
|
skip_if_no_torchvision,
|
|
skipIfNoFBGEMM,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
DeterministicGuard,
|
|
IS_CI,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
skipIfRocm,
|
|
skipIfXpu,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
from torch.testing._internal.custom_tensor import CustomTensorPlainOut
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
from torch.testing._internal.triton_utils import requires_gpu
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._triton import (
|
|
has_triton_experimental_host_tma,
|
|
has_triton_tensor_descriptor_host_tma,
|
|
)
|
|
|
|
|
|
if HAS_GPU:
|
|
import triton # @manual
|
|
from triton import language as tl
|
|
|
|
from torch.testing._internal.triton_utils import (
|
|
add_kernel,
|
|
add_kernel_2d_autotuned,
|
|
add_kernel_autotuned,
|
|
add_kernel_autotuned_weird_param_order,
|
|
add_kernel_on_device_tma_new_api,
|
|
add_kernel_on_device_tma_old_api,
|
|
add_kernel_with_none_param_and_equal_to_1_arg,
|
|
add_kernel_with_optional_param,
|
|
add_kernel_with_scaling,
|
|
add_kernel_with_tma_1d_new_api,
|
|
add_kernel_with_tma_1d_old_api,
|
|
add_kernel_with_tma_2d_new_api,
|
|
add_kernel_with_tma_2d_old_api,
|
|
create_tensor_descriptor_shim,
|
|
mul2_inplace_kernel,
|
|
strange_config_matmul_kernel,
|
|
sub_kernel_autotuned,
|
|
)
|
|
|
|
if IS_WINDOWS and IS_CI:
|
|
sys.stderr.write(
|
|
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
|
|
)
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
|
|
|
try:
|
|
try:
|
|
from .test_aot_inductor_utils import (
|
|
AOTIRunnerUtil,
|
|
check_model,
|
|
check_model_with_multiple_inputs,
|
|
code_check_count,
|
|
)
|
|
from .test_control_flow import (
|
|
CondModels,
|
|
prepend_counters,
|
|
prepend_predicates,
|
|
WhileLoopModels,
|
|
)
|
|
from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
|
|
except ImportError:
|
|
from test_aot_inductor_utils import ( # @manual=fbcode//caffe2/test/inductor:aot_inductor_utils-library
|
|
AOTIRunnerUtil,
|
|
check_model,
|
|
check_model_with_multiple_inputs,
|
|
code_check_count,
|
|
)
|
|
from test_control_flow import ( # @manual=fbcode//caffe2/test/inductor:control_flow-library
|
|
CondModels,
|
|
prepend_counters,
|
|
prepend_predicates,
|
|
WhileLoopModels,
|
|
)
|
|
from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
|
copy_tests,
|
|
requires_multigpu,
|
|
TestFailure,
|
|
)
|
|
except (unittest.SkipTest, ImportError):
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
|
|
class AOTInductorTestsTemplate:
|
|
@common_utils.parametrize("embed_kernel_binary", [False, True])
|
|
@common_utils.parametrize("max_autotune", [False, True])
|
|
def test_simple(self, embed_kernel_binary, max_autotune):
|
|
if self.device == "cpu" and IS_MACOS and max_autotune:
|
|
raise unittest.SkipTest("max_autotune not supported on macos")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
model = Model()
|
|
with config.patch(
|
|
{
|
|
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
|
|
"max_autotune": max_autotune,
|
|
}
|
|
):
|
|
self.check_model(model, example_inputs)
|
|
|
|
_, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
if self.device == GPU_TYPE:
|
|
FileCheck().check("launchKernel(").run(code)
|
|
if config.aot_inductor.embed_kernel_binary:
|
|
# Not expect to see launchKernel("CUBIN_FILE_NAME"
|
|
FileCheck().check_not('launchKernel("').run(code)
|
|
|
|
if self.use_minimal_arrayref_interface:
|
|
self.code_check_count(
|
|
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"toolchain doesn't support ptx to fatbin",
|
|
)
|
|
@skipIfRocm
|
|
@common_utils.parametrize("embed_kernel_binary", [True, False])
|
|
def test_simple_multi_arch(self, embed_kernel_binary):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU_TYPE")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 16)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 16, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
model = Model()
|
|
with config.patch(
|
|
{
|
|
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
|
|
"aot_inductor.emit_multi_arch_kernel": True,
|
|
}
|
|
):
|
|
self.check_model(model, example_inputs)
|
|
if not embed_kernel_binary:
|
|
_, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
file_extension = ".spv" if self.device == "xpu" else ".fatbin"
|
|
FileCheck().check(file_extension).run(code)
|
|
|
|
def test_small_constant(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"always_keep_tensor_constants": True}):
|
|
self.check_model(Model().to(self.device), example_inputs)
|
|
|
|
def test_output_path_1(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch("aot_inductor.output_path", "tmp_output_"):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_output_path_2(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
model = Model().to(device=self.device)
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
expected_path = os.path.join(tempfile.mkdtemp(dir=cache_dir()), "model.so")
|
|
actual_path = AOTIRunnerUtil.legacy_compile(
|
|
model, example_inputs, options={"aot_inductor.output_path": expected_path}
|
|
)
|
|
self.assertTrue(actual_path == expected_path)
|
|
|
|
def test_empty_constant_folding(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w = torch.randn(4, 4, device=device)
|
|
self.b = torch.randn(4, device=device)
|
|
|
|
def forward(self, x):
|
|
return torch.matmul(x, self.w) + self.b
|
|
|
|
model = Model(self.device)
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
so_path, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.legacy_compile, model, example_inputs
|
|
)
|
|
# We should have 1 input, 1 output, 2 constants for the model.
|
|
FileCheck().check_count("AOTInductorModelBase(1,", 1).check_next(
|
|
"1,"
|
|
).check_next("2,").run(code)
|
|
|
|
def test_constant_folding(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w_pre = torch.randn(4, 4, device=device)
|
|
self.b = torch.randn(4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_transpose = torch.transpose(self.w_pre, 0, 1)
|
|
w_relu = torch.nn.functional.relu(w_transpose)
|
|
w = w_relu + self.b
|
|
return torch.matmul(x, w)
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_constant_folding_with_update(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w_pre = torch.randn(4, 4, device=device)
|
|
self.b = torch.randn(4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_transpose = torch.transpose(self.w_pre, 0, 1)
|
|
w_relu = torch.nn.functional.relu(w_transpose)
|
|
w = w_relu + self.b
|
|
return torch.matmul(x, w)
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"always_keep_tensor_constants": True,
|
|
"aot_inductor.use_runtime_constant_folding": True,
|
|
}
|
|
):
|
|
model = Model(self.device)
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(4, 4, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
# Update with new weights on active buffer
|
|
new_weights = {
|
|
"L__self___b": torch.randn(4, device=self.device),
|
|
"L__self___w_pre": torch.randn(4, 4, device=self.device),
|
|
}
|
|
model.w_pre = new_weights["L__self___w_pre"]
|
|
model.b = new_weights["L__self___b"]
|
|
expected = model(test_inputs)
|
|
runner.update_constant_buffer(new_weights, False, False)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
# Update with new weights on inactive buffer
|
|
new_weights = {
|
|
"L__self___b": torch.randn(4, device=self.device),
|
|
"L__self___w_pre": torch.randn(4, 4, device=self.device),
|
|
}
|
|
model.w_pre = new_weights["L__self___w_pre"]
|
|
model.b = new_weights["L__self___b"]
|
|
expected = model(test_inputs)
|
|
runner.update_constant_buffer(new_weights, True, False)
|
|
new_output = runner_call(test_inputs)
|
|
# We have not yet swapped the buffer, new_output should be the same as the old one.
|
|
self.assertEqual(output, new_output)
|
|
# Swap the buffer, should get the correct result now.
|
|
runner.swap_constant_buffer()
|
|
new_output = runner_call(test_inputs)
|
|
self.assertEqual(expected, new_output)
|
|
|
|
@requires_gpu
|
|
def test_duplicate_constant_folding(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w1 = torch.randn(4, 4, device=device)
|
|
self.w2 = torch.randn(4, 4, device=device)
|
|
self.w3 = torch.randn(4, 4, device=device)
|
|
self.w4 = torch.randn(4, 4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_concat = torch.cat((self.w1, self.w2, self.w3, self.w4))
|
|
return torch.cat((x, w_concat))
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_autotune_with_constant_folding(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device) -> None:
|
|
super().__init__()
|
|
self.x = torch.randn(2048, 2048, dtype=torch.float16, device=device)
|
|
|
|
def _quantize(self, input):
|
|
return torch.abs(input)
|
|
|
|
def forward(self, y):
|
|
abs_weight = self._quantize(self.x)
|
|
abs_y = self._quantize(y)
|
|
|
|
return abs_weight, abs_y
|
|
|
|
input1 = (torch.rand(2048, 2048, dtype=torch.float16, device=self.device),)
|
|
model = Model(self.device).to(self.device)
|
|
|
|
_ = model(*input1)
|
|
|
|
ep = torch.export.export(model, input1, dynamic_shapes=None, strict=False)
|
|
torch._inductor.aoti_compile_and_package(
|
|
ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True}
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("tma_version", ["new", "old"])
|
|
def test_triton_kernel_on_device_tma(self, dynamic, tma_version):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
if tma_version == "new" and not has_triton_tensor_descriptor_host_tma():
|
|
self.skipTest("requires triton.tools.tensor_descriptor TMA support")
|
|
if tma_version == "old" and not has_triton_experimental_host_tma():
|
|
self.skipTest("requires triton.tools.experimental_descriptor TMA support")
|
|
|
|
kernel = (
|
|
add_kernel_on_device_tma_new_api
|
|
if tma_version == "new"
|
|
else add_kernel_on_device_tma_old_api
|
|
)
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
BLOCK_SIZE = 32
|
|
out = torch.zeros_like(a)
|
|
m, n = out.size()
|
|
|
|
# Allocate workspace for on-device TMA descriptors
|
|
# Need 128 bytes per descriptor, 3 descriptors total
|
|
workspace = torch.zeros(3 * 128, dtype=torch.uint8, device=a.device)
|
|
|
|
kernel[(1,)](
|
|
a,
|
|
b,
|
|
out,
|
|
m,
|
|
n,
|
|
workspace,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn((32, 32), device=self.device)
|
|
b = torch.randn((32, 32), device=self.device)
|
|
example_inputs = (a, b)
|
|
|
|
triton.set_allocator(
|
|
lambda size, align, stream: torch.empty(
|
|
size, dtype=torch.int8, device="cuda"
|
|
)
|
|
)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs=example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_multi_device(self):
|
|
if self.device == "cpu" and GPU_TYPE == "xpu":
|
|
raise unittest.SkipTest(
|
|
"In this scenario, the test case will run XPU code in "
|
|
"AOTIModelContainerRunnerCpu, which is not reasonable,"
|
|
"See issue #140805"
|
|
)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 1
|
|
x = x.cpu()
|
|
x = x + 2
|
|
x = x.to(GPU_TYPE)
|
|
return x
|
|
|
|
example_inputs = (torch.randn(32, 64, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_large_weight(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2048, 262144)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 262144, device=self.device),
|
|
torch.randn(1, 2048, device=self.device),
|
|
)
|
|
|
|
# We only test compilation since we often get OOM running in CI.
|
|
model = Model()
|
|
model = model.to(self.device)
|
|
AOTIRunnerUtil.compile(model, example_inputs)
|
|
|
|
def test_constant_type_propagation(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w_pre = torch.randn(4, 4, device=device)
|
|
self.b = torch.randn(4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_transpose = torch.transpose(self.w_pre, 0, 1)
|
|
w_relu = torch.nn.functional.relu(w_transpose)
|
|
w = w_relu + self.b
|
|
return torch.matmul(x, w)
|
|
|
|
model = Model(self.device)
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
so_path, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.legacy_compile, model, example_inputs
|
|
)
|
|
FileCheck().check_not("torch::aot_inductor::ConstantType::Unknown").run(
|
|
code
|
|
)
|
|
|
|
def test_subclasses(self):
|
|
device_to_init = self.device
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.p1 = torch.nn.Parameter(torch.ones(3, 4, device=device_to_init))
|
|
self.p2 = torch.nn.Parameter(
|
|
CustomTensorPlainOut(
|
|
torch.ones(3, 4, device=device_to_init),
|
|
torch.ones(3, 4, device=device_to_init),
|
|
)
|
|
)
|
|
|
|
def forward(self, x):
|
|
a = (2 * self.p1 + self.p2).sum()
|
|
return x + a
|
|
|
|
m = Foo()
|
|
ref_x = torch.randn(3, 4, device=device_to_init)
|
|
|
|
with torch.no_grad():
|
|
result = AOTIRunnerUtil.run(
|
|
m,
|
|
(ref_x,),
|
|
)
|
|
actual = m(ref_x)
|
|
self.assertTrue(same(result, actual))
|
|
|
|
def test_large_mmaped_weights(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(512, 250112)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 250112, device=self.device),
|
|
torch.randn(1, 512, device=self.device),
|
|
)
|
|
with config.patch({"aot_inductor.force_mmap_weights": True}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_with_offset(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.orig_tensor = torch.randn(2, 15, 10, device=device)[0]
|
|
self.tensor = self.orig_tensor[5:, :]
|
|
|
|
def forward(self, x, y):
|
|
return (
|
|
x
|
|
+ torch.nn.functional.linear(y, self.orig_tensor[:10, :])
|
|
+ self.tensor
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_freezing(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(9, 10, device=device)
|
|
self.padding = torch.randn(1, 10, device=device)
|
|
|
|
def forward(self, x, y):
|
|
padded_weight = torch.cat((self.weight, self.padding), dim=0)
|
|
return x + torch.nn.functional.linear(y, padded_weight)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_conv_freezing(self):
|
|
dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float]
|
|
for dtype, groups in itertools.product(dtypes, [1, 2]):
|
|
iC = 2
|
|
oC = 3
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
|
|
dtype
|
|
)
|
|
|
|
def forward(self, y):
|
|
return torch.nn.functional.conv2d(y, self.weight, groups=groups)
|
|
|
|
example_inputs = (
|
|
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
|
|
)
|
|
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_deconv_freezing(self):
|
|
dtypes = [torch.float]
|
|
if torch._C._has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
for dtype, groups in itertools.product(dtypes, [2, 1]):
|
|
iC = 4
|
|
oC = 2
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
|
|
dtype
|
|
)
|
|
|
|
def forward(self, y):
|
|
return torch.nn.functional.conv_transpose2d(
|
|
y, self.weight, groups=groups
|
|
)
|
|
|
|
example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_linear_freezing(self):
|
|
dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float]
|
|
for dtype in dtypes:
|
|
|
|
class LinearModel(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10, device=device).to(dtype)
|
|
self.bias = torch.randn(10, device=device).to(dtype)
|
|
|
|
def forward(self, y):
|
|
return torch.nn.functional.linear(y, self.weight, self.bias)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
|
|
|
|
with config.patch({"freezing": True}):
|
|
model = LinearModel(device=self.device)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_same_backing(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo2",
|
|
"(Tensor a, Tensor b) -> Tensor",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo2", "CompositeExplicitAutograd", lib=lib)
|
|
def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
return a + b
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
x = a.shape[0]
|
|
y = b.shape[0]
|
|
a = torch.cat([a, a])
|
|
a = torch.ops.mylib.foo2(a, a)
|
|
a = a * x
|
|
b = torch.cat([b, b])
|
|
b = torch.ops.mylib.foo2(b, b)
|
|
b = b * y
|
|
return a, b
|
|
|
|
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
|
self.check_model(M(), inp)
|
|
|
|
def test_empty_cat_dtype_promotion(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
z = torch.cat([x, y], dim=1)
|
|
z = z.to(dtype=torch.bfloat16)
|
|
return z * 2
|
|
|
|
model = Foo()
|
|
inps = (torch.randn(4, 10, dtype=torch.bfloat16), torch.randn(4, 0))
|
|
self.check_model(model, inps)
|
|
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
|
)
|
|
def test_linear_dynamic_maxautotune(self):
|
|
if self.device == "cpu":
|
|
raise unittest.SkipTest("using triton backend only is not supported on CPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
model = Model().to(device=self.device)
|
|
compile_inputs = (torch.randn(2048, 1, device=self.device),)
|
|
dim0_x = Dim("dim0_x", min=2, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}}
|
|
ep = torch.export.export(
|
|
model, compile_inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
)
|
|
optimized = torch._inductor.aoti_load_package(
|
|
torch._inductor.aoti_compile_and_package(
|
|
ep,
|
|
inductor_configs={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
)
|
|
runtime_input = torch.randn(10, 1, device=self.device)
|
|
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
|
|
runtime_input = torch.randn(16, 1, device=self.device)
|
|
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
|
|
runtime_input = torch.randn(100, 1, device=self.device)
|
|
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
|
|
|
|
@torch._inductor.config.patch(
|
|
pre_grad_fusion_options={
|
|
"normalization_pass": {},
|
|
"remove_split_with_size_one_pass": {},
|
|
"merge_getitem_cat_pass": {},
|
|
"merge_stack_tahn_unbind_pass": {},
|
|
"merge_splits_pass": {},
|
|
"mutate_cat_pass": {},
|
|
"split_cat_pass": {},
|
|
"unbind_stack_pass": {},
|
|
},
|
|
post_grad_fusion_options={},
|
|
)
|
|
def test_simple_split(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
|
|
|
|
example_inputs = (torch.randn(2, 8, device=self.device),)
|
|
counters.clear()
|
|
model = Model().to(device=self.device)
|
|
actual = AOTIRunnerUtil.legacy_run(self.device, model, example_inputs)
|
|
self.assertTrue(same(model(*example_inputs), actual))
|
|
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
|
|
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
|
|
self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
|
|
|
|
def test_amp_fallback_random(self):
|
|
def fn(x, w):
|
|
return torch.functional.F.linear(x, w)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch({"fallback_random": True}):
|
|
with torch.amp.autocast(device_type=self.device):
|
|
self.check_model(fn, example_inputs)
|
|
|
|
def test_missing_output(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.sin(x)
|
|
b = torch.mm(a, y)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_output_misaligned(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x_unsqueeze = torch.unsqueeze(x, dim=0)
|
|
y_unsqueeze = torch.unsqueeze(y, dim=0)
|
|
cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
|
|
x_getitem = cat[0]
|
|
y_getitem = cat[1]
|
|
x_sigmoid = torch.sigmoid(x_getitem)
|
|
return x_sigmoid, y_getitem
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
|
)
|
|
@skip("Test was marked as expected failure, but does not fail always anymore.")
|
|
def test_dynamic_smem_above_default_limit(self):
|
|
if self.device == "cpu":
|
|
raise unittest.SkipTest("using triton backend only is not supported on CPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x @ y
|
|
|
|
model = Model().to(self.device)
|
|
# on A100, the generated Triton kernel for this MM
|
|
# requires 55296 bytes of dynamic SMEM which is above
|
|
# the A100's default dynamic SMEM limit of 49152 bytes.
|
|
example_inputs = (
|
|
torch.randn(10285, 96, device=self.device),
|
|
torch.randn(96, 1, device=self.device),
|
|
)
|
|
self.check_model(
|
|
model,
|
|
example_inputs,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
|
def test_seq(self):
|
|
layernorm = torch.nn.LayerNorm(10)
|
|
net = torch.nn.Sequential(
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
example_inputs = (torch.randn(10, device=self.device),)
|
|
self.check_model(net.eval(), example_inputs)
|
|
|
|
def test_addmm(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
# We should be able to call self.check_model here, but torch.export.export
|
|
# constants (non-parameter, non-buffer) doesn't work today.
|
|
example_inputs = (a,)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_aliased_buffer_reuse(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x = 2 * x
|
|
y = 2 * y
|
|
c = torch.cat([x, y], dim=-1)
|
|
d = 1 + c
|
|
m = torch.mm(d, d)
|
|
return m[:, :2] + x
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 2, device=self.device),
|
|
torch.randn(4, 2, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_buffer_reuse(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.sin(x)
|
|
b = torch.cos(y)
|
|
c = torch.mm(a, b)
|
|
d = torch.relu(c)
|
|
e = torch.sigmoid(d)
|
|
f = torch.mm(x, y)
|
|
g = e + f
|
|
return g
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=self.device),
|
|
torch.randn(4, 4, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_duplicated_params(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.rand(6))
|
|
self.q = self.p
|
|
|
|
def forward(self, x):
|
|
return self.p * x + self.q
|
|
|
|
example_inputs = (torch.rand(6, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
|
|
def test_inf(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
x = torch.randn(10, 10, device=self.device)
|
|
x[0][0] = float("Inf")
|
|
example_inputs = (
|
|
x,
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model().to(self.device),
|
|
example_inputs,
|
|
options={"debug_check_inf_and_nan": True},
|
|
)
|
|
|
|
@unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
|
|
def test_nan(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
x = torch.randn(10, 10, device=self.device)
|
|
x[0][0] = float("nan")
|
|
example_inputs = (
|
|
x,
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model().to(self.device),
|
|
example_inputs,
|
|
options={"debug_check_inf_and_nan": True},
|
|
)
|
|
|
|
def test_assert_async(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU_TYPE")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
u0 = x.item()
|
|
torch._check(u0 > 3)
|
|
return torch.ones(u0)[0]
|
|
|
|
x = torch.tensor(23, device=self.device)
|
|
example_inputs = (x,)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_simple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
add_0 = x + y
|
|
return torch.nn.functional.relu(input=add_0, inplace=False)
|
|
|
|
x = torch.randn(128, 2048, device=self.device)
|
|
y = torch.randn(128, 2048, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
|
example_inputs = (x, y)
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
|
@skipIfXpu
|
|
def test_fp8(self):
|
|
# cuda only
|
|
if self.device != "cuda":
|
|
return
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, dtype):
|
|
super().__init__()
|
|
self.out_dtype = dtype
|
|
|
|
def forward(self, x, weight, bias, scale_a, scale_b):
|
|
weight = weight.to(torch.float8_e4m3fn)
|
|
output = torch._scaled_mm(
|
|
x,
|
|
weight,
|
|
bias=input_bias,
|
|
out_dtype=self.out_dtype,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
return output
|
|
|
|
dtype = torch.float16
|
|
|
|
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
|
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
|
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
|
|
weight_shape = (32, 16)
|
|
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
|
|
a_inverse_scale = 1 / a_scale
|
|
b_inverse_scale = 1 / b_scale
|
|
|
|
x_shape = (16, 16)
|
|
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
|
|
self.check_model(
|
|
Model(dtype),
|
|
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
|
@skipIfXpu
|
|
def test_fp8_view_of_param(self):
|
|
# cuda only
|
|
if self.device != GPU_TYPE:
|
|
return
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, dtype, weight):
|
|
super().__init__()
|
|
self.out_dtype = dtype
|
|
self.weight = weight
|
|
|
|
def forward(self, x, bias, scale_a, scale_b):
|
|
# test: do the view inside of the graph,
|
|
# AOTI needs to materialize this view before passing
|
|
# it into the scaled_mm extern kernel
|
|
weight = self.weight.T
|
|
output = torch._scaled_mm(
|
|
x,
|
|
weight,
|
|
bias=input_bias,
|
|
out_dtype=self.out_dtype,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
return output
|
|
|
|
dtype = torch.float16
|
|
|
|
a_scale = torch.Tensor([1.0]).to(device=self.device)
|
|
b_scale = torch.Tensor([1.0]).to(device=self.device)
|
|
input_bias = torch.rand(32, device=self.device, dtype=dtype)
|
|
weight_shape = (32, 16)
|
|
weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).to(
|
|
torch.float8_e4m3fn
|
|
)
|
|
a_inverse_scale = 1 / a_scale
|
|
b_inverse_scale = 1 / b_scale
|
|
|
|
x_shape = (16, 16)
|
|
x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(
|
|
torch.float8_e4m3fn
|
|
)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = ({0: dim0_x}, None, None, None)
|
|
self.check_model(
|
|
Model(dtype, weight),
|
|
(x, input_bias, a_inverse_scale, b_inverse_scale),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_poi_multiple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
add_0 = x + y
|
|
return torch.nn.functional.relu(input=add_0, inplace=False)
|
|
|
|
x = torch.randn(128, 2048, device=self.device)
|
|
y = torch.randn(128, 2048, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
|
list_example_inputs = [(x, y)]
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(64, 2048, device=self.device),
|
|
torch.randn(64, 2048, device=self.device),
|
|
),
|
|
)
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(211, 2048, device=self.device),
|
|
torch.randn(211, 2048, device=self.device),
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
Model(), list_example_inputs, dynamic_shapes=dynamic_shapes
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
|
)
|
|
def test_addmm_multiple_dynamic(self):
|
|
if self.device == "cpu":
|
|
raise unittest.SkipTest("using triton backend only is not supported on CPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=2048)
|
|
dynamic_shapes = {"a": {0: dim0_a}}
|
|
list_example_inputs = [(a,)]
|
|
batch = 2048
|
|
list_example_inputs.append(
|
|
(torch.randn(batch, M, K, device=self.device),),
|
|
)
|
|
batch = 128
|
|
list_example_inputs.append(
|
|
(torch.randn(batch, M, K, device=self.device),),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
model,
|
|
list_example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
|
)
|
|
def test_bmm_multiple_dynamic(self):
|
|
if self.device == "cpu":
|
|
raise unittest.SkipTest("using triton backend only is not supported on CPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
return torch.bmm(a, b)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model()
|
|
batch = 1024
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
b = torch.randn(batch, K, N, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=2048)
|
|
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_a}}
|
|
list_example_inputs = [(a, b)]
|
|
batch = 2048
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(batch, M, K, device=self.device),
|
|
torch.randn(batch, K, N, device=self.device),
|
|
),
|
|
)
|
|
batch = 128
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(batch, M, K, device=self.device),
|
|
torch.randn(batch, K, N, device=self.device),
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
model,
|
|
list_example_inputs,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_foreach_multiple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x_unsqueeze = torch.unsqueeze(x, dim=0)
|
|
y_unsqueeze = torch.unsqueeze(y, dim=0)
|
|
cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
|
|
return cat
|
|
|
|
model = Model()
|
|
x = torch.randn(128, 2048, device=self.device)
|
|
y = torch.randn(128, 2048, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
|
list_example_inputs = [(x, y)]
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(64, 2048, device=self.device),
|
|
torch.randn(64, 2048, device=self.device),
|
|
),
|
|
)
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(211, 2048, device=self.device),
|
|
torch.randn(211, 2048, device=self.device),
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
model,
|
|
list_example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
# scaled_dot_product_flash_attention
|
|
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
|
|
def test_sdpa(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, q, k, v):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
|
|
def test_sdpa_2(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, q, k, v, x):
|
|
t = torch.nn.functional.scaled_dot_product_attention(
|
|
q, k, v, is_causal=True
|
|
)[0]
|
|
return x + t
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_linear(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10, device=device)
|
|
self.bias = torch.randn(10, device=device)
|
|
|
|
def forward(self, x):
|
|
return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
|
|
x, self.weight, self.bias
|
|
)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quanatized_int8_linear(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10, device=device)
|
|
self.bias = torch.randn(10, device=device)
|
|
self.input_scale = torch.tensor(0.1)
|
|
self.input_zero_point = torch.tensor(0)
|
|
self.weight_scale = torch.tensor(0.1)
|
|
self.weight_zero_point = torch.tensor(0)
|
|
self.output_scale = torch.tensor(0.1)
|
|
self.output_zero_point = torch.tensor(0)
|
|
self.out_channel = 10
|
|
|
|
def forward(self, x):
|
|
return torch.ops._quantized.wrapped_quantized_linear(
|
|
x,
|
|
self.input_scale,
|
|
self.input_zero_point,
|
|
self.weight,
|
|
self.weight_scale,
|
|
self.weight_zero_point,
|
|
self.bias,
|
|
self.output_scale,
|
|
self.output_zero_point,
|
|
self.out_channel,
|
|
)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_zero_grid_with_unbacked_symbols(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones_like(nz, dtype=torch.float16)
|
|
c = torch.zeros_like(nz, dtype=torch.float16)
|
|
d = (b + c) @ y
|
|
return d.sum()
|
|
|
|
example_inputs = (
|
|
torch.tensor([1, 1, 1], device=self.device),
|
|
torch.randn((1, 32), dtype=torch.float16, device=self.device),
|
|
)
|
|
self.check_model(Repro(), example_inputs)
|
|
|
|
@config.patch({"triton.autotune_at_compile_time": None})
|
|
def test_stride_with_unbacked_expr(self):
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
u0 = x.item()
|
|
torch._check(u0 >= 1)
|
|
s0 = y.size(0)
|
|
expr = u0 * s0
|
|
sevens = torch.empty_strided(
|
|
size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device
|
|
).fill_(7)
|
|
return sevens * 3
|
|
|
|
example_inputs = (
|
|
torch.scalar_tensor(2, dtype=torch.int, device=self.device),
|
|
torch.ones(8, device=self.device),
|
|
)
|
|
self.check_model(Repro(), example_inputs)
|
|
|
|
def test_size_with_unbacked_add_expr(self):
|
|
# Tests AOTI autotuning to make sure the correct input tensor sizes
|
|
# are generated for sizes that include an expr such as s0 + u0.
|
|
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, values, repeats, mask, embeddings, x, z, scalar):
|
|
repeat_interleave = torch.repeat_interleave(values, repeats)
|
|
index = torch.clamp(repeat_interleave, min=0, max=400).int()
|
|
index_select = torch.index_select(embeddings, 0, index)
|
|
|
|
backed = z.size(0)
|
|
unbacked = scalar.item()
|
|
torch._check_is_size(unbacked)
|
|
|
|
unbacked_add_expr = backed + unbacked
|
|
repeated = x.repeat(unbacked_add_expr, 1)
|
|
return torch.cat([repeated, index_select], dim=1)
|
|
|
|
example_inputs = (
|
|
torch.ones(64, dtype=torch.int64, device=self.device),
|
|
torch.ones(64, dtype=torch.int64, device=self.device) * 12,
|
|
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
|
|
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
|
|
torch.randn((1, 256), dtype=torch.bfloat16, device=self.device),
|
|
torch.ones(758, 127, dtype=torch.int64, device=self.device),
|
|
torch.scalar_tensor(10, dtype=torch.int32, device=self.device),
|
|
)
|
|
spec = {
|
|
"values": (Dim.DYNAMIC,),
|
|
"repeats": (Dim.DYNAMIC,),
|
|
"mask": (Dim.DYNAMIC,),
|
|
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
|
|
"x": (Dim.STATIC, Dim.STATIC),
|
|
"z": (Dim.DYNAMIC, Dim.STATIC),
|
|
"scalar": (),
|
|
}
|
|
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
|
|
|
def test_size_with_unbacked_add_expr_transitive(self):
|
|
# Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked).
|
|
# When generating example input sizes for autotuning, it should coalesce
|
|
# expr1, expr2, unbacked into a single size.
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
|
index = torch.repeat_interleave(values, repeats)
|
|
index_select = torch.index_select(embeddings, 0, index)
|
|
|
|
u0, u1 = lst.tolist()
|
|
torch._check_is_size(u0)
|
|
torch._check_is_size(u1)
|
|
backed0, backed1 = z.size(0), z.size(1)
|
|
|
|
repeated0 = y.repeat(backed0 + u0, 1)
|
|
repeated1 = x.repeat(backed1 + u1, 1)
|
|
out1 = torch.empty_like(repeated1)
|
|
add_kernel[(out1.numel(),)](
|
|
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
|
|
)
|
|
|
|
# Implicitly add torch._check(expr2, unbacked)
|
|
cat = torch.cat([out1, index_select], dim=1)
|
|
add = repeated0 + repeated1
|
|
|
|
# Explicitly add torch._check(expr1, expr2)
|
|
torch._check(repeated0.size(0) == out1.size(0))
|
|
return cat, add
|
|
|
|
example_inputs = (
|
|
torch.ones(64, dtype=torch.int64, device=self.device),
|
|
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
|
|
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
|
|
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
|
|
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
|
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
|
torch.ones(758, 758, dtype=torch.int64, device=self.device),
|
|
torch.tensor([10, 10], dtype=torch.int32, device=self.device),
|
|
)
|
|
spec = {
|
|
"values": (Dim.DYNAMIC,),
|
|
"repeats": (Dim.DYNAMIC,),
|
|
"mask": (Dim.DYNAMIC,),
|
|
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
|
|
"x": (Dim.DYNAMIC, Dim.STATIC),
|
|
"y": (Dim.DYNAMIC, Dim.STATIC),
|
|
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
|
|
"lst": (Dim.STATIC,),
|
|
}
|
|
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
|
|
|
@config.patch({"unbacked_symint_fallback": 128})
|
|
def test_size_with_unbacked_add_and_mul_expr(self):
|
|
# Edge case with torch._check(add_expr, mul_expr). When generating example
|
|
# input sizes for autotuning, make sure they coalesce into a single size.
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
|
u0, u1, u2 = lst.tolist()
|
|
torch._check_is_size(u0)
|
|
torch._check_is_size(u1)
|
|
torch._check_is_size(u2)
|
|
backed = z.size(0)
|
|
backed1 = z.size(1)
|
|
|
|
unbacked_add_expr = backed + u0
|
|
unbacked_mul_expr = backed1 + (u1 * u2)
|
|
repeated0 = x.repeat(unbacked_add_expr, 1)
|
|
repeated1 = y.repeat(unbacked_mul_expr, 1)
|
|
out0 = torch.empty_like(repeated0)
|
|
out1 = torch.empty_like(repeated1)
|
|
add_kernel[(out0.numel(),)](
|
|
repeated0, repeated0, out0, out0.numel(), BLOCK_SIZE=2
|
|
)
|
|
add_kernel[(out1.numel(),)](
|
|
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
|
|
)
|
|
|
|
return torch.cat([out1, out0], dim=1)
|
|
|
|
example_inputs = (
|
|
torch.ones(64, dtype=torch.int64, device=self.device),
|
|
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
|
|
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
|
|
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
|
|
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
|
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
|
|
torch.ones(758, 758, dtype=torch.int64, device=self.device),
|
|
torch.tensor([10, 5, 2], dtype=torch.int32, device=self.device),
|
|
)
|
|
spec = {
|
|
"values": (Dim.DYNAMIC,),
|
|
"repeats": (Dim.DYNAMIC,),
|
|
"mask": (Dim.DYNAMIC,),
|
|
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
|
|
"x": (Dim.DYNAMIC, Dim.STATIC),
|
|
"y": (Dim.DYNAMIC, Dim.STATIC),
|
|
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
|
|
"lst": (Dim.STATIC,),
|
|
}
|
|
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
|
|
|
|
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
|
def test_fallback_kernel_with_symexpr_output(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, q, k, v):
|
|
q = q.reshape(
|
|
q.shape[0],
|
|
2,
|
|
q.shape[2] * q.shape[3],
|
|
q.shape[1] // 2,
|
|
)
|
|
k = k.reshape(
|
|
k.shape[0],
|
|
2,
|
|
k.shape[2] * k.shape[3],
|
|
k.shape[1] // 2,
|
|
)
|
|
v = v.reshape(
|
|
v.shape[0],
|
|
2,
|
|
v.shape[2] * v.shape[3],
|
|
v.shape[1] // 2,
|
|
)
|
|
|
|
res = torch.ops.aten._scaled_dot_product_flash_attention.default(
|
|
q,
|
|
k,
|
|
v,
|
|
)
|
|
return res[0]
|
|
|
|
m = Module().to(device=self.device)
|
|
tensor_shape = (4, 32, 4, 4)
|
|
inputs = (
|
|
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
|
|
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
|
|
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
|
|
)
|
|
|
|
dynamic_shapes = {
|
|
"q": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
|
|
"k": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
|
|
"v": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
|
|
}
|
|
ep = torch.export.export(m, inputs, dynamic_shapes=dynamic_shapes, strict=False)
|
|
path = torch._inductor.aot_compile(ep.module(), inputs)
|
|
aot_model = torch._export.aot_load(path, device=self.device)
|
|
torch.testing.assert_close(m(*inputs), aot_model(*inputs))
|
|
|
|
def test_aoti_constant_tensor(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.a = torch.ones(4, 4, device=device)
|
|
self.b = torch.ones(4, 4, device=device)
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.linear.default(x, self.a, self.b)
|
|
|
|
example_inputs = (torch.ones(4, 4, device=self.device),)
|
|
self.check_model(Foo(self.device), example_inputs)
|
|
|
|
def test_aoti_constant_tensor_name_collision(self):
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"_tensor_constant1",
|
|
torch.ones(1, device=device, dtype=torch.float32),
|
|
persistent=True,
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, user_float_feature_idx, device):
|
|
super().__init__()
|
|
self.user_float_feature_idx = user_float_feature_idx
|
|
self.register_buffer(
|
|
"_tensor_constant0",
|
|
torch.ones(1, device=device, dtype=torch.float32),
|
|
persistent=True,
|
|
)
|
|
self.register_buffer(
|
|
"_tensor_constant1",
|
|
torch.ones(1, device=device, dtype=torch.float32),
|
|
persistent=True,
|
|
)
|
|
self.sub_mod = SubModule(device)
|
|
|
|
def forward(self, x):
|
|
return (
|
|
torch.index_select(
|
|
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
|
|
),
|
|
self._tensor_constant0,
|
|
self._tensor_constant1,
|
|
self.sub_mod._tensor_constant1,
|
|
)
|
|
|
|
example_inputs = (torch.ones(4, 4, device=self.device),)
|
|
user_float_feature_idx = [1]
|
|
# we have to have run_decomposition first to trigger the name collision
|
|
ep = torch.export.export(
|
|
Foo(user_float_feature_idx, self.device), example_inputs, strict=False
|
|
).run_decompositions()
|
|
gm = ep.module()
|
|
self.check_model(gm, example_inputs)
|
|
|
|
def test_large_grid(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, primals_5):
|
|
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
|
|
primals_5 = None
|
|
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
|
|
clone = torch.ops.aten.clone.default(
|
|
permute, memory_format=torch.contiguous_format
|
|
)
|
|
return clone
|
|
|
|
# let y_grid = 65537
|
|
s0 = 16777472
|
|
s1 = 8
|
|
example_inputs = (torch.rand(s0, s1, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_cond_simple(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.Simple(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_nested(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_abc = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p0": {},
|
|
"p1": {},
|
|
"p2": {},
|
|
"a": {0: dim0_abc, 1: None},
|
|
"b": {0: dim0_abc, 1: None},
|
|
"c": {0: dim0_abc, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.Nested(),
|
|
prepend_predicates(inputs, num_predicates=3),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_parameters(self):
|
|
inputs = (torch.randn((10, 20), device=self.device),)
|
|
dim0_abc = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_abc, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.Parameters(self.device),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_reinterpret_view_inputs_outputs(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
# TODO: the min value need to be 5 because in the body_fn, we're slicing over z1[2:],
|
|
# since the output size is [dim0_ab-3], when we extract tensor metadata out of the output
|
|
# we call guard_size_oblivious, which assumes the dim0_ab-3 != 0 or 1. So we have to set
|
|
# the minimum to 5 for now. We need to relax this restriction either by writing a less
|
|
# constrained shape checking in fake impl of cond.
|
|
dim0_ab = Dim("s0", min=5, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.ReinterpretView(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_multiple_outputs(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((30, 40), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dim0_c = Dim("s1", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
"c": {0: dim0_c, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.MultipleOutputs(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_outer_code_before_after(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.OuterCode(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_use_buffers_from_outer_scope(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_abc = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_abc, 1: None},
|
|
"b": {0: dim0_abc, 1: None},
|
|
"c": {0: dim0_abc, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.OuterBuffers(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_cond_non_tensor_predicates(self, dynamic):
|
|
inputs1 = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((15, 20), device=self.device),
|
|
)
|
|
inputs2 = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((5, 20), device=self.device),
|
|
)
|
|
inputs = (inputs1,)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
inputs = (inputs1, inputs2)
|
|
dim0_a = Dim("s0", min=2, max=1024)
|
|
dim0_b = Dim("s1", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_a, 1: None},
|
|
"b": {0: dim0_b, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.WithNonTensorPredicate(),
|
|
inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_cond_unbacked_symint_closure(self, dynamic):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((15, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_a = Dim("s0", min=2, max=1024)
|
|
dim0_b = Dim("s1", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"x": {0: dim0_a, 1: None},
|
|
"y": {0: dim0_b, 1: None},
|
|
"z": {0: dim0_a, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.UnbackedSymIntClosure(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_cond_mismatched_branch_output(self, dynamic):
|
|
inputs = (
|
|
torch.randn(10, 20, device=self.device),
|
|
torch.randn(10, 20, device=self.device),
|
|
torch.randn(10, 20, device=self.device),
|
|
)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
# Note the minimum has to be 4 because the model
|
|
# is slicing over the first dim with [2:], if first
|
|
# dim is 2 or 3, the slicing will be 0/1 specialized,
|
|
# causing a constraint violation eror.
|
|
dim0_a = Dim("s0", min=4, max=1024)
|
|
dim0_b = Dim("s1", min=4, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"x": {0: dim0_a, 1: None},
|
|
"y": {0: dim0_b, 1: None},
|
|
"z": {0: dim0_a, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.MismatchedOutputSize(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_symint_input(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
input1 = (
|
|
torch.ones(3, 3, device=self.device),
|
|
torch.ones(5, device=self.device),
|
|
torch.ones(3, 3, device=self.device),
|
|
)
|
|
input2 = (
|
|
torch.ones(10, 3, device=self.device),
|
|
torch.ones(6, device=self.device),
|
|
torch.ones(10, 3, device=self.device),
|
|
)
|
|
inputs = (input1, input2)
|
|
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
|
self.check_model_with_multiple_inputs(
|
|
M(),
|
|
inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_simple(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"ci": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Simple(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_nested(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"ci": {},
|
|
"cj": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Nested(),
|
|
prepend_counters(inputs, num_counters=2),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_outer_code(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.OuterCode(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_parameters(self):
|
|
inputs = (torch.randn((10, 20), device=self.device),)
|
|
dim0_a = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_a, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Parameters(self.device),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_outer_buffers(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
# dynamic shapes don't work now due to
|
|
# https://github.com/pytorch/pytorch/issues/123596
|
|
# dim0_ab = Dim("s0", min=2, max=1024)
|
|
# dynamic_shapes = {
|
|
# "c": {},
|
|
# "a": {0: dim0_ab, 1: None},
|
|
# "b": {0: dim0_ab, 1: None},
|
|
# }
|
|
dynamic_shapes = None
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.OuterBuffers(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_pytree_inputs(self):
|
|
inputs = (
|
|
torch.tensor(0, device=self.device),
|
|
(
|
|
[torch.randn(10, 20, device=self.device)],
|
|
{
|
|
"x": torch.randn(10, 20, device=self.device),
|
|
"y": torch.randn(10, 20, device=self.device),
|
|
},
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.PytreeCarry(),
|
|
[inputs],
|
|
dynamic_shapes=None,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_unbacked_symint_closure(self, dynamic):
|
|
inputs = (
|
|
torch.randn(10, 20, device=self.device),
|
|
torch.randn(10, 20, device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.UnbackedSymIntClosure(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_mixed_device(self, dynamic):
|
|
inputs = (
|
|
torch.randn(10, 20, device=self.device),
|
|
torch.randn(10, 20, device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.MixedDevice(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_sym_expr_cond(self, dynamic):
|
|
inputs = (
|
|
torch.randn(10, 20, device=self.device),
|
|
torch.randn(10, 20, device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.SymExprCond(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_conv(self, dynamic):
|
|
inputs = (torch.randn(2, 4, 4, 4, device=self.device, dtype=torch.float64),)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"x": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Conv(self.device),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@config.patch({"is_predispatch": True})
|
|
def test_constant(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.device = device
|
|
|
|
def forward(self, x):
|
|
t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
|
|
t = torch.sqrt(t * 3)
|
|
return x * t
|
|
|
|
self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),))
|
|
|
|
@unittest.skipIf(IS_MACOS, "no CUDA on Mac")
|
|
def test_zero_grid_with_backed_symbols(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, b):
|
|
return x + b
|
|
|
|
example_inputs = (
|
|
torch.randn((3, 2), device=self.device),
|
|
torch.randn((1, 2), device=self.device),
|
|
)
|
|
dynamic_shapes = {
|
|
"x": {0: Dim("dx"), 1: Dim.STATIC},
|
|
"b": None,
|
|
}
|
|
|
|
# Compile & run model where dynamic dim size > 0.
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
Repro(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
aot_inductor_module(*example_inputs)
|
|
|
|
# Re-run where dynamic dim size is 0.
|
|
example_inputs = (
|
|
torch.randn((0, 2), device=self.device),
|
|
torch.randn((1, 2), device=self.device),
|
|
)
|
|
actual = aot_inductor_module(*example_inputs)
|
|
expected = Repro()(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_repeat_interleave(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12)
|
|
|
|
example_inputs = (torch.ones((1,), dtype=torch.int32, device=self.device) * 12,)
|
|
self.check_model(Repro(), example_inputs)
|
|
|
|
def test_dynamic_cat(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
return torch.cat([a, b], dim=0)
|
|
|
|
a = torch.randn(2, 4, device=self.device)
|
|
b = torch.randn(3, 4, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=10)
|
|
dim0_b = Dim("dim0_b", min=1, max=20)
|
|
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
|
|
example_inputs = (a, b)
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_buffer_mutation_1(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.foo = torch.nn.Buffer(torch.randn(4, 4, device=device))
|
|
|
|
def forward(self, x):
|
|
self.foo.add_(1)
|
|
return self.foo + x
|
|
|
|
example_inputs = (torch.rand(4, 4, device=self.device),)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_non_tensor_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, alpha=1.0):
|
|
return torch.add(a, b, alpha=alpha)
|
|
|
|
a = torch.randn(10, device=self.device)
|
|
b = torch.randn(10, device=self.device)
|
|
|
|
for simdlen in [0, None]:
|
|
with torch._inductor.config.patch({"cpp.simdlen": simdlen}):
|
|
so_path = torch._export.aot_compile(
|
|
torch.ops.aten.add,
|
|
args=(a, b),
|
|
kwargs={"alpha": 2.0},
|
|
)
|
|
kernel_runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
res = kernel_runner.run([a, b])
|
|
self.assertTrue(isinstance(res, list))
|
|
self.assertTrue(len(res) == 1)
|
|
self.assertEqual(Model()(a, b, alpha=2.0), res[0])
|
|
|
|
def test_buffer_mutation_2(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.foo = torch.nn.Buffer(torch.arange(10, device=device))
|
|
self.bar = torch.nn.Buffer(torch.arange(10, device=device))
|
|
|
|
def forward(self, x):
|
|
self.bar.mul_(2)
|
|
self.foo[5] = self.bar[0]
|
|
return x + self.bar, x * self.foo
|
|
|
|
example_inputs = (torch.randn(10, device=self.device),)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_buffer_mutation_3(self):
|
|
class KVCache(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
max_batch_size,
|
|
max_seq_length,
|
|
n_heads,
|
|
head_dim,
|
|
dtype=torch.float,
|
|
):
|
|
super().__init__()
|
|
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
|
self.k_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
|
|
self.v_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
|
|
|
|
def update(self, input_pos, k_val, v_val):
|
|
# input_pos: [S], k_val: [B, H, S, D]
|
|
k_out = self.k_cache
|
|
v_out = self.v_cache
|
|
k_out[:, :, input_pos] = k_val
|
|
v_out[:, :, input_pos] = v_val
|
|
|
|
return k_out, v_out
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.kv_cache = KVCache(1, 256, 6, 48)
|
|
|
|
def forward(self, inp_pos, k, v):
|
|
self.kv_cache.update(inp_pos, k, v)
|
|
return self.kv_cache.k_cache + 1, self.kv_cache.v_cache / 2
|
|
|
|
example_inputs = (
|
|
torch.tensor([0], device=self.device),
|
|
torch.randn(1, 6, 1, 48, device=self.device),
|
|
torch.randn(1, 6, 1, 48, device=self.device),
|
|
)
|
|
model = Model(self.device)
|
|
self.check_model(model, example_inputs)
|
|
self.code_check_count(model, example_inputs, "empty_strided", 2)
|
|
|
|
def test_buffer_mutation_4(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"_tensor_constant0",
|
|
torch.randint(1, size=[38], dtype=torch.int64, device="cpu"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x + self._tensor_constant0.to(
|
|
torch.device(type=GPU_TYPE, index=0)
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.randint(1, size=[38], dtype=torch.int64, device=GPU_TYPE),
|
|
)
|
|
torch._export.aot_compile(Model(), example_inputs)
|
|
|
|
@skipCUDAIf(True, "Test for x86 backend")
|
|
@skipIfXpu
|
|
@unittest.skipIf(IS_FBCODE, "Need newer ideep")
|
|
def test_buffer_mutation_and_force_mmap_weights(self):
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(16, 15)
|
|
self.linear2 = torch.nn.Linear(15, 14)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
out = self.linear2(x)
|
|
return out
|
|
|
|
example_inputs = (torch.randn(32, 16),)
|
|
model = Model().eval()
|
|
with config.patch(
|
|
{"freezing": True, "aot_inductor.force_mmap_weights": True}
|
|
), torch.no_grad():
|
|
exported_model = export_for_training(
|
|
model, example_inputs, strict=True
|
|
).module()
|
|
quantizer = X86InductorQuantizer()
|
|
quantizer.set_global(
|
|
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)
|
|
)
|
|
prepared_model = prepare_pt2e(exported_model, quantizer)
|
|
prepared_model(*example_inputs)
|
|
converted_model = convert_pt2e(prepared_model)
|
|
torch.ao.quantization.move_exported_model_to_eval(converted_model)
|
|
|
|
self.check_model(converted_model, example_inputs)
|
|
|
|
def test_fallback_mem_leak_fix(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y, idx):
|
|
tmp = x + y
|
|
w = torch.ops.aten.as_strided(tmp, x.shape, x.stride())
|
|
out = torch.ops.aten.index.Tensor(w, [idx])
|
|
return w, out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 1, 4, device=GPU_TYPE),
|
|
torch.randn(4, 1, 4, device=GPU_TYPE),
|
|
torch.randn(4, device=GPU_TYPE) > 0,
|
|
)
|
|
|
|
dim0 = Dim("dim0", min=1, max=2048)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0},
|
|
"y": {0: dim0},
|
|
"idx": {0: dim0},
|
|
}
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
Model(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
device: int = device_interface.current_device()
|
|
mem_before = device_interface.memory_allocated(device)
|
|
aot_inductor_module(*example_inputs)
|
|
mem_after = device_interface.memory_allocated(device)
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
actual = aot_inductor_module(*example_inputs)
|
|
expected = Model()(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
@requires_multigpu()
|
|
def test_replicate_on_devices(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, w1, w2):
|
|
super().__init__()
|
|
self.w1 = w1
|
|
self.w2 = w2
|
|
|
|
def forward(self, x, y):
|
|
a = x * self.w1
|
|
b = y * self.w2
|
|
return a + b
|
|
|
|
w1 = torch.randn(10, 10)
|
|
w2 = torch.randn(10, 10)
|
|
inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
result_cpu = Model(w1, w2)(*inputs)
|
|
|
|
# Compile model with AOTInductor
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
with device_interface.device(0):
|
|
package_path = AOTIRunnerUtil.compile(
|
|
model=Model(
|
|
w1.to(torch.device(GPU_TYPE, 0)), w2.to(torch.device(GPU_TYPE, 0))
|
|
),
|
|
example_inputs=tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
|
|
)
|
|
|
|
# Run model on gpu:N
|
|
for i in range(device_interface.device_count()):
|
|
with device_interface.device(i):
|
|
example_inputs = tuple(t.to(torch.device(GPU_TYPE, i)) for t in inputs)
|
|
optimized = torch._inductor.aoti_load_package(package_path)
|
|
result_gpu = optimized(*example_inputs)
|
|
self.assertTrue(same(result_cpu, result_gpu.cpu()))
|
|
|
|
@requires_multigpu()
|
|
def test_on_gpu_device1(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
try:
|
|
device_interface.get_device_properties(1)
|
|
except AssertionError:
|
|
raise unittest.SkipTest("GPU device 1 is not available") from None
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(10, 16)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(16, 1)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
x = self.sigmoid(x)
|
|
return x
|
|
|
|
device = f"{GPU_TYPE}:1"
|
|
model = Model().to(device)
|
|
example_inputs = (torch.randn(8, 10, device=device),)
|
|
expected = model(*example_inputs)
|
|
|
|
so_path = AOTIRunnerUtil.legacy_compile(model, example_inputs)
|
|
optimized = AOTIRunnerUtil.legacy_load(device, so_path)
|
|
actual = optimized(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_pytree_inputs(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: dict[str, torch.Tensor]):
|
|
device = next(iter(x.values())).device
|
|
add_ = torch.zeros(5, device=device)
|
|
mul_ = torch.ones(5, device=device)
|
|
for v in x.values():
|
|
add_ += v
|
|
mul_ *= v
|
|
|
|
return [add_, mul_]
|
|
|
|
self.check_model(
|
|
M(),
|
|
(
|
|
{
|
|
"x": torch.ones(5, device=self.device),
|
|
"y": torch.ones(5, device=self.device),
|
|
},
|
|
),
|
|
)
|
|
|
|
@requires_multigpu()
|
|
def test_non_default_gpu_device(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = weight
|
|
|
|
def forward(self, x, y):
|
|
return x + torch.nn.functional.linear(y, self.weight)
|
|
|
|
weight = torch.randn(10, 10)
|
|
inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
result_cpu = Model(weight)(*inputs)
|
|
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
with device_interface.device(0), torch.no_grad():
|
|
result_gpu_0 = AOTIRunnerUtil.run(
|
|
Model(weight.to(torch.device(GPU_TYPE, 0))),
|
|
tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
|
|
)
|
|
|
|
with device_interface.device(1), torch.no_grad():
|
|
result_gpu_1 = AOTIRunnerUtil.run(
|
|
Model(weight.to(torch.device(GPU_TYPE, 1))),
|
|
tuple(t.to(torch.device(GPU_TYPE, 1)) for t in inputs),
|
|
)
|
|
|
|
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
|
|
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
|
|
|
|
@requires_multigpu()
|
|
def test_load_package_multiple_gpus(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = weight
|
|
|
|
def forward(self, x, y):
|
|
return x + torch.nn.functional.linear(y, self.weight)
|
|
|
|
weight = torch.randn(10, 10, device=self.device)
|
|
inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
model = Model(weight).to(device=self.device)
|
|
result_ref = model(*inputs)
|
|
|
|
package_path = AOTIRunnerUtil.compile(model, inputs)
|
|
|
|
# Load AOT package on gpu:N
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
for i in range(device_interface.device_count()):
|
|
device = torch.device(GPU_TYPE, i)
|
|
with device_interface.device(i), torch.no_grad():
|
|
model_package = torch._inductor.aoti_load_package(
|
|
package_path, device_index=i
|
|
)
|
|
inputs_on_device = [input.to(device=device) for input in inputs]
|
|
result_package = model_package(*inputs_on_device)
|
|
self.assertTrue(same(result_ref.cpu(), result_package.cpu()))
|
|
|
|
def test_reuse_kernel(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.sin(x)
|
|
b = torch.mm(a, y)
|
|
c = torch.sin(b)
|
|
d = torch.mm(b, c)
|
|
return d
|
|
|
|
example_inputs = (
|
|
torch.randn(87, 87, device=self.device),
|
|
torch.randn(87, 87, device=self.device),
|
|
)
|
|
model = Model()
|
|
self.check_model(
|
|
model, example_inputs, atol=1e-4, rtol=1e-4
|
|
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
|
|
|
|
if self.device == GPU_TYPE:
|
|
self.code_check_count(
|
|
model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1
|
|
)
|
|
|
|
def test_reuse_kernel_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.cst = torch.randn(48, device=device, dtype=torch.float)
|
|
self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float)
|
|
self.cst_1 = torch.randn(48, device=device, dtype=torch.float)
|
|
self.weights_1 = torch.randn(
|
|
6, 48, 48, device=device, dtype=torch.float
|
|
)
|
|
|
|
def forward(self, x, y, z):
|
|
dim0 = x.size(1)
|
|
add_0 = z + z
|
|
expand_2 = add_0.expand(-1, -1, 48)
|
|
# [s0, 6, 48]
|
|
mul_3 = add_0 * expand_2
|
|
# [6, s0, 48]
|
|
permute_4 = torch.permute(mul_3, (1, 0, 2))
|
|
# [6, s0, 48]
|
|
bmm_5 = torch.bmm(permute_4, self.weights)
|
|
add_6 = bmm_5 + self.cst
|
|
reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8])
|
|
# [6*s0, 6, 8]
|
|
permute_8 = torch.permute(reshape_7, (1, 0, 2))
|
|
mul_9 = permute_8 * 0.123
|
|
reshape_10 = torch.reshape(y, [8, dim0 * 6, 4])
|
|
# [6*s0, 8, 4]
|
|
permute_11 = torch.permute(reshape_10, (1, 0, 2))
|
|
bmm_12 = torch.bmm(mul_9, permute_11)
|
|
|
|
add_0_1 = z + z
|
|
expand_2_1 = add_0_1.expand(-1, -1, 48)
|
|
# [s0, 6, 48]
|
|
mul_3_1 = add_0_1 * expand_2_1
|
|
# [6, s0, 48]
|
|
permute_4_1 = torch.permute(mul_3_1, (1, 0, 2))
|
|
# [6, s0, 48]
|
|
bmm_5_1 = torch.bmm(permute_4_1, self.weights_1)
|
|
add_6_1 = bmm_5_1 + self.cst_1
|
|
reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8])
|
|
# [6*s0, 6, 8]
|
|
permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2))
|
|
mul_9_1 = permute_8_1 * 0.123
|
|
reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4])
|
|
# [6*s0, 8, 4]
|
|
permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2))
|
|
bmm_12_1 = torch.bmm(mul_9_1, permute_11_1)
|
|
return bmm_12 + bmm_12_1
|
|
|
|
x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float)
|
|
y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float)
|
|
z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float)
|
|
dim0 = Dim("dim0", min=1, max=2048)
|
|
dynamic_shapes = {
|
|
"x": {1: dim0},
|
|
"y": {1: dim0},
|
|
"z": {0: dim0},
|
|
}
|
|
|
|
example_inputs = (x, y, z)
|
|
model = Model(self.device).to(dtype=torch.float)
|
|
self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_fake_tensor_device_validation(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
|
|
# Export on CPU
|
|
exported_program = export(Model(), example_inputs, strict=True)
|
|
|
|
# Compile exported model on GPU
|
|
gm = exported_program.graph_module.to(self.device)
|
|
with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"):
|
|
torch._inductor.aot_compile(
|
|
gm, tuple(i.to(self.device) for i in example_inputs)
|
|
)
|
|
|
|
def test_fx_gm_return_tuple_validation(self):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
|
|
gm = make_fx(Model(), tracing_mode="symbolic")(*example_inputs)
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
r"Graph output must be a tuple\(\). This is so that we can avoid "
|
|
"pytree processing of the outputs.",
|
|
):
|
|
torch._inductor.aot_compile(gm, example_inputs)
|
|
|
|
def test_consecutive_compiles(self):
|
|
"""Test that compilation behaves correctly with cache hits"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
mod = TestModule()
|
|
inp = torch.rand(1)
|
|
mod(inp)
|
|
mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])
|
|
so = torch._export.aot_compile(mod2, (inp,))
|
|
assert so is not None
|
|
# compile the 2nd time with cache hit
|
|
so = torch._export.aot_compile(mod2, (inp,))
|
|
assert so is not None
|
|
|
|
def test_normal_functional(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.normal_functional.default(x)
|
|
|
|
self.check_model(Model(), (torch.empty(4, 1, 4, 4, device=self.device),))
|
|
|
|
def test_empty_graph(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
example_inputs = (torch.randn(8, 4, 4, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@patch("torch._dynamo.utils.CompileEventLogger.log_instant_event")
|
|
def test_backward_no_op_logging(self, mock_log_instant_event):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
model = Model()
|
|
dummy_input = torch.randn(1, 5)
|
|
|
|
from torch._dynamo.utils import CompileEventLogLevel
|
|
from torch._inductor import compile_fx
|
|
|
|
graph_module = torch.fx.symbolic_trace(model)
|
|
compile_fx._compile_fx_inner(graph_module, (dummy_input,))
|
|
mock_log_instant_event.assert_called_once_with(
|
|
"backward no-op",
|
|
metadata={"compile_id": None},
|
|
log_level=CompileEventLogLevel.PT2_COMPILE,
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
|
def test_dup_unbacked_sym_decl(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
abs_1 = torch.ops.aten.abs.default(x)
|
|
lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
|
|
eq = torch.ops.aten.eq.Scalar(lt, 0)
|
|
index_1 = torch.ops.aten.index.Tensor(x, [eq])
|
|
sin = torch.ops.aten.sin.default(index_1)
|
|
index_2 = torch.ops.aten.index.Tensor(x, [eq])
|
|
div_3 = torch.ops.aten.div.Tensor(sin, index_2)
|
|
return div_3
|
|
|
|
example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
# This exercises _eliminate_unbacked path in ShapeEnv
|
|
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
|
def test_dup_unbacked_sym_decl_with_refinement(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
abs_1 = torch.ops.aten.abs.default(x)
|
|
lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
|
|
eq = torch.ops.aten.eq.Scalar(lt, 0)
|
|
index_1 = torch.ops.aten.index.Tensor(x, [eq])
|
|
torch._check(index_1.size(0) == 4**4)
|
|
sin = torch.ops.aten.sin.default(index_1)
|
|
index_2 = torch.ops.aten.index.Tensor(x, [eq])
|
|
div_3 = torch.ops.aten.div.Tensor(sin, index_2)
|
|
return div_3
|
|
|
|
example_inputs = (torch.ones(4, 4, 4, 4).to(self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_run_with_grad_enabled(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, weight, bias):
|
|
return torch.ops.aten.addmm(bias, weight, x)
|
|
|
|
m = Model().to(device=self.device)
|
|
x = torch.rand(8, 8, device=self.device, requires_grad=True)
|
|
weight = torch.rand(8, 8, device=self.device, requires_grad=True)
|
|
bias = torch.rand(8, device=self.device, requires_grad=True)
|
|
example_inputs = (x, weight, bias)
|
|
|
|
expected = m(*example_inputs)
|
|
expected = pytree.tree_leaves(expected)
|
|
|
|
# compiler under no_grad
|
|
with torch.no_grad():
|
|
package_path = AOTIRunnerUtil.compile(m, example_inputs)
|
|
|
|
# run under grad enabled
|
|
self.assertTrue(torch.is_grad_enabled())
|
|
|
|
optimized = torch._inductor.aoti_load_package(package_path)
|
|
actual = optimized(*example_inputs)
|
|
actual = pytree.tree_leaves(actual)
|
|
|
|
self.assertTrue(same(actual, expected))
|
|
|
|
def test_return_constant(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.cst = torch.randn(5, 5, device=device)
|
|
|
|
def forward(self, x):
|
|
a = self.cst.clone()
|
|
return (x, a)
|
|
|
|
x = torch.randn(5, device=self.device)
|
|
self.check_model(Model(self.device), (x,))
|
|
|
|
def test_return_view_constant(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.cst = torch.randn(5, 5, device=device)
|
|
|
|
def forward(self, x):
|
|
a = torch.transpose(self.cst, 0, 1)
|
|
return (x, a)
|
|
|
|
x = torch.randn(5, device=self.device)
|
|
self.check_model(Model(self.device), (x,))
|
|
|
|
def test_profile_benchmark_harness(self):
|
|
batch_size = 32
|
|
seq_length = 50
|
|
hidden_size = 768
|
|
|
|
def create_test_fn():
|
|
def test_fn():
|
|
inp = torch.randn(
|
|
batch_size, seq_length, hidden_size, device=self.device
|
|
)
|
|
weight = torch.randn(hidden_size, hidden_size, device=self.device)
|
|
matmul_output = inp @ weight
|
|
torch.nn.LayerNorm(hidden_size, device=self.device)(matmul_output)
|
|
return True
|
|
|
|
return test_fn
|
|
|
|
fn = torch.compile(
|
|
options={"profile_bandwidth_output": "foo", "benchmark_harness": False}
|
|
)(create_test_fn())
|
|
fn()
|
|
|
|
def test_with_profiler(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_with_no_triton_profiler(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.permute(x, (1, 0))
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_repeat_output(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
return y, y
|
|
|
|
example_inputs = (torch.randn(3, 10, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_repeated_calling(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("requires CUDA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.sin(x)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
optimized = torch._inductor.aoti_load_package(
|
|
torch._inductor.aoti_compile_and_package(
|
|
torch.export.export(Model(), example_inputs, strict=True)
|
|
)
|
|
)
|
|
try:
|
|
torch.cuda.memory.empty_cache()
|
|
torch.cuda.memory._record_memory_history(context=None)
|
|
for _ in range(10):
|
|
optimized(*example_inputs)
|
|
finally:
|
|
torch.cuda.memory._record_memory_history(False)
|
|
segments = torch.cuda.memory._snapshot()["segments"]
|
|
self.assertEqual(segments[0]["requested_size"], 400)
|
|
|
|
def test_view_outputs(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
y_same_size = y.view(*y.shape)
|
|
y_diff_size = y.view(1, *y.shape)
|
|
return y, y_same_size, y_diff_size
|
|
|
|
example_inputs = (torch.randn(3, 10, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@skip_if_no_torchvision
|
|
def test_missing_cubin(self):
|
|
from torchvision.models.resnet import Bottleneck, ResNet
|
|
|
|
class Model(ResNet):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
block=Bottleneck,
|
|
layers=[3, 4, 6, 3],
|
|
replace_stride_with_dilation=[False, False, True],
|
|
norm_layer=None,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
f1 = x
|
|
x = self.maxpool(x)
|
|
x = self.layer1(x)
|
|
f2 = x
|
|
x = self.layer2(x)
|
|
f3 = x
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
f4 = x
|
|
return [f1, f2, f3, f4]
|
|
|
|
# Call eval() here so that batch_norm won't update the running stats
|
|
# Use float64 to avoid numeric difference failure
|
|
model = Model().to(device=self.device, dtype=torch.float64).eval()
|
|
example_inputs = (
|
|
torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
|
|
)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_triton_next_power_of_2(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, lengths):
|
|
n_elements = a.numel()
|
|
out = torch.empty_like(a)
|
|
max_len = int(lengths.max())
|
|
scaling_factor = triton.next_power_of_2(max_len)
|
|
add_kernel_with_scaling[(n_elements,)](
|
|
a,
|
|
b,
|
|
out,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(2, device=self.device),
|
|
torch.randn(2, device=self.device),
|
|
torch.arange(end=4, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("minmax", [min, max])
|
|
def test_sympy_cpp_printer_min_max(self, minmax):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, ranks):
|
|
n_elements = a.numel()
|
|
out = torch.empty_like(a)
|
|
backed = a.size(0)
|
|
unbacked = int(ranks.max())
|
|
scaling_factor = minmax(backed, unbacked, 100)
|
|
add_kernel_with_scaling[(n_elements,)](
|
|
a,
|
|
b,
|
|
out,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(16, device=self.device),
|
|
torch.randn(16, device=self.device),
|
|
torch.arange(end=4, device=self.device, dtype=torch.int16),
|
|
)
|
|
torch._dynamo.mark_dynamic(example_inputs[0], 0)
|
|
torch._dynamo.mark_dynamic(example_inputs[1], 0)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("grid_type", [1, 2, 3])
|
|
@common_utils.parametrize("num_dims", [1, 2])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("autotune", [False, True])
|
|
def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
output = torch.zeros_like(x)
|
|
if autotune and num_dims == 2:
|
|
x_elements = output.size()[0]
|
|
y_elements = output.size()[1]
|
|
else:
|
|
n_elements = output.numel()
|
|
|
|
# Select grid
|
|
if autotune and num_dims == 2:
|
|
if grid_type == 1:
|
|
grid = (x_elements, y_elements)
|
|
elif grid_type == 2:
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
else:
|
|
|
|
def grid_fn(meta):
|
|
return (
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
|
|
grid = grid_fn
|
|
else:
|
|
if grid_type == 1:
|
|
grid = (n_elements,)
|
|
elif grid_type == 2:
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
else:
|
|
|
|
def grid_fn(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
grid = grid_fn
|
|
|
|
# Select kernel
|
|
if autotune:
|
|
if num_dims == 1:
|
|
add_kernel_autotuned[grid](x, y, output, n_elements)
|
|
else:
|
|
add_kernel_2d_autotuned[grid](
|
|
x, y, output, x_elements, y_elements
|
|
)
|
|
else:
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
dims = [10] * num_dims
|
|
x = torch.randn(*dims, device=self.device)
|
|
y = torch.randn(*dims, device=self.device)
|
|
dynamic_shapes = []
|
|
if dynamic:
|
|
dim0_x = Dim("dim0_x", min=1, max=10)
|
|
dim0_y = Dim("dim0_y", min=1, max=10)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_triton_kernel_dynamic_shape_with_div(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.jit
|
|
def pass_kernel(x, num):
|
|
pass
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
num = x.numel() // 4
|
|
|
|
grid = lambda meta: (triton.cdiv(num, 16),) # noqa: E731
|
|
pass_kernel[grid](x, num)
|
|
return x
|
|
|
|
x = torch.randn(10, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=10)
|
|
dynamic_shapes = {"x": {0: dim0_x}}
|
|
self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_triton_kernel_reinterpret_view(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.jit
|
|
def pass_kernel(x, y):
|
|
pass
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
out = torch.zeros_like(x[:, 4:])
|
|
# the slicing below creates two ReinterpretView
|
|
# instances: with offset=3 and offset=4
|
|
add_kernel[(10,)](
|
|
in_ptr0=x[:, 3:-1],
|
|
in_ptr1=x[:, 4:],
|
|
out_ptr=out,
|
|
n_elements=160,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
example_inputs = (torch.randn(10, 20, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("tma_version", ["new", "old"])
|
|
def test_triton_kernel_tma_descriptor_1d(self, dynamic, tma_version):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
if tma_version == "new" and not has_triton_tensor_descriptor_host_tma():
|
|
self.skipTest("requires triton.tools.tensor_descriptor TMA support")
|
|
if tma_version == "old" and not has_triton_experimental_host_tma():
|
|
self.skipTest("requires triton.tools.experimental_descriptor TMA support")
|
|
|
|
kernel = (
|
|
add_kernel_with_tma_1d_new_api
|
|
if tma_version == "new"
|
|
else add_kernel_with_tma_1d_old_api
|
|
)
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
BLOCK_SIZE = 256
|
|
out = torch.zeros_like(a)
|
|
n_elements = out.numel()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
create_tensor_descriptor_shim(
|
|
t, [BLOCK_SIZE], new_api=(tma_version == "new")
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
kernel[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn(301, device=self.device)
|
|
b = torch.randn(301, device=self.device)
|
|
example_inputs = (a, b)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs=example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("tma_version", ["new", "old"])
|
|
def test_triton_kernel_tma_descriptor_2d(self, dynamic, tma_version):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
if tma_version == "new" and not has_triton_tensor_descriptor_host_tma():
|
|
self.skipTest("requires triton.tools.tensor_descriptor TMA support")
|
|
if tma_version == "old" and not has_triton_experimental_host_tma():
|
|
self.skipTest("requires triton.tools.experimental_descriptor TMA support")
|
|
|
|
kernel = (
|
|
add_kernel_with_tma_2d_new_api
|
|
if tma_version == "new"
|
|
else add_kernel_with_tma_2d_old_api
|
|
)
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
BLOCK_SIZE_X = 16
|
|
BLOCK_SIZE_Y = 32
|
|
out = torch.zeros_like(a)
|
|
x_size, y_size = out.size()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
create_tensor_descriptor_shim(
|
|
t,
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
new_api=(tma_version == "new"),
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(x_size, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
kernel[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE_X=BLOCK_SIZE_X,
|
|
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn((25, 16), device=self.device)
|
|
b = torch.randn((25, 16), device=self.device)
|
|
example_inputs = (a, b)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs=example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_triton_kernel_sympy_expr_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, e):
|
|
sympy_expr = max(1, e.item())
|
|
out = torch.zeros_like(x)
|
|
add_kernel[(1,)](
|
|
in_ptr0=x,
|
|
in_ptr1=x,
|
|
out_ptr=out,
|
|
n_elements=sympy_expr,
|
|
BLOCK_SIZE=1,
|
|
)
|
|
return out
|
|
|
|
NUMEL = 64
|
|
inputs = (
|
|
torch.randn(NUMEL, device=self.device),
|
|
torch.tensor(NUMEL, device=self.device),
|
|
)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_triton_kernel_sympy_fn_like_arg(self):
|
|
# This test should hit sympy.expand("sqrt") which crashes with
|
|
# AttributeError: 'function' object has no attribute 'expand'.
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
out = torch.zeros_like(x)
|
|
add_kernel_with_optional_param[1,](
|
|
in_ptr0=x,
|
|
in_ptr1=x,
|
|
out_ptr=out,
|
|
n_elements=x.numel(),
|
|
BLOCK_SIZE=1,
|
|
ARGS_PASSED="sqrt", # sqrt is a valid sympy fn
|
|
)
|
|
return out
|
|
|
|
inputs = (torch.randn(4, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_triton_kernel_with_none_input(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
n_elements = x.size()[0]
|
|
BLOCK_SIZE = 1024
|
|
|
|
output_wo_y = torch.empty_like(x)
|
|
output_with_y = torch.empty_like(x)
|
|
|
|
add_kernel_with_optional_param[(1,)](
|
|
x,
|
|
None,
|
|
output_wo_y,
|
|
n_elements,
|
|
ARGS_PASSED="one",
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
add_kernel_with_optional_param[(1,)](
|
|
x,
|
|
y,
|
|
output_with_y,
|
|
n_elements,
|
|
ARGS_PASSED="two",
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return 2.71 * output_wo_y + 3.14 * output_with_y
|
|
|
|
example_inputs = (
|
|
torch.randn(1023, device=self.device),
|
|
torch.randn(1023, device=self.device),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_triton_kernel_equal_to_1_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
add_kernel[(n_elements,)](x, y, out, n_elements, BLOCK_SIZE=16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(1, device=self.device),
|
|
torch.randn(1, device=self.device),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_triton_kernel_with_none_inputs_and_equal_to_1_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
n_elements = x.size()[0]
|
|
BLOCK_SIZE = 1024
|
|
out1 = torch.empty_like(x)
|
|
out2 = torch.empty_like(x)
|
|
# Run the same kernel multiple times to test the optimization
|
|
# of removing None arguments and then update the indices of
|
|
# equal_to_1 arguments. The None arguments need to be before
|
|
# the equal_to_1 arguments
|
|
add_kernel_with_none_param_and_equal_to_1_arg[(1,)](
|
|
x,
|
|
None,
|
|
out1,
|
|
n_elements,
|
|
x.stride(0), # equal to 1
|
|
ARGS_PASSED="one",
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
add_kernel_with_none_param_and_equal_to_1_arg[(1,)](
|
|
2.71 * out1,
|
|
None,
|
|
out2,
|
|
n_elements,
|
|
x.stride(0), # equal to 1
|
|
ARGS_PASSED="one",
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
return out2
|
|
|
|
example_inputs = (torch.randn(1023, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
scaling_factor = (n_elements**0) / 1.0
|
|
add_kernel_with_scaling[(n_elements,)](
|
|
x,
|
|
y,
|
|
out,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_xy = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0_xy},
|
|
"y": {0: dim0_xy},
|
|
}
|
|
example_inputs = (
|
|
torch.randn(2, device=self.device),
|
|
torch.randn(2, device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_triton_kernel_weird_param_order(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
out = torch.empty_like(x)
|
|
add_kernel_autotuned_weird_param_order[16,](
|
|
in_ptr0=x,
|
|
in_ptr1=x,
|
|
n_elements=x.numel(),
|
|
out_ptr=out,
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(16, 16, device=self.device)
|
|
self.check_model(Model(), (x,))
|
|
|
|
def test_triton_kernel_dynamic_grid(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
import math
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y, n_elements_tensor):
|
|
output = torch.zeros_like(x)
|
|
n_elements_symint = n_elements_tensor.item()
|
|
n_elements = x.numel()
|
|
|
|
def grid(meta):
|
|
n_elements_complicated = n_elements_symint // 1.0
|
|
return (math.trunc(n_elements_complicated / meta["BLOCK_SIZE"]),)
|
|
|
|
add_kernel_autotuned[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
)
|
|
|
|
return output
|
|
|
|
x = torch.randn(128, device=self.device)
|
|
y = torch.randn(128, device=self.device)
|
|
n_elem = torch.tensor(128)
|
|
dim0_x = Dim("dim0_x", min=8, max=256)
|
|
dim0_y = Dim("dim0_y", min=8, max=256)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}, "n_elements_tensor": {}}
|
|
self.check_model(Model(), (x, y, n_elem), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_shifted_constraint_ranges(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
):
|
|
torch._check(y.size(0) == x.size(0) + 1)
|
|
return x.sum(0) + y.sum(0)
|
|
|
|
a = torch.randn((4, 5), device=self.device)
|
|
b = torch.randn((5, 5), device=self.device)
|
|
dim0_x = Dim("dim0_x", min=2, max=1024)
|
|
dim0_y = dim0_x + 1
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check_model(
|
|
Model(),
|
|
(a, b),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_scatter_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
inp: torch.Tensor,
|
|
index: torch.Tensor,
|
|
src: torch.Tensor,
|
|
):
|
|
return torch.scatter(inp, 1, index, src)
|
|
|
|
inputs = (
|
|
torch.ones((3, 5), device=self.device, dtype=torch.int64),
|
|
torch.tensor([[0, 1, 2, 0]], device=self.device, dtype=torch.int64),
|
|
torch.zeros((2, 5), device=self.device, dtype=torch.int64),
|
|
)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_scatter_reduce_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
inp: torch.Tensor,
|
|
index: torch.Tensor,
|
|
src: torch.Tensor,
|
|
):
|
|
return torch.scatter_reduce(inp, 0, index, src, reduce="sum")
|
|
|
|
inputs = (
|
|
torch.tensor([1, 10, 100, 1000], device=self.device, dtype=torch.int64),
|
|
torch.tensor([0, 1, 0, 1, 2, 1], device=self.device, dtype=torch.int64),
|
|
torch.tensor([1, 2, 3, 4, 5, 6], device=self.device, dtype=torch.int64),
|
|
)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_index_put_fallback(self):
|
|
# index_put falls back in the deterministic mode
|
|
with DeterministicGuard(True):
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
self_tensor: torch.Tensor,
|
|
indices: tuple[torch.Tensor],
|
|
values: torch.Tensor,
|
|
):
|
|
return torch.index_put(
|
|
self_tensor, indices, values, accumulate=True
|
|
)
|
|
|
|
inputs = (
|
|
torch.ones(4, device=self.device, dtype=torch.int64),
|
|
(torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),),
|
|
torch.ones(4, device=self.device, dtype=torch.int64),
|
|
)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_narrow_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp: torch.Tensor, dim: int, start: int, length: int):
|
|
return torch.ops.aten.narrow(inp, dim, start, length)
|
|
|
|
inputs = (torch.rand((3, 4), device=self.device), 0, 0, 2)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_pad_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
inp: torch.Tensor,
|
|
pad: tuple[int, ...],
|
|
):
|
|
return torch.ops.aten.pad(inp, pad)
|
|
|
|
inputs = (torch.rand((3, 3, 4, 2), device=self.device), (0, 1, 2, 1, 3, 3))
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_fill__fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, inp: torch.Tensor, scalar: float):
|
|
torch.ops.aten.fill_(inp, scalar)
|
|
return inp
|
|
|
|
inputs = (torch.rand((3, 3, 4, 2), device=self.device), 0.5)
|
|
self.check_model(Model(), inputs)
|
|
|
|
@common_utils.parametrize("embed_kernel_binary", [False, True])
|
|
def test_repeated_user_defined_triton_kernel(self, embed_kernel_binary):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
for _ in range(3):
|
|
mul2_inplace_kernel[4,](x, n_elements=4, BLOCK_SIZE=16)
|
|
return x
|
|
|
|
inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.embed_kernel_binary": embed_kernel_binary}):
|
|
model = Model()
|
|
self.check_model(model, inputs)
|
|
_, code = run_and_get_cpp_code(AOTIRunnerUtil.compile, model, inputs)
|
|
FileCheck().check("launchKernel(").run(code)
|
|
if config.aot_inductor.embed_kernel_binary:
|
|
# Not expect to see launchKernel("CUBIN_FILE_NAME"
|
|
FileCheck().check_not('launchKernel("').run(code)
|
|
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
|
)
|
|
def test_convolution(self):
|
|
if self.device == "cpu":
|
|
raise unittest.SkipTest("using triton backend only is not supported on CPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, w, b):
|
|
return torch.ops.aten.convolution(x, w, b, [4], [0], [1], True, [0], 1)
|
|
|
|
example_inputs = (
|
|
torch.randn([2, 32, 90], device=self.device),
|
|
torch.randn([32, 16, 8], device=self.device),
|
|
torch.randn([16], device=self.device),
|
|
)
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "Triton",
|
|
}
|
|
):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_zero_size_weight(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, channel, r=8):
|
|
super().__init__()
|
|
self.pool = torch.nn.AdaptiveAvgPool2d(1)
|
|
self.net = torch.nn.Sequential(
|
|
torch.nn.Linear(channel, channel // r, bias=False),
|
|
torch.nn.ReLU(inplace=True),
|
|
torch.nn.Linear(channel // r, channel, bias=False),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, inp):
|
|
b, c, _, _ = inp.shape
|
|
x = self.pool(inp).view(b, c)
|
|
x = self.net(x).view(b, c, 1, 1)
|
|
x = inp * x
|
|
return x
|
|
|
|
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
|
|
self.check_model(Model(4), inputs)
|
|
|
|
def test_zero_size_buffer(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.foo = torch.nn.Buffer(torch.zeros((0, 0), device=device))
|
|
|
|
def forward(self, x):
|
|
return x + 1, self.foo
|
|
|
|
example_inputs = (torch.rand(4, 4, device=self.device),)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_no_args(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, m, n):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(
|
|
torch.randn(m, n),
|
|
)
|
|
self.alpha = torch.nn.Parameter(torch.randn(m, n))
|
|
|
|
def forward(self):
|
|
return self.weight * self.alpha
|
|
|
|
self.check_model(Model(6, 4), ())
|
|
|
|
def test_dynamic_scalar(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.criterion_ce = torch.nn.CrossEntropyLoss(reduction="none")
|
|
|
|
def forward(self, inputs, targets, split_index=None):
|
|
statistics = {}
|
|
total_loss = self.criterion_ce(inputs, targets).sum()
|
|
statistics["dl"] = total_loss.item()
|
|
return total_loss, statistics
|
|
|
|
inputs = (
|
|
torch.rand(4, 4, 4, 4, device=self.device),
|
|
torch.rand(4, 4, 4, 4, device=self.device),
|
|
)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_symint_item(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tensor):
|
|
return tensor.item()
|
|
|
|
inputs = (torch.tensor([1], dtype=torch.int, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_symbool_item(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tensor):
|
|
return tensor.item()
|
|
|
|
inputs = (torch.tensor([0], dtype=torch.bool, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_symfloat_item(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tensor):
|
|
return tensor.item()
|
|
|
|
inputs = (torch.tensor([3.14], dtype=torch.float, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_constant_original_fqn_and_dtype(self):
|
|
class FooBarModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4)))
|
|
self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
|
|
self.register_parameter(
|
|
"test_param", torch.nn.Parameter(torch.randn(3, 4))
|
|
)
|
|
|
|
def forward(self, x):
|
|
return ((x + self.test_buf) * getattr(self, "0")) / self.test_param
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo_bar = FooBarModule()
|
|
self.register_parameter(
|
|
"test_param", torch.nn.Parameter(torch.randn(3, 4))
|
|
)
|
|
self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
|
|
|
|
def forward(self, x):
|
|
return (self.foo_bar(x) + self.test_param) * self.test_buf
|
|
|
|
with torch.no_grad():
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=TestModule().to(device=self.device),
|
|
example_inputs=(torch.rand(3, 4, device=self.device),),
|
|
)
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
expected_original_fqns = {
|
|
"L__self___test_param": "test_param",
|
|
"L__self___test_buf": "test_buf",
|
|
"getattr_L__self___foo_bar___0__": "foo_bar.0",
|
|
"L__self___foo_bar_test_param": "foo_bar.test_param",
|
|
"L__self___foo_bar_test_buf": "foo_bar.test_buf",
|
|
}
|
|
self.assertEqual(
|
|
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
|
)
|
|
|
|
expected_dtypes = {
|
|
"L__self___test_param": 6,
|
|
"L__self___test_buf": 6,
|
|
"getattr_L__self___foo_bar___0__": 6,
|
|
"L__self___foo_bar_test_param": 6,
|
|
"L__self___foo_bar_test_buf": 6,
|
|
}
|
|
self.assertEqual(expected_dtypes, runner.get_constant_names_to_dtypes())
|
|
|
|
def test_masked_select_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
mask = x.ge(0.5)
|
|
return torch.masked_select(x, mask)
|
|
|
|
example_args = (torch.randn(3, 4, 5, device=self.device),)
|
|
dim0_x_max, dim1_x_max = 100, 7
|
|
dynamic_shapes = {
|
|
"x": {
|
|
0: Dim("dim0_x", max=dim0_x_max),
|
|
1: Dim("dim1_x_max", max=dim1_x_max),
|
|
}
|
|
}
|
|
m = M()
|
|
self.check_model(m, example_args, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_proxy_executor_permute(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.permute.default(x, [0, 2, 1])
|
|
|
|
example_args = (torch.randn((1, 3001, 201), dtype=torch.complex64),)
|
|
m = M()
|
|
self.check_model(m, example_args)
|
|
|
|
def test_proxy_executor_abs(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.abs.default(x)
|
|
|
|
example_args = (torch.randn((1, 3001, 201), dtype=torch.complex64),)
|
|
m = M()
|
|
self.check_model(m, example_args)
|
|
|
|
def test_proxy_executor_squeeze(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.squeeze.dim(x, 0)
|
|
|
|
example_args = (torch.randn((1, 300, 201), dtype=torch.complex64),)
|
|
m = M()
|
|
self.check_model(m, example_args)
|
|
|
|
def test_proxy_executor_hann(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self):
|
|
return torch.ops.aten.hann_window.default(400)
|
|
|
|
example_args = ()
|
|
m = M()
|
|
self.check_model(m, example_args)
|
|
|
|
def test_fqn(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nestedchild3buffer = torch.nn.Buffer(torch.ones(2, 3) * 3)
|
|
|
|
def forward(self, x):
|
|
return x / self.nestedchild3buffer
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3) * 2)
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3) * 4)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
self.check_model(MyModule(), (torch.randn(2, 3, device=self.device),))
|
|
|
|
def test_model_modified_weights(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 16
|
|
N = 10
|
|
K = 128
|
|
example_inputs = (torch.randn(2, M, K, device=self.device),)
|
|
model = Model(N, K, self.device)
|
|
self.check_model(model, example_inputs)
|
|
|
|
# Update model weights, after this AOTInductor should re-generate model.so
|
|
# if weights are stored in the model.so
|
|
model.weight += 1
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_triton_kernel_extern_kernel_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
# torch.mm is ExternKernelOut
|
|
add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_triton_kernel_multi_output_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
# torch.sort creates fallback kernel and hence MultiOutput
|
|
add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
# @skipIfXpu(msg="torch.xpu.memory_allocated not supported yet")
|
|
def test_triton_kernel_reinterpret_view_mem_leak(self):
|
|
# Check for memory leak when using user-defined Triton Kernel + AOTI.
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
yy = y * y
|
|
# reshape creates a ReinterpretView
|
|
add_kernel[(4,)](x, yy.reshape_as(x), out, 4, 16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(1, 16, device=GPU_TYPE),
|
|
)
|
|
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
Model(),
|
|
example_inputs,
|
|
)
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
# Don't assign outputs to a variable b/c it will allocate GPU memory.
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
device: int = device_interface.current_device()
|
|
mem_before = device_interface.memory_allocated(device)
|
|
aot_inductor_module(*example_inputs)
|
|
aot_inductor_module(*example_inputs)
|
|
mem_after = device_interface.memory_allocated(device)
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
actual = aot_inductor_module(*example_inputs)
|
|
expected = Model()(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("autotuning", [False, True])
|
|
def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y, n_elements_tensor):
|
|
output = torch.zeros_like(x)
|
|
n_elements_symint = n_elements_tensor.item()
|
|
n_elements = x.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements_symint, meta["BLOCK_SIZE"]),)
|
|
|
|
if autotuning:
|
|
add_kernel_autotuned[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
)
|
|
else:
|
|
add_kernel[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
|
|
return output
|
|
|
|
example_inputs = (
|
|
torch.randn(123, device=GPU_TYPE),
|
|
torch.randn(123, device=GPU_TYPE),
|
|
torch.tensor(123),
|
|
)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0},
|
|
"y": {0: dim0},
|
|
"n_elements_tensor": {},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_scaled_dot_product_efficient_attention(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, q, k, v, attn_bias):
|
|
return torch.ops.aten._scaled_dot_product_efficient_attention(
|
|
q, k, v, attn_bias, False
|
|
)[0]
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_aoti_runtime_asserts(self):
|
|
from torch.export._draft_export import draft_export, FailureType
|
|
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor a, Tensor b) -> Tensor",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
return a[: b.item()]
|
|
|
|
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
|
def foo_fake_impl(a, b):
|
|
ctx = torch.library.get_ctx()
|
|
u = ctx.new_dynamic_size()
|
|
return torch.empty(u)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
res = torch.ops.mylib.foo(a, b)
|
|
s = res.shape[0]
|
|
torch._check(s > 3)
|
|
torch._check(s < a.shape[0])
|
|
return a[s - 3]
|
|
|
|
example_inputs = (torch.randn(100), torch.tensor(10))
|
|
ep = draft_export(M(), example_inputs)
|
|
report = ep._report
|
|
need_config_patch = any(
|
|
not f.xfail and f.failure_type == FailureType.MISMATCHED_FAKE_KERNEL
|
|
for f in report.failures
|
|
)
|
|
m = ep.module()
|
|
|
|
# This should no longer be needed after #150093
|
|
from torch._functorch import config as functorch_config
|
|
|
|
with functorch_config.patch(
|
|
{"generate_fake_kernels_from_real_mismatches": need_config_patch}
|
|
):
|
|
pt2_file = torch._inductor.aoti_compile_and_package(ep)
|
|
optimized = torch._inductor.aoti_load_package(pt2_file)
|
|
|
|
self.assertTrue(same(optimized(*example_inputs), m(*example_inputs)))
|
|
|
|
with self.assertRaisesRegex(Exception, "run_func_(.*) API call failed "):
|
|
optimized(torch.randn(100), torch.tensor(2))
|
|
|
|
@patch.dict(os.environ, {"TORCHINDUCTOR_SCALAR_ASSERTS_FULL": "1"})
|
|
def test_aoti_runtime_asserts_backed_symint(self):
|
|
if not full_aoti_runtime_assert():
|
|
raise unittest.SkipTest("full runtime assert not turned on")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = x.reshape(100, -1).clone()
|
|
y = y + 1
|
|
return y
|
|
|
|
model = Model().to(self.device)
|
|
input1 = (torch.rand(100, device=self.device),)
|
|
input2 = (torch.rand(2099, device=self.device),)
|
|
dynamic_shapes = {
|
|
"x": {0: torch.export.Dim.DYNAMIC},
|
|
}
|
|
package_path = AOTIRunnerUtil.compile(
|
|
model,
|
|
input1,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
optimized = torch._inductor.aoti_load_package(package_path)
|
|
self.assertEqual(model(*input1), optimized(*input1))
|
|
with self.assertRaisesRegex(Exception, "run_func_(.*) API call failed "):
|
|
optimized(*input2)
|
|
|
|
def test_index_put_with_none_index(self):
|
|
# index_put falls back in the deterministic mode
|
|
with DeterministicGuard(True):
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, i1, i2, y):
|
|
return torch.ops.aten.index_put(
|
|
x,
|
|
(None, None, i1, i2.transpose(0, 1)),
|
|
y,
|
|
accumulate=True,
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.rand(8, 192, 30, 30, device=self.device),
|
|
torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
|
|
torch.ones(14, 3, dtype=torch.int64, device=self.device),
|
|
torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
|
|
def test_runtime_checks(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
if SM80OrLater:
|
|
|
|
def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9):
|
|
return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9)
|
|
|
|
else:
|
|
|
|
def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9):
|
|
return (x0, x1, x2, x4, x5, x6, x7, x8, x9)
|
|
|
|
inputs = []
|
|
dtypes = [
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.bool,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.uint8,
|
|
]
|
|
if SM80OrLater:
|
|
dtypes.append(torch.bfloat16)
|
|
for dtype in dtypes:
|
|
inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device))
|
|
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dim1 = Dim("s1", min=2, max=512)
|
|
dim2 = Dim("s2", min=2, max=128)
|
|
dynamic_shapes = {
|
|
"x0": {0: dim0},
|
|
"x1": {0: dim0},
|
|
"x2": {0: dim0},
|
|
"x4": {1: dim1},
|
|
"x5": {1: dim1},
|
|
"x6": {},
|
|
"x7": {2: dim2},
|
|
"x8": {2: dim2},
|
|
"x9": {2: dim2},
|
|
}
|
|
if SM80OrLater:
|
|
dynamic_shapes["x3"] = {1: dim1}
|
|
|
|
m = Model()
|
|
inputs = tuple(inputs)
|
|
with torch.no_grad():
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
m, inputs, dynamic_shapes=dynamic_shapes
|
|
)
|
|
with open(os.path.splitext(so_path)[0] + ".cpp") as cpp:
|
|
src_code = cpp.read()
|
|
FileCheck().check_count(
|
|
"unmatched dtype",
|
|
10 if SM80OrLater else 9,
|
|
exactly=True,
|
|
).run(src_code)
|
|
FileCheck().check_count(
|
|
"unmatched dim value at",
|
|
21
|
|
if SM80OrLater
|
|
else 19, # we have 9 dynamic dims for which we generate different checks
|
|
exactly=True,
|
|
).run(src_code)
|
|
FileCheck().check_count(
|
|
"dim value is too",
|
|
18
|
|
if SM80OrLater
|
|
else 16, # we have 9 dynamic dims for which we generate two checks
|
|
exactly=True,
|
|
).run(src_code)
|
|
FileCheck().check_count(
|
|
"unmatched stride value at",
|
|
21
|
|
if SM80OrLater
|
|
else 19, # we have 9 symbolic strides for which we don't generate checks
|
|
exactly=True,
|
|
).run(src_code)
|
|
|
|
self.check_model(m, inputs)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
|
|
def test_runtime_checks_fp8(self):
|
|
# cuda only
|
|
if self.device != "cuda":
|
|
return
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x0, x1):
|
|
t = x0.to(torch.float) + x1.to(torch.float)
|
|
return t
|
|
|
|
inputs = []
|
|
for dtype in (
|
|
torch.float8_e4m3fn,
|
|
torch.float8_e5m2,
|
|
# FP8 funz are for AMD
|
|
# see https://github.com/pytorch/pytorch/issues/126734
|
|
# torch.float8_e4m3fnuz,
|
|
# torch.float8_e5m2fnuz,
|
|
):
|
|
inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device))
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x0": {0: dim0},
|
|
"x1": {0: dim0},
|
|
}
|
|
with torch.no_grad():
|
|
self.check_model(
|
|
Model(),
|
|
tuple(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@skipIfXpu(msg="Total size of kernel arguments exceeds driver limit on XPU")
|
|
def test_runtime_checks_large(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, *inputs):
|
|
result = inputs[0]
|
|
for i in range(1, len(inputs)):
|
|
result = result + inputs[i]
|
|
return result
|
|
|
|
inputs = []
|
|
for i in range(1000):
|
|
inputs.append(torch.ones(8, 8, 8, dtype=torch.float16, device=self.device))
|
|
inputs = tuple(inputs)
|
|
model = Model()
|
|
with torch.no_grad():
|
|
AOTIRunnerUtil.compile(
|
|
model,
|
|
inputs,
|
|
)
|
|
|
|
def test_runtime_checks_complex(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x0, x1, x2):
|
|
return (x0, x1, x2)
|
|
|
|
inputs = []
|
|
x0 = torch.tensor([1, -1], dtype=torch.complex32, device=self.device)
|
|
x1 = torch.tensor(
|
|
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
|
|
dtype=torch.complex64,
|
|
device=self.device,
|
|
)
|
|
x2 = torch.tensor(128, dtype=torch.complex128, device=self.device)
|
|
inputs.append(x0)
|
|
inputs.append(x1)
|
|
inputs.append(x2)
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x0": {0: dim0},
|
|
"x1": {},
|
|
"x2": {},
|
|
}
|
|
with torch.no_grad():
|
|
self.check_model(
|
|
Model(),
|
|
tuple(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
|
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
|
|
def test_runtime_checks_dtype_failed(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
y = x.type(torch.float)
|
|
return y
|
|
|
|
x = torch.randn(1, 4, dtype=torch.float16, device=self.device)
|
|
model = Model()
|
|
with torch.no_grad():
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
model,
|
|
(x,),
|
|
)
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
x_casted = x.float()
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(x_casted)
|
|
|
|
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
|
|
def test_runtime_checks_device_type_failed(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
x = torch.randn(1, 4, dtype=torch.float16, device="cpu")
|
|
model = Model()
|
|
with torch.no_grad():
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
model,
|
|
(x,),
|
|
)
|
|
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
aot_inductor_module(x)
|
|
x_casted = x.to(GPU_TYPE)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(x_casted)
|
|
|
|
def test_non_contiguous_output_alias(self):
|
|
# Test return x, x.contiguous() where x is non-contiguous.
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
squared = x * x
|
|
transposed = squared.t() # non-contiguous
|
|
contig = transposed.contiguous()
|
|
return transposed, contig
|
|
|
|
x = torch.randn(3, 4, dtype=torch.float16, device=self.device)
|
|
model = Model()
|
|
with torch.no_grad():
|
|
result = AOTIRunnerUtil.run(
|
|
model,
|
|
(x,),
|
|
)
|
|
actual = model(x)
|
|
self.assertTrue(same(result, actual))
|
|
|
|
# contiguous() should create a new tensor
|
|
self.assertTrue(result[0].data_ptr() != result[1].data_ptr())
|
|
|
|
def test_multiple_output_alias(self):
|
|
# Test when mutliple outputs alias the same tensor
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
squared = x * x
|
|
contig = squared.contiguous() # alias
|
|
reshaped = squared.reshape(squared.shape) # alias
|
|
cubed = squared * x
|
|
return squared, contig, reshaped, cubed
|
|
|
|
x = torch.randn(3, 4, dtype=torch.float32, device=self.device)
|
|
model = Model()
|
|
|
|
with torch.no_grad():
|
|
result = AOTIRunnerUtil.run(
|
|
model,
|
|
(x,),
|
|
)
|
|
actual = model(x)
|
|
self.assertTrue(same(result, actual))
|
|
|
|
# squared, contig and reshaped alias the same tensor.
|
|
self.assertTrue(result[0].data_ptr() == result[1].data_ptr())
|
|
self.assertTrue(result[0].data_ptr() == result[2].data_ptr())
|
|
# cubed shouldn't be an alias.
|
|
self.assertTrue(result[0].data_ptr() != result[3].data_ptr())
|
|
|
|
@patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"})
|
|
def test_runtime_checks_shape_failed(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
x = torch.randn(4, 4, 4, dtype=torch.float16, device=self.device)
|
|
y0 = torch.randn(8, 4, 4, dtype=torch.float16, device=self.device)
|
|
y1 = torch.randn(4, 8, 4, dtype=torch.float16, device=self.device)
|
|
y2 = rand_strided(
|
|
(4, 4, 4), (16, 1, 4), dtype=torch.float16, device=self.device
|
|
)
|
|
# batch size is outside of the range
|
|
y3 = torch.randn(2048, 3, 4, dtype=torch.float16, device=self.device)
|
|
y4 = torch.randn(2048, 4, 4, dtype=torch.float16, device=self.device)
|
|
dim0 = Dim("s0", min=4, max=1024)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0},
|
|
}
|
|
model = Model()
|
|
with torch.no_grad():
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
model, (x,), dynamic_shapes=dynamic_shapes
|
|
)
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
# dynamic dim works fine
|
|
_ = aot_inductor_module(y0)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y1)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y2)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y3)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y4)
|
|
|
|
def test_add_complex(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return torch.add(a, b)
|
|
|
|
x = torch.tensor(
|
|
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
|
|
)
|
|
y = torch.tensor(
|
|
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
|
|
)
|
|
self.check_model(Model(), (x, y))
|
|
|
|
def test_embedding_bag(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, w, i, o):
|
|
return torch.ops.aten._embedding_bag(w, i, o, False, 0, False, None)
|
|
|
|
example_inputs = (
|
|
torch.randn([10, 4], device=self.device),
|
|
torch.randint(10, [8], device=self.device),
|
|
torch.tensor([0, 2, 6], device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_fft_c2c(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.fft.fftn(x), torch.fft.fftn(x).real
|
|
|
|
example_inputs = (torch.randn(16, 16, 16, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_bool_input(self):
|
|
# Specialize on whichever branch the example input for b is
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, b):
|
|
if b:
|
|
return x * x
|
|
else:
|
|
return x + x
|
|
|
|
example_inputs = (torch.randn(3, 3, device=self.device), True)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_int_list_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, i):
|
|
return x * i[0] * i[1]
|
|
|
|
example_inputs = (torch.randn(3, 3, device=self.device), [3, 4])
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_nested_tensor_from_jagged(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, values, offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
res = self.mlp(nt)
|
|
return res.values()
|
|
|
|
model = Model().to(device=self.device)
|
|
|
|
example_inputs_1 = (
|
|
torch.randn((15, 128), device=self.device),
|
|
torch.tensor([0, 3, 4, 10, 15], device=self.device),
|
|
)
|
|
|
|
# same "NT batch size", different actual amount of data
|
|
example_inputs_2 = (
|
|
torch.randn((31, 128), device=self.device),
|
|
torch.tensor([0, 1, 20, 25, 31], device=self.device),
|
|
)
|
|
|
|
# same actual amount of data, different "NT batch size"
|
|
example_inputs_3 = (
|
|
torch.randn((15, 128), device=self.device),
|
|
torch.tensor([0, 3, 10, 15], device=self.device),
|
|
)
|
|
|
|
# different "NT batch size"
|
|
example_inputs_4 = (
|
|
torch.randn((37, 128), device=self.device),
|
|
torch.tensor([0, 5, 16, 25, 29, 37], device=self.device),
|
|
)
|
|
|
|
dim0_values = Dim("dim0_values", min=1, max=128)
|
|
dim0_offsets = Dim("dim0_offsets", min=1, max=9)
|
|
dynamic_shapes = {"values": {0: dim0_values}, "offsets": {0: dim0_offsets}}
|
|
example_inputs_list = [
|
|
example_inputs_1,
|
|
example_inputs_2,
|
|
example_inputs_3,
|
|
example_inputs_4,
|
|
]
|
|
for example_input in example_inputs_list:
|
|
actual = AOTIRunnerUtil.legacy_run(
|
|
self.device,
|
|
model,
|
|
example_input,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
self.assertTrue(same(model(*example_input), actual))
|
|
|
|
@common_utils.parametrize("max_autotune", [True, False])
|
|
def test_misc_1(self, max_autotune):
|
|
if self.device == "cpu" and IS_MACOS and max_autotune:
|
|
raise unittest.SkipTest("max_autotune not supported on macos")
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
|
|
)
|
|
self.emb = nn.EmbeddingBag(num_embeddings=128, embedding_dim=32)
|
|
self.over_arch = nn.Sequential(
|
|
nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 32), nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x, y):
|
|
mlp_output = self.mlp(x)
|
|
emb_output = self.emb(y)
|
|
return self.over_arch(torch.concat([mlp_output, emb_output], dim=1))
|
|
|
|
example_inputs = (
|
|
torch.randn(16, 128, device=self.device),
|
|
torch.randint(0, 128, (16, 10), device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model(), example_inputs, options=dict(max_autotune=max_autotune)
|
|
)
|
|
|
|
@skip_if_no_torchvision
|
|
def test_torchvision_transforms_functional_tensor_resize(self):
|
|
import torchvision
|
|
|
|
# https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/
|
|
class A(torch.nn.Module):
|
|
def forward(self, image: torch.Tensor, target_size: torch.Tensor):
|
|
target_h, target_w = target_size.tolist()
|
|
torch._check(target_h > 0)
|
|
torch._check(target_w > 0)
|
|
torch._check(target_h <= 4000)
|
|
torch._check(target_w <= 4000)
|
|
|
|
return torchvision.transforms._functional_tensor.resize(
|
|
image,
|
|
size=[target_h, target_w],
|
|
interpolation="bilinear",
|
|
antialias=False,
|
|
)
|
|
|
|
model = A()
|
|
example_inputs = (
|
|
torch.ones([3, 800, 600], device=self.device),
|
|
torch.tensor([448, 336], device=self.device),
|
|
)
|
|
dynamic_shapes = {
|
|
"image": {
|
|
1: torch.export.Dim("height", min=1, max=4000),
|
|
2: torch.export.Dim("width", min=1, max=4000),
|
|
},
|
|
"target_size": None,
|
|
}
|
|
self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_aoti_debug_printer_codegen(self):
|
|
# basic addmm model to test codegen for aoti intermediate debug printer
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
|
|
kernel_calls = (
|
|
[
|
|
("triton_poi_fused_0", 1),
|
|
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
|
|
]
|
|
if self.device == GPU_TYPE
|
|
else [
|
|
("aoti_torch_cpu_addmm_out", 2),
|
|
]
|
|
)
|
|
|
|
# test default debug printing all tensor values codegen
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.legacy_compile, model, example_inputs
|
|
)
|
|
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
|
|
# check the codegen for debug printing around the actual kernel call is expected
|
|
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
# test printing selected kernel's tensor values codegen
|
|
filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
|
|
with config.patch(
|
|
{
|
|
"aot_inductor.debug_intermediate_value_printer": "2",
|
|
"aot_inductor.filtered_kernel_names": filtered_kernel_name,
|
|
}
|
|
):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.legacy_compile, model, example_inputs
|
|
)
|
|
filtered_kernel_calls = [
|
|
(filtered_kernel_name, 2),
|
|
]
|
|
for kernel_call, count in filtered_kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
kernel_calls_not_to_print = [
|
|
kernel_call
|
|
for kernel_call in kernel_calls
|
|
if kernel_call[0] != filtered_kernel_name
|
|
]
|
|
for kernel_name, _ in kernel_calls_not_to_print:
|
|
FileCheck().check_not(f"before_launch - {kernel_name}").run(code)
|
|
FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
|
|
|
|
@common_utils.parametrize("enable_kernel_profile", (True, False))
|
|
def test_aoti_profiler(self, enable_kernel_profile):
|
|
# basic addmm model
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
if sys.platform not in ["linux", "win32"]:
|
|
raise unittest.SkipTest(
|
|
"enable_kernel_profile only supported on linux and win32"
|
|
)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
kernel_calls = (
|
|
f"aoti_torch_{GPU_TYPE}_addmm_out"
|
|
if self.device == GPU_TYPE
|
|
else "aoti_torch_cpu_addmm_out"
|
|
)
|
|
with config.patch({"cpp.enable_kernel_profile": enable_kernel_profile}):
|
|
_, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
shim_fn_codes = (
|
|
f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef<c10::IValue>());'
|
|
)
|
|
if enable_kernel_profile:
|
|
FileCheck().check(shim_fn_codes).run(code)
|
|
else:
|
|
FileCheck().check_not(shim_fn_codes).run(code)
|
|
|
|
def test_aoti_debug_printer_user_defined_triton_kernel(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=self.device),
|
|
torch.randn(4, 4, device=self.device),
|
|
)
|
|
|
|
kernel_calls = [
|
|
("add_kernel_0", 3),
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
# check the codegen for debug printing around the actual kernel call is expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_aoti_debug_printer_cpp_kernel(self):
|
|
if self.device != "cpu":
|
|
raise unittest.SkipTest("cpu test case only")
|
|
|
|
# a simple cpp kernel test case for testing the debug printer codegen
|
|
# on cpp kernel cpu device.
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
t = torch.tensor(x.size(-1), device="cpu", dtype=torch.float)
|
|
t = torch.sqrt(t * 3)
|
|
return x * t
|
|
|
|
example_inputs = (torch.randn(4, 4, device="cpu"),)
|
|
|
|
kernel_calls = [
|
|
("cpp_fused_mul_sqrt_0", 2),
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
# check the codegen for debug printing around the actual kernel call is expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_aoti_debug_printer_sym_inputs(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
from torch.testing._internal.triton_utils import add_kernel
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
maxlen = max(x.item(), 512)
|
|
a = torch.ones(maxlen, device=GPU_TYPE)
|
|
b = torch.ones(maxlen, device=GPU_TYPE)
|
|
out = torch.zeros_like(a)
|
|
# unbacked symint in grid
|
|
add_kernel[(1, 1, maxlen)](a, b, out, maxlen, 32)
|
|
return out
|
|
|
|
example_inputs = (torch.randint(high=1024, size=(1,), device=self.device),)
|
|
|
|
expected_scalar_args = [
|
|
"triton_poi_fused_zeros_like_0_xnumel",
|
|
"triton_poi_fused_1_xnumel",
|
|
"std::max(static_cast<int64_t>(512L), static_cast<int64_t>(u0))",
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
for scalar in expected_scalar_args:
|
|
FileCheck().check_count(
|
|
f"{scalar}",
|
|
2,
|
|
).run(code)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
|
@skipIfXpu
|
|
def test_aoti_debug_printer_fp8_dtype(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, dtype):
|
|
super().__init__()
|
|
self.out_dtype = dtype
|
|
|
|
def forward(self, x, weight, bias, scale_a, scale_b):
|
|
weight = weight.to(torch.float8_e4m3fn)
|
|
output = torch._scaled_mm(
|
|
x,
|
|
weight,
|
|
bias=input_bias,
|
|
out_dtype=self.out_dtype,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
return output
|
|
|
|
dtype = torch.float16
|
|
|
|
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
|
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
|
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
|
|
weight_shape = (32, 16)
|
|
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
|
|
a_inverse_scale = 1 / a_scale
|
|
b_inverse_scale = 1 / b_scale
|
|
|
|
x_shape = (16, 16)
|
|
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
|
|
|
|
kernel_calls = [
|
|
(f"aoti_torch_{GPU_TYPE}__scaled_mm_out", 5),
|
|
]
|
|
|
|
# test default debug printing all tensor values codegen
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.legacy_compile,
|
|
Model(dtype),
|
|
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
|
|
)
|
|
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
|
|
# check the codegen for debug printing around the actual kernel call is expected and float8 dtype is printed as expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_aoti_debug_printing_model_inputs_codegen(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("requires CUDA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c):
|
|
x = a * 3.14
|
|
y = torch.addmm(c, x, b)
|
|
z = torch.nn.functional.gelu(y)
|
|
return z
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 20, device="cuda"),
|
|
torch.randn(20, 30, device="cuda"),
|
|
torch.randn(10, 30, device="cuda"),
|
|
)
|
|
model = Model()
|
|
kernel_calls = [
|
|
("aoti_model_inputs", 3),
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
|
|
# check if the triton kernel is printed as comment
|
|
self.assertEqual("def triton_" in code, True)
|
|
|
|
# check the codegen for debug printing around aoti model inputs is expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"{kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_size_from_multi_output(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
_x, _i = torch.unique(x, sorted=True, return_inverse=True)
|
|
_x = _x.detach().clone()
|
|
return self.relu(_x), _i
|
|
|
|
example_inputs = (torch.randn(8, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@dynamo_config.patch({"capture_scalar_outputs": True})
|
|
def test_sym_i64_input_codegen(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
from torch.testing._internal.triton_utils import add_kernel
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x_symint = x.item()
|
|
a = torch.ones(x_symint, device=GPU_TYPE)
|
|
b = torch.ones(x_symint, device=GPU_TYPE)
|
|
out = torch.zeros_like(a)
|
|
# unbacked symint in grid
|
|
add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randint(high=1024, size=(1,), device=self.device, dtype=torch.int32),
|
|
)
|
|
# This simple unit test case model generates two triton kernels:
|
|
# 1. triton_poi_fused_ones_1:
|
|
# triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i64'}
|
|
# 2. add_kernel:
|
|
# triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr': '*fp32', 'n_elements': 'i64'}
|
|
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
|
|
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
|
|
expected_scalar_args = [
|
|
"buf3, u0",
|
|
"buf4, u0",
|
|
"buf4, buf5, buf3, u0",
|
|
]
|
|
if full_aoti_runtime_assert():
|
|
# we'll have one more assertion
|
|
expected_scalar_args = [
|
|
"buf4, u0",
|
|
"buf5, u0",
|
|
"buf5, buf6, buf4, u0",
|
|
]
|
|
# check the new behavior of codegen is expected
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
for scalar_line in expected_scalar_args:
|
|
FileCheck().check_count(
|
|
scalar_line,
|
|
1,
|
|
).run(code)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_input_codegen_with_sympy_expr(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, getitem_54, getitem_52, getitem_19, values_2, offsets):
|
|
bitwise_or = torch.bitwise_or(getitem_54, getitem_52)
|
|
combined = torch.cat([getitem_19, values_2], dim=0)
|
|
add = combined + bitwise_or
|
|
|
|
sliced = values_2[:-1] + offsets
|
|
return add, sliced
|
|
|
|
inps = (
|
|
torch.randint(0, 1, (240,), device=GPU_TYPE, dtype=torch.uint8),
|
|
torch.randint(0, 1, (240,), device=GPU_TYPE, dtype=torch.uint8),
|
|
torch.randn((192,), device=GPU_TYPE),
|
|
torch.randn((48,), device=GPU_TYPE),
|
|
torch.randint(0, 100, (47,), device=GPU_TYPE, dtype=torch.uint8),
|
|
)
|
|
|
|
dim = torch.export.Dim("dimensionality")
|
|
derived_dim = 2 * dim
|
|
spec = {
|
|
"getitem_54": (Dim.AUTO,), # [s33 + 2*s40 + 1]
|
|
"getitem_52": (Dim.AUTO,), # [s33 + 2*s40 + 1]
|
|
"getitem_19": (derived_dim,), # [2*s40]
|
|
"values_2": (Dim.AUTO,), # [s33 + 1]
|
|
"offsets": (Dim.AUTO,), # [s33]
|
|
}
|
|
|
|
self.check_model(MyModel(), inps, dynamic_shapes=spec)
|
|
|
|
@common_utils.parametrize("mark_unbacked", (True, False))
|
|
def test_unbacked_equals_input_size_runtime_assertion(self, mark_unbacked: bool):
|
|
# This test checks the unbacked symint runtime assertions, for the following cases:
|
|
# (A) an unbacked symint equals an unbacked symint (mark_unbacked=True)
|
|
# (B) an unbacked symint equals a backed symint (mark_unbacked=False)
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, c):
|
|
nz = torch.nonzero(a)
|
|
ones = a.new_ones([nz.size(0), b.size(0)])
|
|
torch._check(ones.size(0) >= 1)
|
|
equals = torch.add(ones, c)
|
|
return equals
|
|
|
|
model = Model()
|
|
example_inputs = (
|
|
torch.ones(64, device=self.device),
|
|
b := torch.randn((32,), device=self.device),
|
|
c := torch.randn((64, 32), device=self.device),
|
|
)
|
|
if mark_unbacked:
|
|
torch._dynamo.decorators.mark_unbacked(c, 0)
|
|
else:
|
|
torch._dynamo.mark_dynamic(c, 0)
|
|
|
|
# Check the runtime assertion is codegen'ed.
|
|
so_path, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.legacy_compile, model, example_inputs
|
|
)
|
|
lowerbound_check = "u1 >= 1" if mark_unbacked else "u0 >= 2"
|
|
FileCheck().check_count(lowerbound_check, 1).run(code)
|
|
|
|
compiled = AOTIRunnerUtil.legacy_load(self.device, so_path)
|
|
compiled(*example_inputs)
|
|
|
|
# Check the runtime assertion.
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
unexpected_inputs = (torch.ones(0, device=self.device), b, c)
|
|
compiled(*unexpected_inputs)
|
|
|
|
# Try it again without runtime assertions.
|
|
with config.patch({"scalar_asserts": False}):
|
|
AOTIRunnerUtil.run_multiple(model, [example_inputs, unexpected_inputs])
|
|
|
|
def test_none_args_aot_codegen(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def sin_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
# We want to include an arg known to be 1 at compile time
|
|
# This is because we remove None args from the arg list; changing the eq_1/constexpr arg indices.
|
|
# We want to make sure we recompute these correctly
|
|
EQ_1_ARG,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
if in_ptr0 is not None:
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
else:
|
|
x = 0.0
|
|
output = tl.sin(x) + EQ_1_ARG
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def sin_triton(x, out):
|
|
n_elements = out.numel()
|
|
sin_kernel[(n_elements,)](x, out, 1, n_elements)
|
|
return out
|
|
|
|
x = torch.randn(65, device=self.device)
|
|
out = torch.empty_like(x)
|
|
|
|
not_none_inputs = (x, out)
|
|
none_inputs = (None, out)
|
|
|
|
# AOTI compilation specializes on either None or non-None inputs
|
|
# So we have to check twice here
|
|
|
|
self.check_model(sin_triton, none_inputs)
|
|
self.check_model(sin_triton, not_none_inputs)
|
|
|
|
def test_issue_140766(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mlp = torch.nn.Sequential(
|
|
torch.nn.Linear(128, 512),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(512, 128),
|
|
)
|
|
self.norm = torch.nn.LayerNorm(128)
|
|
self.attn = torch.nn.functional.scaled_dot_product_attention
|
|
|
|
def forward(self, x):
|
|
# [2, 128, 4096]
|
|
x = x.transpose(1, 2)
|
|
# [2, 4096, 128]
|
|
for _ in range(2):
|
|
x = self.forward_block(x)
|
|
return x
|
|
|
|
def forward_block(self, x):
|
|
# x: B, H*W, C
|
|
B = x.shape[0]
|
|
H, W, C = 64, 64, 128
|
|
shortcut = x
|
|
x = self.norm(x)
|
|
x = x.reshape(B, H, W, C)
|
|
# B, H, W, C
|
|
x = self.attn(x, x, x)
|
|
x = x.reshape(B, H // 8, W // 8, 8, 8, -1)
|
|
x = x.transpose(2, 3).reshape(B, H * W, -1)
|
|
|
|
x = shortcut + x
|
|
x = x + self.mlp(self.norm(x))
|
|
return x
|
|
|
|
bs = torch.export.Dim("bs", max=12)
|
|
example_inputs = (torch.randn(2, 128, 4096, device=self.device),)
|
|
self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}})
|
|
|
|
def test_so_without_weight(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 128, 2048, 4096
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"always_keep_tensor_constants": True,
|
|
"aot_inductor.package_constants_in_so": True,
|
|
}
|
|
):
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"always_keep_tensor_constants": True,
|
|
"aot_inductor.package_constants_in_so": False,
|
|
}
|
|
):
|
|
so_path_weightless = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
self.assertTrue(os.path.getsize(so_path) > 10_000_000)
|
|
self.assertTrue(os.path.getsize(so_path_weightless) < 10_000_000)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path_weightless)
|
|
|
|
# Let's check whether the model has correct constant name mapping.
|
|
expected_original_fqns = {
|
|
"L__self___weight": "L__self___weight",
|
|
"L__self___bias": "L__self___bias",
|
|
}
|
|
self.assertEqual(
|
|
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
|
)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
attach_weights = {
|
|
"L__self___weight": model.weight,
|
|
"L__self___bias": model.bias,
|
|
}
|
|
runner.update_constant_buffer(attach_weights, False, False)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
def test_weight_on_disk_legacy(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 128, 2048, 4096
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"always_keep_tensor_constants": True,
|
|
"aot_inductor.package_constants_in_so": False,
|
|
"aot_inductor.package_constants_on_disk": True,
|
|
"aot_inductor.package": True,
|
|
}
|
|
):
|
|
aoti_files = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
|
package_path = package_aoti(
|
|
f.name,
|
|
{"model": aoti_files},
|
|
)
|
|
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
|
|
loaded1 = pt2_contents.aoti_runners["model"]
|
|
|
|
self.assertEqual(loaded1(a), model(a))
|
|
|
|
def test_extract_constants_map(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 8, 6, 16
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
original_weights = {
|
|
"L__self___weight": model.weight,
|
|
"L__self___bias": model.bias,
|
|
}
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
|
|
# Extract weights with use_inactive = False, this should be the current weight.
|
|
extracted_original_weights = runner.extract_constants_map(False)
|
|
self.assertEqual(original_weights, extracted_original_weights)
|
|
|
|
# update the inactive weights with new_weights, extract inactive weights.
|
|
runner.update_constant_buffer(new_weights, True, False)
|
|
extracted_new_weights = runner.extract_constants_map(True)
|
|
self.assertEqual(new_weights, extracted_new_weights)
|
|
|
|
# Swap constant buffer, this should give us the opposite weights.
|
|
runner.swap_constant_buffer()
|
|
|
|
extracted_inactive_weights = runner.extract_constants_map(True)
|
|
extracted_active_weights = runner.extract_constants_map(False)
|
|
self.assertEqual(original_weights, extracted_inactive_weights)
|
|
self.assertEqual(new_weights, extracted_active_weights)
|
|
|
|
def test_update_constant_buffer(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 8, 6, 16
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
# Attribute naming has changed in the new export API, so still use the legacy API here.
|
|
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
# Let's check whether the model has correct constant name mapping.
|
|
expected_original_fqns = {
|
|
"L__self___weight": "L__self___weight",
|
|
"L__self___bias": "L__self___bias",
|
|
}
|
|
self.assertEqual(
|
|
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
|
)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
runner.update_constant_buffer(new_weights, False, False)
|
|
new_output = runner_call(test_inputs)
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
self.assertEqual(new_expected, new_output)
|
|
|
|
def test_update_inactive_constant_buffer(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 8, 6, 16
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
|
|
runner.update_constant_buffer(new_weights, True, False)
|
|
output_before_swap = runner_call(test_inputs)
|
|
runner.swap_constant_buffer()
|
|
output_after_swap = runner_call(test_inputs)
|
|
|
|
self.assertEqual(expected, output_before_swap)
|
|
self.assertEqual(new_expected, output_after_swap)
|
|
|
|
def test_free_inactive_buffer(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 8, 6, 16
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
# Check the outputs, make sure the model is correct here.
|
|
self.assertEqual(expected, output)
|
|
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
runner.update_constant_buffer(new_weights, True, False)
|
|
|
|
# Make sure we have swapped buffer
|
|
runner.swap_constant_buffer()
|
|
output_after_swap = runner_call(test_inputs)
|
|
self.assertEqual(new_expected, output_after_swap)
|
|
|
|
# Free the secondary buffer
|
|
runner.free_inactive_constant_buffer()
|
|
|
|
# Create a new set of weights to refill into the already freed buffer.
|
|
new_weights_1 = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
new_expected_1 = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
runner.update_constant_buffer(new_weights_1, True, False)
|
|
|
|
output_after_swap_1 = runner_call(test_inputs)
|
|
self.assertEqual(new_expected_1, output_after_swap_1)
|
|
|
|
runner.free_inactive_constant_buffer()
|
|
|
|
def test_update_user_managed_buffer(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("requires CUDA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 1024, 4096, 4096
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
# Attribute naming has changed in the new export API, so still use the legacy API here.
|
|
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
|
so_path = AOTIRunnerUtil.legacy_compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
mem_before, _ = torch.cuda.mem_get_info(self.device)
|
|
# Do not use user managed_buffer, should have less free memory.
|
|
runner.update_constant_buffer(new_weights, True, False, False)
|
|
mem_after, _ = torch.cuda.mem_get_info(self.device)
|
|
self.assertGreater(mem_before, mem_after)
|
|
|
|
runner.swap_constant_buffer()
|
|
new_output = runner_call(test_inputs)
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
self.assertEqual(new_expected, new_output)
|
|
|
|
# Inplace substitube tensor, without user managed buffer, result should be different.
|
|
new_weights["L__self___weight"].add_(1)
|
|
new_weights["L__self___bias"].add_(1)
|
|
|
|
new_output = runner_call(test_inputs)
|
|
# Same as the previous result
|
|
self.assertEqual(new_expected, new_output)
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
# Differ from latest result
|
|
self.assertNotEqual(new_expected, new_output)
|
|
|
|
# Clear out all buffers
|
|
runner.free_inactive_constant_buffer()
|
|
runner.swap_constant_buffer()
|
|
runner.free_inactive_constant_buffer()
|
|
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
mem_before, _ = torch.cuda.mem_get_info(self.device)
|
|
# Try user managed_buffer, should have same free memory.
|
|
runner.update_constant_buffer(new_weights, True, False, True)
|
|
mem_after, _ = torch.cuda.mem_get_info(self.device)
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
runner.swap_constant_buffer()
|
|
new_output = runner_call(test_inputs)
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
self.assertEqual(new_expected, new_output)
|
|
|
|
# Inplace substitube tensor, with user managed buffer, result should be the same.
|
|
new_weights["L__self___weight"].add_(1)
|
|
new_weights["L__self___bias"].add_(1)
|
|
|
|
new_output = runner_call(test_inputs)
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
self.assertEqual(new_expected, new_output)
|
|
|
|
def test_cond_share_predicte(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, predicate, x):
|
|
y = torch.cond(
|
|
predicate,
|
|
lambda: x + 1,
|
|
lambda: x + 2,
|
|
)
|
|
|
|
z = torch.cond(
|
|
predicate,
|
|
lambda: y + 1,
|
|
lambda: y + 2,
|
|
)
|
|
return (z,)
|
|
|
|
example_inputs = (
|
|
torch.tensor([True]).to(self.device),
|
|
torch.tensor([1, 2, 3]).to(self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"To enable after the C shim FC window ends",
|
|
)
|
|
def test_misaligned_input_1(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("CUDA test only")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sin() + x.cos()
|
|
|
|
N = 64 * 64 * 64 + 64
|
|
arg = torch.randn(N, device=self.device)
|
|
example_inputs = (arg,)
|
|
model = Model()
|
|
expected = model(*example_inputs)
|
|
package_path = AOTIRunnerUtil.compile(model, example_inputs)
|
|
optimized = torch._inductor.aoti_load_package(package_path)
|
|
# If the model is compiled with aligned inputs, the generated
|
|
# code will check inputs alignment at runtime
|
|
self.code_check_count(
|
|
model, example_inputs, "aoti_torch_clone_preserve_strides", 1
|
|
)
|
|
|
|
misaligned_arg = torch.zeros(N + 1, device=self.device)
|
|
misaligned_arg = misaligned_arg[1:]
|
|
misaligned_arg.copy_(arg)
|
|
actual = optimized(misaligned_arg)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_misaligned_input_2(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("CUDA test only")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sin() + x.cos()
|
|
|
|
N = 64 * 64 * 64 + 64
|
|
arg = torch.randn(N, device=self.device)
|
|
misaligned_arg = torch.zeros(N + 1, device=self.device)
|
|
misaligned_arg = misaligned_arg[1:]
|
|
misaligned_arg.copy_(arg)
|
|
example_inputs = (misaligned_arg,)
|
|
|
|
model = Model()
|
|
self.check_model(model, example_inputs)
|
|
# If the model is already compiled with a misaligned input, the
|
|
# generated code should NOT contain an alignment check for that input.
|
|
self.code_check_count(
|
|
model, example_inputs, "aoti_torch_clone_preserve_strides", 0
|
|
)
|
|
|
|
def test_autotuning_args_reuse(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x_out = torch.empty_strided(
|
|
(x.size()[0], x.size()[1]), (x.size()[1], 1), device=GPU_TYPE
|
|
)
|
|
x_out = torch.permute(x_out, [0, 1])
|
|
add_kernel_autotuned[(4,)](x, x, x_out, 16)
|
|
|
|
y_out = torch.empty_strided(
|
|
(y.size()[0], y.size()[1]), (y.size()[1], 1), device=GPU_TYPE
|
|
)
|
|
y_out = torch.permute(y_out, [0, 1])
|
|
add_kernel_autotuned[(64,)](y, y, y_out, 64)
|
|
|
|
sub_kernel_autotuned[(4,)](x, x, x_out, 16)
|
|
|
|
return x_out, y_out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(8, 8, device=GPU_TYPE),
|
|
)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dim0_y = Dim("dim0_y", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
options={"max_autotune": True},
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
|
def test_stft(self):
|
|
N_FFT = 400
|
|
HOP_LENGTH = 160
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
window = torch.hann_window(N_FFT).to(x.device)
|
|
stft = torch.stft(
|
|
x, N_FFT, HOP_LENGTH, window=window, return_complex=True
|
|
)
|
|
magnitudes = stft[..., :-1].abs() ** 2
|
|
return magnitudes
|
|
|
|
model = Model()
|
|
example_inputs = (torch.randn(500, device=self.device),)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_conv3d(self):
|
|
if self.device != GPU_TYPE or not is_big_gpu():
|
|
raise unittest.SkipTest("requires modern GPU to run max-autotune")
|
|
|
|
if not _has_sufficient_memory(self.device, 2**35):
|
|
raise unittest.SkipTest("insufficient memory")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
convert_element_type_1271,
|
|
convert_element_type_1272,
|
|
convert_element_type_1273,
|
|
):
|
|
return torch.ops.aten.convolution.default(
|
|
convert_element_type_1271,
|
|
convert_element_type_1272,
|
|
convert_element_type_1273,
|
|
[1, 1],
|
|
[1, 1],
|
|
[1, 1],
|
|
False,
|
|
[0, 0],
|
|
1,
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 64, 5160, 5160, device=self.device),
|
|
torch.randn(3, 64, 3, 3, device=self.device),
|
|
torch.randn(3, device=self.device),
|
|
)
|
|
dynamic_shapes = {
|
|
"convert_element_type_1271": {
|
|
3: torch.export.Dim.DYNAMIC,
|
|
4: torch.export.Dim.DYNAMIC,
|
|
},
|
|
"convert_element_type_1272": None,
|
|
"convert_element_type_1273": None,
|
|
}
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_conv_backends": "TRITON",
|
|
}
|
|
):
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
atol=0.1,
|
|
rtol=1e-3,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@skipIfXpu(
|
|
msg="The operator 'aten::_int_mm' is not currently implemented for the XPU device"
|
|
)
|
|
def test__int_mm(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return torch._int_mm(x, y)
|
|
|
|
example_inputs = (
|
|
torch.randint(-10, 10, (64, 32), device=self.device, dtype=torch.int8),
|
|
torch.randint(-10, 10, (32, 64), device=self.device, dtype=torch.int8),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@skipIfXpu(
|
|
msg="aten::convert_weight_to_int4pack is not currently implemented for XPU"
|
|
)
|
|
@parametrize("m", [32])
|
|
@parametrize("n", [64])
|
|
@parametrize("q_group", [32, 64])
|
|
@parametrize("num_groups", [1, 2])
|
|
def test__weight_int4pack_mm(self, m, n, q_group, num_groups):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, weight, scale_and_zeros) -> None:
|
|
super().__init__()
|
|
self.weight = weight
|
|
self.scale_and_zeros = scale_and_zeros
|
|
|
|
def forward(self, a):
|
|
return torch._weight_int4pack_mm(
|
|
a, self.weight, q_group, self.scale_and_zeros
|
|
)
|
|
|
|
def convert_weight_to_int4pack(b):
|
|
b_int32, b_scales_and_zeros = _group_quantize_tensor(
|
|
b, n_bit=4, q_group_size=q_group
|
|
)
|
|
b_int4pack = torch._convert_weight_to_int4pack(b_int32, innerKTiles=2)
|
|
return b_int4pack, b_scales_and_zeros
|
|
|
|
k = q_group * num_groups
|
|
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
|
|
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
|
|
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b)
|
|
model = Model(b_int4pack, b_scales_and_zeros_f32)
|
|
self.check_model(model, (a,))
|
|
|
|
@parametrize("m", [32])
|
|
@parametrize("n", [64])
|
|
@parametrize("q_group", [32, 64])
|
|
@parametrize("num_groups", [1, 2])
|
|
def test__weight_int4pack_mm_with_scales_and_zeros(self, m, n, q_group, num_groups):
|
|
if "xpu" not in self.device:
|
|
raise unittest.SkipTest("requires Intel GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, weight, scale, zeros) -> None:
|
|
super().__init__()
|
|
self.weight = weight
|
|
self.scale = scale
|
|
self.zeros = zeros
|
|
|
|
def forward(self, a):
|
|
return torch._weight_int4pack_mm_with_scales_and_zeros(
|
|
a, self.weight, q_group, self.scale, self.zeros
|
|
)
|
|
|
|
def _group_quantize_tensor_xpu(w, n_bit=4, q_group_size=16):
|
|
# w [k, n] = [32, 48]
|
|
assert w.dim() == 2
|
|
# w [n, k] = [48, 32]
|
|
w = w.transpose(0, 1).contiguous()
|
|
assert q_group_size > 1
|
|
assert w.shape[-1] % q_group_size == 0
|
|
|
|
# to_quant: [n * k / group_size, group_size]
|
|
to_quant = w.reshape(-1, q_group_size)
|
|
assert torch.isnan(to_quant).sum() == 0
|
|
|
|
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
max_int = 2**n_bit - 1
|
|
min_int = 0
|
|
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
assert torch.isnan(scales).sum() == 0
|
|
|
|
zeros = min_int - min_val.div(scales).round()
|
|
zeros = torch.clamp(zeros, min_int, max_int)
|
|
zeros = zeros.to(torch.int8)
|
|
assert torch.isnan(zeros).sum() == 0
|
|
|
|
out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int)
|
|
assert torch.isnan(out).sum() == 0
|
|
|
|
# [n, k]
|
|
out = out.to(dtype=torch.int32).reshape(w.shape)
|
|
if out.device != torch.device("cpu"):
|
|
out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8)
|
|
|
|
# Scales and zeros for the same q-group should be contiguous, so we can
|
|
# load as a 32-bit word
|
|
scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous()
|
|
zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous()
|
|
|
|
return out, scales, zeros
|
|
|
|
def convert_weight_to_int4pack(b):
|
|
# b_uint8 [n, k //2]
|
|
b_uint8, scales, zeros = _group_quantize_tensor_xpu(
|
|
b, n_bit=4, q_group_size=q_group
|
|
)
|
|
# b_int4pack [k//8, n]
|
|
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=2)
|
|
|
|
return b_int4pack, scales, zeros
|
|
|
|
k = q_group * num_groups
|
|
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
|
|
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
|
|
b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b)
|
|
model = Model(b_int4pack, b_scales, zeros_int8)
|
|
self.check_model(model, (a,))
|
|
|
|
def test_assert_tensor_meta(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch.ops.aten._assert_tensor_metadata.default(
|
|
x,
|
|
dtype=torch.int32,
|
|
)
|
|
return (x + 1,)
|
|
|
|
example_inputs = (torch.tensor(1, dtype=torch.int32),)
|
|
with config.patch(
|
|
{
|
|
"implicit_fallbacks": False,
|
|
}
|
|
):
|
|
self.check_model(
|
|
Module(),
|
|
example_inputs,
|
|
atol=0.1,
|
|
rtol=1e-3,
|
|
)
|
|
|
|
@skipIfRocm # RoCM does not support the config block size in test suite.
|
|
def test_triton_autotuning(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y, m):
|
|
_M, K = x.shape
|
|
K, N = y.shape
|
|
M = torch.abs(m)
|
|
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
|
|
grid = lambda META: ( # noqa: E731
|
|
triton.cdiv(
|
|
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
|
|
),
|
|
)
|
|
strange_config_matmul_kernel[grid](
|
|
x,
|
|
y,
|
|
out,
|
|
M,
|
|
N,
|
|
K,
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(4096, 1024, device=self.device)
|
|
y = torch.randn(1024, 2048, device=self.device)
|
|
m = torch.tensor([4096], dtype=torch.int32, device=self.device)
|
|
|
|
with config.patch("triton.autotune_with_sample_inputs", True):
|
|
# The tuned best config on XPU is different with CUDA.
|
|
grid_0 = 32736 if GPU_TYPE == "xpu" else 1023
|
|
self.code_check_count(
|
|
Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1
|
|
)
|
|
|
|
@skipIfRocm # RoCM does not support the config block size in test suite.
|
|
def test_triton_mutated_autotuning(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.jit
|
|
def add_one_kernel(X, Y, N):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid
|
|
offsets = block_start + tl.arange(0, 1)
|
|
|
|
x = tl.load(X + offsets, mask=offsets < N)
|
|
y = x + 1
|
|
tl.store(Y + offsets, y, mask=offsets < N)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y, m):
|
|
_M, K = x.shape
|
|
K, N = y.shape
|
|
M = torch.empty((1), device=x.device, dtype=torch.int32)
|
|
add_one_kernel[(1,)](m, M, 1)
|
|
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
|
|
grid = lambda META: ( # noqa: E731
|
|
triton.cdiv(
|
|
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
|
|
),
|
|
)
|
|
strange_config_matmul_kernel[grid](
|
|
x,
|
|
y,
|
|
out,
|
|
M,
|
|
N,
|
|
K,
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(4096, 1024, device=self.device)
|
|
y = torch.randn(1024, 2048, device=self.device)
|
|
m = torch.tensor([4095], dtype=torch.int32, device=self.device)
|
|
|
|
with config.patch("triton.autotune_with_sample_inputs", True):
|
|
# The tuned best config on XPU is different with CUDA.
|
|
grid_0 = 32736 if GPU_TYPE == "xpu" else 1023
|
|
self.code_check_count(
|
|
Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1
|
|
)
|
|
|
|
@skipIfRocm
|
|
@patch.dict(os.environ, {"TRITON_DEBUG": "1"})
|
|
def test_triton_dynamic_launcher_grid(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=["numel"],
|
|
)
|
|
@triton.jit
|
|
def add_one_kernel(X, Y, numel, BLOCK_SIZE: "tl.constexpr"):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
tl.device_assert(block_start < numel)
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
|
|
x = tl.load(X + offsets)
|
|
y = x + 1
|
|
tl.store(Y + offsets, y)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, value):
|
|
numel = value.item()
|
|
out = torch.zeros_like(x, dtype=torch.float16)
|
|
|
|
grid = lambda META: ( # noqa: E731
|
|
triton.cdiv(numel, META["BLOCK_SIZE"]),
|
|
)
|
|
add_one_kernel[grid](x, out, numel)
|
|
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(1024, device=self.device),
|
|
torch.tensor([1024], dtype=torch.int32, device=self.device),
|
|
)
|
|
|
|
with config.patch("triton.autotune_with_sample_inputs", True):
|
|
dim0_x = Dim("dim0_x", min=2, max=8192)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "value": {0: Dim.AUTO}}
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
@skipIfRocm
|
|
@patch.dict(os.environ, {"TRITON_DEBUG": "1"})
|
|
def test_triton_dynamic_launcher_grid_infer_from_tensor(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=["numel"],
|
|
)
|
|
@triton.jit
|
|
def add_one_kernel(X, Y, numel, BLOCK_SIZE: "tl.constexpr"):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
tl.device_assert(block_start < numel)
|
|
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(X + offsets)
|
|
y = x + 1
|
|
tl.store(Y + offsets, y)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, dim_D):
|
|
numel = x.shape[1] * dim_D.item()
|
|
x = x.repeat(dim_D, 1)
|
|
out = torch.zeros_like(x, dtype=torch.float16)
|
|
|
|
grid = lambda META: ( # noqa: E731
|
|
triton.cdiv(numel, META["BLOCK_SIZE"]),
|
|
)
|
|
add_one_kernel[grid](x, out, numel)
|
|
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 1024, device=self.device),
|
|
torch.tensor([2], dtype=torch.int32, device=self.device),
|
|
)
|
|
|
|
with config.patch("triton.autotune_with_sample_inputs", True):
|
|
dim1_x = Dim("dim1_x", min=2, max=8192)
|
|
dynamic_shapes = {"x": {0: Dim.AUTO, 1: dim1_x}, "dim_D": {0: Dim.AUTO}}
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_composed_dynamic_size(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
example_inputs = (torch.randn(10, device=self.device),)
|
|
dim = torch.export.Dim("dim_0")
|
|
dim_even = 2 * dim
|
|
dynamic_shapes = {
|
|
"x": {0: dim_even},
|
|
}
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_with_cudagraphs(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("requires CUDA")
|
|
|
|
# define CUDAGraph handling wrapper (only works with kwargs for simplicity)
|
|
def cudagraph(f):
|
|
_graphs = {}
|
|
|
|
def f_(**kwargs):
|
|
key = hash(
|
|
tuple(
|
|
tuple(kwargs[a].shape)
|
|
for a in sorted(kwargs.keys())
|
|
if isinstance(kwargs[a], torch.Tensor)
|
|
)
|
|
)
|
|
if key in _graphs:
|
|
wrapped, *_ = _graphs[key]
|
|
return wrapped(**kwargs)
|
|
g = torch.cuda.CUDAGraph()
|
|
in_tensors = {
|
|
k: v.clone() if isinstance(v, torch.Tensor) else v
|
|
for k, v in kwargs.items()
|
|
}
|
|
f(**in_tensors) # stream warmup
|
|
with torch.cuda.graph(g):
|
|
out_tensors = f(**in_tensors)
|
|
|
|
def wrapped(**kwargs):
|
|
for key in kwargs:
|
|
in_tensors[key].copy_(kwargs[key])
|
|
g.replay()
|
|
if isinstance(out_tensors, torch.Tensor):
|
|
return out_tensors.clone()
|
|
elif isinstance(out_tensors, (list, tuple)):
|
|
return type(out_tensors)(o.clone() for o in out_tensors)
|
|
raise ValueError("unsupported output type encountered")
|
|
|
|
_graphs[key] = (wrapped, g, in_tensors, out_tensors)
|
|
return wrapped(**kwargs)
|
|
|
|
return f_
|
|
|
|
# define a simple model
|
|
model = torch.nn.Linear(10, 20).to(device=self.device)
|
|
|
|
# export + AOTI
|
|
model_kwargs = {
|
|
"input": torch.randn(3, 10, device=self.device),
|
|
}
|
|
ep = torch.export.export(model, args=(), kwargs=model_kwargs, strict=True)
|
|
|
|
optimized = torch._inductor.aoti_load_package(
|
|
torch._inductor.aoti_compile_and_package(
|
|
ep,
|
|
inductor_configs={"max_autotune": True},
|
|
),
|
|
# NB: this flag avoids a CUDAGraph + AOTI runtime multi-threading conflict
|
|
# "Error: operation not permitted when stream is capturing"
|
|
run_single_threaded=True,
|
|
)
|
|
|
|
# enable CUDAGraphs
|
|
optimized = cudagraph(optimized)
|
|
|
|
# warmup -> run with CUDAGraphs
|
|
for _ in range(3):
|
|
optimized(**model_kwargs)
|
|
|
|
# compare against eager
|
|
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
|
|
|
|
def test_clamp_decomposition(self):
|
|
class Model1(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(min=1.5)
|
|
|
|
class Model2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(min=2)
|
|
|
|
x = torch.randint(4, (4,))
|
|
|
|
# the output should have float32 type, not int
|
|
self.check_model(Model1(), (x,))
|
|
# the output should have int type
|
|
self.check_model(Model2(), (x,))
|
|
|
|
def test_using_model_name_for_files(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
model = Model().to(self.device)
|
|
with torch.no_grad():
|
|
package_path: str = AOTIRunnerUtil.compile(
|
|
model,
|
|
example_inputs,
|
|
inductor_configs={
|
|
"aot_inductor.model_name_for_generated_files": "test_model"
|
|
},
|
|
)
|
|
|
|
with zipfile.ZipFile(package_path, "r") as zip_ref:
|
|
all_files = zip_ref.namelist()
|
|
base_dir = "test_model.wrapper/data/aotinductor/model/test_model"
|
|
self.assertTrue(f"{base_dir}.wrapper.cpp" in all_files)
|
|
self.assertTrue(f"{base_dir}.kernel.cpp" in all_files)
|
|
self.assertTrue(f"{base_dir}.wrapper.so" in all_files)
|
|
|
|
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
|
|
self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs))
|
|
|
|
|
|
class AOTInductorLoggingTest(LoggingTestCase):
|
|
@make_logging_test(dynamic=logging.DEBUG)
|
|
def test_shape_env_reuse(self, records):
|
|
# make sure ShapeEnv is only created once and reused afterwards
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 2
|
|
|
|
inputs = (torch.randn(4, 4),)
|
|
dynamic_shapes = {
|
|
"x": {0: Dim.AUTO, 1: Dim.AUTO},
|
|
}
|
|
ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False)
|
|
with torch.no_grad():
|
|
torch._inductor.aot_compile(ep.module(), inputs)
|
|
self.assertEqual([r.msg == "create_env" for r in records].count(True), 1)
|
|
|
|
|
|
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
|
|
|
|
|
def fail_cpu(is_skip=False):
|
|
return TestFailure(
|
|
("cpu",),
|
|
is_skip=is_skip,
|
|
)
|
|
|
|
|
|
def fail_gpu(suffixes: tuple[str, ...], is_skip=False):
|
|
return TestFailure(
|
|
suffixes,
|
|
is_skip=is_skip,
|
|
)
|
|
|
|
|
|
# test_failures, xfail by default, set is_skip=True to skip
|
|
CPU_TEST_FAILURES = {
|
|
# TODO: failed internally
|
|
"test_multiple_output_alias": fail_cpu(is_skip=True),
|
|
}
|
|
|
|
# test_failures, xfail by default, set is_skip=True to skip
|
|
GPU_TEST_FAILURES = {
|
|
# quantized unsupported for GPU
|
|
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
|
|
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
|
|
# No scaled_dot_product_efficient_attention implementation for XPU yet.
|
|
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
|
|
# No fft implementation for XPU yet.
|
|
"test_fft_c2c": fail_gpu(("xpu",), is_skip=True),
|
|
}
|
|
|
|
|
|
class AOTInductorTestABICompatibleCpu(TestCase):
|
|
device = "cpu"
|
|
device_type = "cpu"
|
|
check_model = check_model
|
|
check_model_with_multiple_inputs = check_model_with_multiple_inputs
|
|
code_check_count = code_check_count
|
|
allow_stack_allocation = False
|
|
use_minimal_arrayref_interface = False
|
|
|
|
|
|
copy_tests(
|
|
AOTInductorTestsTemplate,
|
|
AOTInductorTestABICompatibleCpu,
|
|
"cpu",
|
|
CPU_TEST_FAILURES,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
|
class AOTInductorTestABICompatibleGpu(TestCase):
|
|
device = GPU_TYPE
|
|
device_type = GPU_TYPE
|
|
check_model = check_model
|
|
check_model_with_multiple_inputs = check_model_with_multiple_inputs
|
|
code_check_count = code_check_count
|
|
allow_stack_allocation = False
|
|
use_minimal_arrayref_interface = False
|
|
|
|
|
|
copy_tests(
|
|
AOTInductorTestsTemplate,
|
|
AOTInductorTestABICompatibleGpu,
|
|
GPU_TYPE,
|
|
GPU_TEST_FAILURES,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
# cpp_extension N/A in fbcode
|
|
if HAS_GPU or sys.platform == "darwin":
|
|
run_tests(needs="filelock")
|