Files
pytorch/test/inductor/test_cuda_repro.py

2579 lines
95 KiB
Python

# Owner(s): ["module: inductor"]
# ruff: noqa: F841
import copy
import functools
import gc
import math
import os
import sys
import unittest
import torch
import torch._dynamo.config as dynamo_config
import torch.backends.cuda
import torch.nn.functional as F
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.runtime.hints import DeviceProperties
from torch._inductor.utils import (
run_and_get_code,
run_and_get_graph_lowering,
run_fw_bw_and_get_code,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
SM80OrLater,
SM90OrLater,
TEST_MULTIGPU,
)
from torch.testing._internal.common_utils import (
DeterministicGuard,
freeze_rng_state,
IS_FBCODE,
MI350_ARCH,
skipIfRocmArch,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
xfailIfPy312Plus,
)
from torch.testing._internal.inductor_utils import IS_BIG_GPU
if TEST_WITH_ROCM:
config.force_layout_optimization = 1
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
requires_multigpu = functools.partial(
unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
)
from torch.testing._internal.inductor_utils import skipCUDAIf
try:
try:
import triton # @manual
from triton import language as tl # @manual
except ImportError:
raise unittest.SkipTest("requires triton") # noqa: B904
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
TestCase = test_torchinductor.TestCase
ToTuple = test_torchinductor.ToTuple
check_model_cuda = test_torchinductor.check_model_cuda
aten = torch.ops.aten
class CudaReproTests(TestCase):
device = "cuda"
common = check_model_cuda
def test_mm_out_dtype_compile(self):
a = torch.randn(1, 3, device="cuda", dtype=torch.float16)
b = torch.randn(3, 2, device="cuda", dtype=torch.float16)
def fn(x, y):
return torch.mm(x, y, out_dtype=torch.float32)
compiled = torch.compile(fn, backend="inductor", fullgraph=True)
result = compiled(a, b)
expected = fn(a, b)
self.assertEqual(result.dtype, expected.dtype)
self.assertEqual(result, expected)
def test_index_put_issue(self):
def forward(
self,
arg76_1,
expand_default,
full_like_default,
_to_copy_default_67,
zeros,
):
sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True)
view_default_57 = torch.ops.aten.view.default(sum_sym_int_19, [512, 768])
where_self = torch.ops.aten.where.self(
expand_default, view_default_57, full_like_default
)
clone_default_12 = torch.ops.aten.clone.default(zeros)
index_put__default = torch.ops.aten.index_put_.default(
clone_default_12, [arg76_1], where_self, True
)
return (index_put__default,)
inps = [
(torch.Size([512]), torch.int64),
(torch.Size([512, 768]), torch.bool),
(torch.Size([512, 768]), torch.float16),
(torch.Size([4, 512, 768]), torch.float16),
(torch.Size([512, 768]), torch.float16),
]
inps = [torch.zeros(())] + [
torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps
]
mod = make_fx(forward)(*inps)
compiled = compile_fx_inner(mod, inps)
compiled(inps)
def test_view_replay_padding_issue_163328(self):
class ReproModule(nn.Module):
def __init__(self):
super().__init__()
self.num_points_out = 120
self.lc_num = 2
input_channels = 16
self.linear_main = nn.Linear(input_channels, self.num_points_out * 2)
self.linear_lc = nn.Linear(input_channels, self.num_points_out * 2)
def forward(self, x: torch.Tensor):
bs, num_lat, num_lon, channels = x.shape
index = num_lat - self.lc_num
main_x = x[:, :index].reshape(bs * index * num_lon, channels)
lc_x = x[:, index:].reshape(bs * self.lc_num * num_lon, channels)
refline = self.linear_main(main_x).reshape(bs, index, num_lon, -1)
lc_refline = self.linear_lc(lc_x).reshape(bs, self.lc_num, num_lon, -1)
base = torch.cat([refline, lc_refline], dim=1).contiguous()
out0 = base.reshape(bs, num_lat, num_lon, self.num_points_out, 2)
out1 = base.reshape(bs, num_lat * num_lon, self.num_points_out * 2)
return {"ten0": out0, "ten1": out1}
torch.manual_seed(0)
model = ReproModule().cuda()
inputs = torch.randn(36, 9, 7, 16, device="cuda", requires_grad=True)
eager_out = model(inputs)
compiled_model = torch.compile(
copy.deepcopy(model),
backend="inductor",
mode="reduce-overhead",
fullgraph=True,
)
compiled_out = compiled_model(inputs)
self.assertEqual(compiled_out["ten0"], eager_out["ten0"])
self.assertEqual(compiled_out["ten1"], eager_out["ten1"])
def test_effn_attn_bias_padding(self):
batch_size, num_heads, seq_len, head_dim = 2, 32, 512, 128
def fn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_tensor: torch.Tensor, # This will be our starting point
):
# Input tensor should be [2, 1, 8192, 1] with appropriate strides
bias = torch.ops.aten.expand(
input_tensor, [2, 32, seq_len, seq_len]
) # Expands with stride pattern [65536, 0, 8, 0]
return torch.ops.aten._scaled_dot_product_efficient_attention(
query,
key,
value,
bias,
compute_log_sumexp=True,
dropout_p=0.0,
is_causal=False,
scale=None,
)
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda")
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda")
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda")
input_tensor = torch.rand([2, 1, seq_len, 1], device="cuda")
out, code = run_and_get_code(torch.compile(fn), query, key, value, input_tensor)
input_tensor2 = torch.rand([2, 32, seq_len, seq_len], device="cuda").copy_(
input_tensor
)
# even though the last dim is broadcasted, needs stride 1 for alignment
# but dim 1 stride can be 0
FileCheck().check("buf0").check("(262144, 0, 512, 1").run(code[0])
# dont check rng state
self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2])
@skipIfRocmArch(MI350_ARCH)
def test_effn_attn_bias_padding_misaligned(self):
seqlen_start = 1008
for offset in range(-1, 2):
seqlen = seqlen_start + offset
torch._dynamo.reset()
bsz = 32
q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda")
inputs = [q, k, v, mask]
def f(q, k, v, mask):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)
f_compiled = torch.compile(f)
out, code = run_and_get_code(f_compiled, *inputs)
# padded bias should have an expanded dim
FileCheck().check("buf0 =").check_same(", 0, ").run(code[0])
# single fused padded kernel
FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check(
"return"
).run(code[0])
self.assertEqual(out, f(*inputs))
def test_input_channels_last(self):
m = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 1, 1),
ToTuple(),
).cuda()
inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda()
self.common(
m,
(inp,),
check_lowp=False,
)
@torch.compile()
def foo(m, inp):
return m(inp)
self.assertTrue(foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last))
# https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527
def test_unspec_inputs_interop(self):
class Repro(torch.nn.Module):
def forward(self, x, y):
unsqueeze = torch.ops.aten.unsqueeze.default(x, 4)
permute = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3])
add = torch.ops.aten.add.Tensor(y, 1)
return [permute, add]
inps = [
rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"),
rand_strided((), (), torch.int64, "cpu"),
]
mod = make_fx(Repro().to(device="cuda"))(*inps)
compiled = compile_fx_inner(mod, inps)
compiled(inps)
@unittest.skipIf(
IS_FBCODE, "RuntimeError: Triton Error [CUDA]: invalid device context"
)
def test_backward_context(self):
def fn(x):
return x * 3
x = torch.randn(4, device="cuda", requires_grad=True)
gO = torch.rand_like(x)
opt_fn = torch.compile(fn)
out = opt_fn(x)
out.backward(gO)
@config.patch(fallback_random=True)
def test_dtype_factory_issue(self):
def forward():
randn = torch.ops.aten.randn.default(
[12, 64, 1, 64],
dtype=torch.float32,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1)
return (unsqueeze_default_2,)
mod = make_fx(forward)()
compiled = compile_fx_inner(mod, ())
assert compiled([])[0].device.type == "cuda"
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_no_device_idx_repro_cudagraphs(self):
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self):
full = torch.ops.aten.full.default(
[8, 512],
1,
dtype=torch.float32,
layout=torch.strided,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
full_1 = torch.ops.aten.full.default(
[8, 512],
0,
dtype=torch.int64,
layout=torch.strided,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
return (full_1, full)
self.common(Repro(), ())
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_expanded_inputs_cudagraphs(self):
@torch.compile(backend="inductor")
def fn(x, y):
return x + y
inputs = (
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(
automatic_dynamic_shapes=True,
assume_static_by_default=False,
)
def test_dynamic_to_static_cudagraphs(self):
for b in [False, True]:
with config.patch({"triton.cudagraph_trees": b}):
@torch.compile(backend="inductor")
def fn(x, y):
r = x + y
return r, r.size(0)
inputs = (
torch.randn((5, 5), device="cuda"),
torch.randn((5, 5), device="cuda"),
)
self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5)))
inputs = (
torch.randn((6, 6), device="cuda"),
torch.randn((6, 6), device="cuda"),
)
self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6)))
def _test_split_reduction_impl(self, x):
def max(x):
return torch.max(x)
max_c = torch.compile(max)
out, code = run_and_get_code(max_c, x)
self.assertEqual(out, max(x))
if DO_PERF_TEST:
ms_c = benchmarker.benchmark_gpu(lambda: max_c(x))
ms_eager = benchmarker.benchmark_gpu(lambda: max(x))
print(f"compile {ms_c=:.03f}, eager {ms_eager=:.03f}")
def test_split_reduction_transposed(self):
x = torch.randn(4096, 8192, dtype=torch.bfloat16, device="cuda")
x = x.t().contiguous().t()
self._test_split_reduction_impl(x)
def test_split_reduction_channels_last(self):
x = torch.randn(4096, 8192, dtype=torch.bfloat16, device="cuda")
x = x.reshape([256, 256, 256, 2]).to(memory_format=torch.channels_last)
self._test_split_reduction_impl(x)
@config.patch({"emulate_precision_casts": True})
def test_bool_emulate_low_precision(self):
from torch import device
inf = float("inf")
def forward():
full_1 = torch.ops.aten.full.default(
[6, 6],
1,
dtype=torch.float32,
layout=torch.strided,
device=device(type="cpu"),
pin_memory=False,
)
device_put_3 = torch.ops.prims.device_put.default(
full_1, device(type="cuda", index=0)
)
full_1 = None
convert_element_type_40 = torch.ops.prims.convert_element_type.default(
device_put_3, torch.bool
)
device_put_3 = None
unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_40, 1)
convert_element_type_40 = None
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
unsqueeze_4 = None
expand = torch.ops.aten.expand.default(unsqueeze_5, [-1, 256, -1, 256])
unsqueeze_5 = None
clone = torch.ops.aten.clone.default(
expand, memory_format=torch.contiguous_format
)
expand = None
view_15 = torch.ops.aten.reshape.default(clone, [1536, 1536])
clone = None
scalar_tensor = torch.ops.aten.scalar_tensor.default(
-inf, dtype=torch.float16, device=device(type="cuda", index=0)
)
scalar_tensor_1 = torch.ops.aten.scalar_tensor.default(
0.0,
dtype=torch.float16,
layout=torch.strided,
device=device(type="cuda", index=0),
)
where = torch.ops.aten.where.self(view_15, scalar_tensor_1, scalar_tensor)
view_15 = scalar_tensor_1 = scalar_tensor = None
return where
from torch._inductor import config
config.emulate_precision_casts = True
self.assertEqual(torch.compile(forward)(), forward())
@config.patch({"emulate_precision_casts": True})
def test_emulate_low_precision(self):
def foo(x):
return torch.nn.functional.gelu(x) * 10.0
inp = torch.rand([32], device="cuda", requires_grad=True, dtype=torch.bfloat16)
out, codes = run_fw_bw_and_get_code(lambda: torch.compile(foo)(inp))
# fwd, backward
for code in codes:
f = FileCheck()
# in eager, there are two down casts
for _ in range(2):
f.check(".to(tl.bfloat16)").check_next(".to(tl.float32)")
f.run(code)
self.assertEqual(foo(inp), out)
# TODO: Abstract this out, test more extensively
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dynamic_shapes(self):
torch._dynamo.reset() # Needed since everywhere else uses "inductor"
def f(x):
return x.cos().view(x.shape).sin()
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
f2 = torch.compile(f, backend=cnts)
f2(torch.randn(32))
inp = torch.randn(16)
real_out = f(inp)
compiled_out = f2(inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(real_out, compiled_out)
torch._dynamo.reset()
@config.patch({"triton.cudagraphs": True, "size_asserts": False})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_expanded_inputs_cudagraphs_no_size_asserts(self):
@torch.compile(backend="inductor")
def fn(x, y):
return x + y
inputs = (
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@config.patch({"triton.cudagraph_trees": False})
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_inplace_updates_cudagraphs(self):
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = torch.nn.Parameter(
torch.randn(10, 20, requires_grad=True)
)
def forward(self, x):
x = torch.matmul(x, self.weight1)
return x
from copy import deepcopy
model = Repro().cuda()
model_ref = deepcopy(model)
model_opt = torch.compile(model, backend="inductor")
input = torch.randn(10, 10, device="cuda", requires_grad=True)
for i in range(2):
output_ref = model_ref(input)
output_res = model_opt(input)
output_ref.sum().backward()
output_res.sum().backward()
for p_ref, p_res in zip(model_ref.parameters(), model_opt.parameters()):
self.assertEqual(p_ref.grad, p_res.grad)
with torch.no_grad():
for param in model_ref.parameters():
param.add_(1.0)
for param in model_opt.parameters():
param.add_(1.0)
# https://github.com/pytorch/torchdynamo/issues/1850
def test_inductor_output_aliases_intermediate(self):
def foo(x):
out = x + x
return out.t()
foo_opt = torch.compile(foo, backend="inductor")
inpt = torch.randn(10, 10, device="cuda", requires_grad=True)
# TODO: this is broken, fix later
# out = foo_opt(inpt)
# out.add_(2)
out_ref = foo(inpt)
out_ref.add_(2)
# self.assertEqual(out_ref, out)
def test_accuracy_issue1(self):
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=768, out_features=2, bias=True
)
def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
linear = self.linear(x)
split = linear.split(1, dim=-1)
getitem = split[0]
squeeze = getitem.squeeze(-1)
clamp = start_positions.clamp(0, 128)
cross_entropy = torch.nn.functional.cross_entropy(
squeeze, clamp, None, None, 128, None, "mean", 0.0
)
return cross_entropy
mod = Repro().cuda()
opt_mod = torch.compile(mod, backend="inductor")
mod.eval()
opt_mod.eval()
args = [
((1,), (1,), torch.int64, "cuda", False),
((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True),
]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
with torch.cuda.amp.autocast(enabled=False):
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
@config.patch(allow_buffer_reuse=False)
def test_issue103461(self):
def forward(add_1):
var_mean = torch.ops.aten.var_mean.correction(
add_1, [2], correction=0, keepdim=True
)
getitem_1 = var_mean[1]
return getitem_1
x = torch.randn(1, 8, 768, device="cuda")
correct = forward(x)
actual = torch.compile(forward, fullgraph=True)(x)
self.assertEqual(actual, correct)
def test_full_copy(self):
def forward(x):
full_10 = torch.ops.aten.full.default(
[204, 204, 28],
0,
dtype=torch.float64,
layout=torch.strided,
device="cuda",
pin_memory=False,
)
return x + full_10.to("cpu")
o = torch.randn([204, 204, 28], dtype=torch.float64)
correct = forward(o)
actual = torch.compile(forward, fullgraph=True)(o)
self.assertEqual(actual, correct)
def test_autotune_inplace_kernel(self):
"""
This UT tests autotune on an inplace kernel. The autotune should not contaminate
the input buffers when tuning with multiple configs. For more details, refer to
https://github.com/triton-lang/triton/issues/781
https://github.com/pytorch/torchdynamo/issues/1670
"""
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
from torch._inductor.runtime.hints import AttrsDescriptorWrapper, HeuristicType
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._inductor.utils import triton_version_uses_attrs_dict
def autotune(configs, meta):
def decorator(fn):
if triton_version_uses_attrs_dict():
# Newer versions of Triton puts constexpr in signature
# Ref: https://github.com/pytorch/pytorch/pull/145051
meta["signature"]["XBLOCK"] = "constexpr"
return CachingAutotuner(
# force autotune by setting save_cache_hook to False
fn,
triton_meta=meta,
configs=configs,
save_cache_hook=False,
mutated_arg_names=["in_out_ptr0"],
reset_to_zero_arg_names=[],
optimize_mem=True,
heuristic_type=HeuristicType.POINTWISE,
inductor_meta={"grid_type": "Grid1D"},
)
return decorator
@autotune(
configs=[
triton.Config({"XBLOCK": 1}),
triton.Config({"XBLOCK": 2}),
],
meta={
"signature": {
"in_out_ptr0": "*fp32",
"in_ptr0": "*fp32",
"xnumel": "i32",
},
"device": DeviceProperties.create(torch.device("cuda")),
"configs": [
AttrsDescriptorWrapper(divisible_by_16=(0, 1), equal_to_1=())
],
"constants": {},
},
)
@triton.jit
def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * XBLOCK
offsets = block_start + tl.arange(0, XBLOCK)
mask = offsets < xnumel
x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0)
y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0)
output = x + y
tl.store(in_out_ptr0 + offsets, output, mask=mask)
xnumel = 384
in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
inout2 = inout1.clone()
stream0 = get_cuda_stream(0)
kernel.run(inout1, in0, xnumel, stream=stream0)
kernel.run(inout2, in0, xnumel, stream=stream0)
assert same(inout1, inout2, tol=0.001, equal_nan=True), (
"failed autotune with inplace kernel"
)
def test_sort_stride_issue(self):
# This minified testcase comes from detectron2_maskrcnn_r_50_fpn
# There was a false error from our size_assert code
@torch.compile(fullgraph=True)
def forward(pred_objectness_logits_3_: torch.Tensor):
sort_3 = pred_objectness_logits_3_.sort(descending=True, dim=1)
getitem_12 = sort_3[0]
return getitem_12
args = [((1, 100), (0, 1), torch.float16, "cuda", False)]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
result = forward(*args)
assert same(result, torch.sort(args[0], descending=True, dim=1)[0])
def test_scalar_triton_index(self):
# The indirect indexing via a scalar like below used to lead to
# bad triton code that made triton segfault when compiling.
# See https://github.com/pytorch/torchdynamo/issues/1515
def fn(a):
zero = torch.zeros((16,), device=a.device, dtype=torch.int64)
return (a[zero],)
a = torch.randn((8,), dtype=torch.float32, device="cuda")
fn_optimized = torch.compile(fn, backend="inductor")
assert same(fn(a), fn_optimized(a))
def test_indirect_indexing_dense_mask(self):
def fn(x, y):
ne = torch.ops.aten.ne.Scalar(x, 1)
sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1])
sub = torch.ops.aten.sub.Tensor(sum_1, 1)
unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1)
gather = torch.ops.aten.gather.default(x, 1, unsqueeze)
squeeze = torch.ops.aten.squeeze.default(gather)
out = torch.ops.aten.multiply(y, squeeze)
return (out,)
a = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
b = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
fn_optimized = torch.compile(fn, backend="inductor")
assert same(fn(a, b), fn_optimized(a, b))
def test_simplify_dims(self):
def fn(a):
return (a + 1,)
self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],))
@config.patch(permute_fusion=True)
def test_permute_fusion(self):
class Repro(torch.nn.Module):
def forward(self, view, reshape_2):
permute = view.permute(0, 2, 1)
view = None
reshape = torch.reshape(permute, (-1, 642))
bmm = torch.bmm(permute, reshape_2)
return (bmm,)
args = [
((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True),
((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True),
]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
mod = Repro()
opt_mod = torch.compile(mod, backend="inductor")
ref = mod(*args)
res = opt_mod(*args)
self.assertTrue(same(ref, res))
@config.patch({"triton.autotune_pointwise": True})
def test_inplace_add_alpha_autotune(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
return (x,)
x1 = torch.zeros(2, 3, 4, 10, device="cuda")
x2 = torch.zeros(2, 3, 4, 10, device="cuda")
x3 = torch.zeros(2, 3, 4, 10, device="cuda")
y = torch.randn(2, 3, 4, 10, device="cuda").to(
memory_format=torch.channels_last
)
fn_fx = make_fx(fn)(x1, y)
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
fn(x2, y)
fn_compiled([x3, y])
assert same(x2, x3)
@config.patch({"triton.autotune_pointwise": True})
def test_inplace_buffer_autotune(self):
def foo(x, y, z):
a = x @ y
return a.unsqueeze(0).unsqueeze(0) + z
x = torch.zeros(5, 5, device="cuda")
y = torch.zeros(5, 5, device="cuda")
z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last)
self.common(
foo,
(x, y, z),
check_lowp=False,
)
def test_memory_history_inductor(self):
def called_inside_compile(x, w, b):
a = x @ w + b
return torch.sigmoid(a)
@torch.compile
def fn(x, w, b):
x = called_inside_compile(x, w, b)
return called_inside_compile(x, w, b)
w = torch.rand(3, 3, device="cuda")
b = torch.rand(3, device="cuda")
x = torch.rand(3, device="cuda")
try:
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history(True)
r = fn(x, w, b)
finally:
torch.cuda.memory._record_memory_history(False)
snapshot = str(torch.cuda.memory._snapshot())
self.assertTrue("called_inside_compile" in snapshot)
def test_negative_arange_dynamic_shapes(self):
# Repro from alibi relative encodings
def sign(x):
return (x > 0) - (x < 0)
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
nheads = 16
start = math.log2(0.5)
end = math.log2(1 / (2**8))
self.scales = nn.Buffer(
2
** torch.arange(
start,
end + 1e-6 * sign(end - start),
(end - start) / (nheads - 1),
).view(1, nheads, 1, 1),
)
self.emb = nn.Embedding(1024, 256)
self.dec_layer = nn.TransformerDecoderLayer(
256, 16, 512, batch_first=True, norm_first=True
)
self.head = nn.Linear(256, 1024)
def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor):
padmask = dec_in == 0
dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2)
dec_mask = dec_mask.to(dtype=torch.float32)
dec_mask = dec_mask.tril(diagonal=0).cuda()
q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
rel_pos = k_pos[None, :] - q_pos[:, None]
values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0)
dec_bias = values * self.scales
dec_bias.tril_(diagonal=0)
dec_mask = dec_mask + dec_bias[0]
out = self.emb(dec_in)
out = self.dec_layer(out, enc_out, tgt_mask=dec_mask)
return self.head(out)
mod = Repro().cuda()
opt_mod = torch.compile(mod, backend="inductor", dynamic=True)
mod.eval()
opt_mod.eval()
enc_out = torch.rand(1, 512, 256).cuda()
dec_inputs = [
torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8)
]
for dec_inp in dec_inputs:
assert same_two_models(mod, opt_mod, [enc_out, dec_inp], only_fwd=True), (
"Inductor with dynamic shapes failed"
)
def test_issue97695_1input(self):
def fn(arg3_1, relu, permute_1):
addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
cat_2 = torch.ops.aten.cat.default([addmm_1], 1)
return (cat_2,)
args = [
((96,), (1,), torch.float32, "cuda"),
((10, 256), (256, 1), torch.float32, "cuda"),
((256, 96), (1, 256), torch.float32, "cuda"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
correct = fn(*args)
mod = make_fx(fn, tracing_mode="real")(*args)
compiled = compile_fx_inner(mod, args)
ref = compiled(list(args))
assert same(ref, correct)
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)
def test_issue_103924(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.temperature = 1
self.layer = torch.nn.Softmax(dim=1)
def forward(self, x):
n_samples, _ = x.shape
y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device)
inp = x / y[..., None]
return self.layer(inp)
x = torch.rand([4, 4], device="cuda")
m = MyModule()
opt_m = torch.compile(backend="inductor")(m)
self.assertEqual(opt_m(x), m(x))
def test_issue97695_2input(self):
def fn(arg3_1, arg3_2, relu, permute_1):
addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
addmm_2 = torch.ops.aten.addmm.default(arg3_2, relu, permute_1)
cat_2 = torch.ops.aten.cat.default([addmm_1, addmm_2], 1)
return (cat_2,)
args = [
((96,), (1,), torch.float32, "cuda"),
((96,), (1,), torch.float32, "cuda"),
((10, 256), (256, 1), torch.float32, "cuda"),
((256, 96), (1, 256), torch.float32, "cuda"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
correct = fn(*args)
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)
def test_scatter_index_not_wrapped(self):
src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.device)
index = torch.tensor([0, 1, 0, 1, 2, 0], device=self.device)
input = torch.tensor([1.0, 2.0, 3.0, 4.0], device=self.device)
compiled_sr = torch.compile(torch.scatter_reduce)
input_orig = input.clone()
out, code = run_and_get_code(compiled_sr, input, 0, index, src, "sum")
# tmp0 - not wrapping of negative numbers
FileCheck().check("tl.device_assert(((0 <= tmp0) & (tmp0 < 4))").check_next(
"atomic_add"
).run(code[0])
self.assertEqual(
out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum")
)
def test_normalize_norm_leq_one(self):
def fn(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.normalize(x, dim=-1)
inp = torch.tensor([[3.799999, 0.0, 0.0]], device="cuda", dtype=torch.float32)
compiled = torch.compile(fn, backend="inductor", fullgraph=True)
out = compiled(inp)
norm = out.norm(dim=-1)
self.assertTrue(
torch.all(norm <= 1.0), f"expected norm <= 1.0 but got {norm.item()}"
)
def test_libdevice_routing(self):
def foo(x):
return x.exp()
inp = torch.ones(64, device="cuda").to(torch.float64)
out, code = run_and_get_code(torch.compile(foo), inp)
FileCheck().check("libdevice.exp").run(code[0])
self.assertEqual(foo(inp), out)
inp = inp.to(torch.float)
out, code = run_and_get_code(torch.compile(foo), inp)
FileCheck().check_not("tl_math.exp").check("libdevice.exp").run(code[0])
self.assertEqual(foo(inp), out)
def foo(x):
return x.sigmoid()
inp = torch.ones(64, device="cuda").to(torch.float64)
out, code = run_and_get_code(torch.compile(foo), inp)
FileCheck().check("libdevice.exp").run(code[0])
self.assertEqual(foo(inp), out)
def test_uint_view_copy(self):
@torch.compile
def view_copy(target, source):
assert target.dtype == torch.bfloat16
assert source.dtype == torch.uint16
target.view(torch.uint16).copy_(source)
target = torch.ones(1024, dtype=torch.bfloat16, device="cuda")
source = torch.full_like(target, 4, dtype=torch.uint16)
out = target.view(torch.uint16).copy_(source).clone()
view_copy(target, source)
self.assertEqual(out, target.view(torch.uint16))
def test_embedding_var_mean(self):
def forward(arg0_1):
full = torch.ops.aten.full.default(
[1, 2048],
1,
dtype=torch.float32,
layout=torch.strided,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
convert_element_type_1 = torch.ops.prims.convert_element_type.default(
full, torch.int64
)
cumsum = torch.ops.aten.cumsum.default(convert_element_type_1, 1)
mul = torch.ops.aten.mul.Tensor(cumsum, convert_element_type_1)
sub_1 = torch.ops.aten.sub.Tensor(mul, 1)
slice_5 = torch.ops.aten.slice.Tensor(sub_1, 0, 0, 9223372036854775807)
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807)
add_2 = torch.ops.aten.add.Tensor(slice_6, 2)
embedding_1 = torch.ops.aten.embedding.default(arg0_1, add_2)
var_mean = torch.ops.aten.var_mean.correction(
embedding_1, [2], correction=0, keepdim=True
)
return [var_mean[0], var_mean[1], add_2]
emb = torch.randn([2050, 768], device="cuda")
gm = make_fx(forward)(emb)
opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb])
opt([emb])
torch.cuda.synchronize()
def test_deterministic_algorithms(self):
N = 10000
@torch.compile
def fn(idx, values):
x = torch.zeros(1, device="cuda")
x[idx] += values
return x
idx = torch.zeros(N, dtype=torch.int64, device="cuda")
values = torch.randn(N, device="cuda")
r0 = fn(idx, values)
with DeterministicGuard(True):
r1 = fn(idx, values)
for _ in range(10):
rn = fn(idx, values)
self.assertEqual(r1, rn, atol=0, rtol=0)
# https://github.com/pytorch/pytorch/issues/96406
def test_linear_cpu_input(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(4, 4)
def forward(self, data):
data = data.to("cuda")
return self.linear(data)
mod = Model().cuda().eval()
with torch.no_grad():
self.common(mod, (torch.randn(4, 4),))
@config.patch({"fallback_random": True, "triton.cudagraphs": True})
def test_xlnet_lm_stride_repro(self):
class Repro(nn.Module):
def __init__(self) -> None:
super().__init__()
self.dropout = nn.Dropout(p=0.1, inplace=False)
def forward(self, x):
y = torch._C._nn.gelu(x)
return self.dropout(y)
mod = Repro()
x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda")
y = torch.compile(mod)(x)
# Inductor claims the output layout of gelu's saved variable for
# backwards will be (4096, 4096, 1) but in actuality it is (4096,
# 2097152, 1). Fortunately this doesn't actually matter in practice.
y.sum().backward()
def test_lookup_seed_backward(self):
@torch.compile(fullgraph=True)
def forward(inductor_seeds, mul_4, view_15):
inductor_lookup_seed_2 = torch.ops.prims.inductor_lookup_seed.default(
inductor_seeds, 2
)
inductor_random_2 = torch.ops.prims.inductor_random.default(
[2, 512, 768], inductor_lookup_seed_2, "rand"
)
gt_2 = torch.ops.aten.gt.Scalar(inductor_random_2, 0.1)
mul_7 = torch.ops.aten.mul.Tensor(gt_2, view_15)
mul_8 = torch.ops.aten.mul.Tensor(mul_7, 1.1111111111111112)
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_4)
var_mean_1 = torch.ops.aten.var_mean.correction(
add_5, [2], correction=0, keepdim=True
)
getitem_3 = var_mean_1[1]
sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3)
return (sub_3,)
buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda")
buf1 = torch.zeros((2, 512, 768), device="cuda")
buf2 = torch.zeros((2, 512, 768), device="cuda")
forward(buf0, buf1, buf2)
def test_issue100806(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(10, 20)
self.linear2 = torch.nn.Linear(20, 30)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = torch.cat((x, x), dim=1)
x = x.view(-1, 2, 30)
x = x[:, 1, :]
x = self.relu(x)
return x
device = "cuda"
batch_size = 2
x = torch.randn(batch_size, 10).to(device)
func = Model().to(device)
with torch.no_grad():
func.train(False)
jit_func = torch.compile(func)
res1 = func(x)
res2 = jit_func(x)
self.assertEqual(res1, res2)
def test_issue103481(self):
def fn(x, y):
# NOTE: 6 dimensions is important! does not fail for 5 dimensions
mean = torch.mean(x, [2, 3, 4, 5], keepdim=True)
add = mean + y
return add
x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda")
y = torch.rand((), device="cuda")
expect = fn(x, y)
opt_fn = torch.compile(fn)
actual = opt_fn(x, y)
self.assertEqual(expect, actual)
@config.patch({"triton.dense_indexing": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_bucketize_dynamic_dense(self):
"""
Make sure that ops.bucketize() can handle dense_indexing, which previously
caused issues due to incorrect handling of the size of offsets.
"""
def fn(values, offsets):
return torch.bucketize(values, offsets)
values = torch.rand((64, 64), device="cuda")
offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda")
expect = fn(values, offsets)
opt_fn = torch.compile(fn, dynamic=True)
actual = opt_fn(values, offsets)
self.assertEqual(expect, actual)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
@config.patch(
{
"max_autotune_gemm_backends": "TRITON",
"triton.disallow_failing_autotune_kernels_TESTING_ONLY": True,
"compile_threads": 1,
}
)
def test_bucketize_epilogue(self):
"""
See https://github.com/pytorch/pytorch/issues/148764.
Make sure that when torch.bucketize appears as an epilogue, the codegen is valid.
Note: during autotuning, there's also the option to _not_ do the fusion.
So if you run the test with standard configs, the fused kernel would fail during
autotuning, and another non-fused kernel would be selected (and Inductor would
throw some errors, but the test would pass)
So we set disallow_failing_autotune_kernels_TESTING_ONLY=True to prevent the
autotuner from catching failures. And set compile_threads=1 so that compile
failures aren't caught by the asyn runner infra.
"""
def fn(x: torch.Tensor, y: torch.Tensor, buckets: torch.Tensor) -> torch.Tensor:
z = torch.mm(x, y)
return torch.bucketize(z, buckets)
buckets = torch.arange(-100, 100, 10, device="cuda")
x = torch.randn(64, 64, device="cuda").clamp(-99, 99)
y = torch.randn(64, 64, device="cuda").clamp(-99, 99)
opt_fn = torch.compile(fn, mode="max-autotune")
expected = fn(x, y, buckets)
actual = opt_fn(x, y, buckets)
self.assertEqual(expected, actual)
def test_float64_constants(self):
def fn():
# NOTE: tensors of all the same value are constant folded, so we
# need a tensor with two distinct values
a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda")
return a * 2e50
cfn = torch.compile(fn)
expect = fn()
actual = cfn()
self.assertEqual(expect, actual, atol=0, rtol=0)
def test_issue104759(self):
def fn(arg7_1, add_1, permute_2, select_scatter, slice_8):
slice_scatter_4 = torch.ops.aten.slice_scatter.default(
permute_2, select_scatter, 0, 1, 9223372036854775807
)
permute_3 = torch.ops.aten.permute.default(slice_scatter_4, [1, 3, 0, 2, 4])
view_6 = torch.ops.aten.view.default(permute_3, [1, 1000, 48])
view_7 = torch.ops.aten.view.default(view_6, [1000, 48])
view_8 = torch.ops.aten.view.default(view_7, [1, 1000, 48])
view_9 = torch.ops.aten.view.default(view_8, [1, 1000, 3, 4, 4])
permute_4 = torch.ops.aten.permute.default(view_9, [2, 0, 3, 1, 4])
slice_7 = torch.ops.aten.slice.Tensor(permute_4, 0, 1, 9223372036854775807)
slice_scatter_5 = torch.ops.aten.slice_scatter.default(
slice_8, slice_7, 4, 0, 9223372036854775807
)
slice_scatter_6 = torch.ops.aten.slice_scatter.default(
arg7_1, slice_scatter_5, 3, 0, 1000
)
mul_8 = torch.ops.aten.mul.Scalar(add_1, 0.7071067811865476)
slice_9 = torch.ops.aten.slice.Tensor(slice_scatter_6, 3, 0, 1000)
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 4, 0, 9223372036854775807)
select_2 = torch.ops.aten.select.int(slice_10, 0, 0)
permute_5 = torch.ops.aten.permute.default(select_2, [0, 1, 3, 2])
mul_9 = torch.ops.aten.mul.Scalar(permute_5, 0.7071067811865476)
expand = torch.ops.aten.expand.default(mul_8, [1, 4, 1000, 4])
view_10 = torch.ops.aten.view.default(expand, [4, 1000, 4])
expand_1 = torch.ops.aten.expand.default(mul_9, [1, 4, 4, 1000])
view_11 = torch.ops.aten.view.default(expand_1, [4, 4, 1000])
bmm = torch.ops.aten.bmm.default(view_10, view_11)
return (bmm,)
args = []
args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda"))
args.append(
rand_strided(
(1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda"
)
)
args.append(
rand_strided(
(3, 1, 4, 1000, 4),
(16, 48000, 4, 48, 1),
dtype=torch.float16,
device="cuda",
)
)
args.append(
rand_strided(
(2, 1, 4, 1000, 4),
(16, 48000, 4, 48, 1),
dtype=torch.float16,
device="cuda",
)
)
args.append(
rand_strided(
(2, 1, 4, 1000, 4),
(19200, 19200, 4800, 4, 1),
dtype=torch.float16,
device="cuda",
)
)
correct = fn(*args)
mod = make_fx(fn, tracing_mode="real")(*args)
compiled = compile_fx_inner(mod, args)
ref = compiled(list(args))
assert same(ref, correct)
@config.patch({"triton.cudagraphs": True})
def test_index_put_inplace_cudagraph(self):
def fn(x, y, z):
x = torch.zeros_like(x)
return x.index_put_([y], z, True)
x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
opt_fn = torch.compile(fn, backend="inductor")
ref = fn(x, y, z)
# run it twice to test cuda graph issue
res = opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res)
@config.patch({"triton.cudagraphs": True})
@config.patch({"fx_graph_cache": True})
def test_index_put_cudagraph(self):
for _ in range(2):
def fn(x, y, z):
x = torch.zeros_like(x)
return x.index_put([y], z, True)
x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
opt_fn = torch.compile(fn, backend="inductor")
ref = fn(x, y, z)
# run it twice to test cuda graph issue
res = opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res)
torch._dynamo.reset()
gc.collect()
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
)
def test_flash_attention_dynamic(self):
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.q = nn.Linear(1024, 1024)
self.k = nn.Linear(1024, 1024)
self.v = nn.Linear(1024, 1024)
def forward(self, x):
batch_size, seq_len, _ = x.size()
queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
attn = F.scaled_dot_product_attention(
queries,
keys,
values,
)
return attn
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
model = Model().cuda().half()
model = torch.compile(model, backend=cnts, dynamic=True)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False,
enable_cudnn=False,
):
input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)
out1 = model(input1)
out2 = model(input2)
out3 = model(input3)
self.assertEqual(cnts.frame_count, 2)
@config.patch({"triton.cudagraphs": True})
def test_index_put_no_fallback_cudagraph(self):
def fn(x, y, z):
x = torch.zeros_like(x)
return x.index_put([y], z, True)
x = torch.zeros((512, 512), device="cuda", dtype=torch.int32)
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
z = torch.ones((512, 512), device="cuda", dtype=torch.int32)
opt_fn = torch.compile(fn, backend="inductor")
ref = fn(x, y, z)
# run it twice to test cuda graph issue
res = opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res)
@torch._inductor.config.patch(emulate_precision_casts=True)
def test_emulate_precision_casts_norm_rounding(self):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
x = torch.rand(1000, device="cuda", dtype=torch.bfloat16)
scalar = torch.rand([], device="cuda", dtype=torch.float32)
def fn(inp, scale):
y = inp.norm()
return y, y + scale
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
expected = fn(x, scalar)
actual = opt_fn(x, scalar)
self.assertEqual(expected, actual)
@torch._inductor.config.patch(emulate_precision_casts=True)
def test_emulate_precision_casts_min_pow_chain(self):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
with dynamo_config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
):
arg0 = torch.rand(
[383, 55, 2, 3],
dtype=torch.float16,
device="cuda",
requires_grad=True,
)
arg1 = torch.rand(
[383, 55], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg2 = torch.rand(
[383, 55], dtype=torch.float32, device="cuda", requires_grad=True
)
arg3 = torch.rand(
[383, 55], dtype=torch.float32, device="cuda", requires_grad=True
)
def fn(a0, a1, a2, a3):
t1 = a0.min(dim=2).values
t2 = t1.sum(dim=2)
t6 = ((((a1) - a2) - a3) - a3) - a3
t7 = t6 + t2
t8 = torch.pow(torch.pow(torch.pow(torch.pow(t2, t7), t7), t7), t7)
return t7, t8
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
eager_out = fn(arg0, arg1, arg2, arg3)
compiled_args = [
arg0.clone().detach().requires_grad_(True),
arg1.clone().detach().requires_grad_(True),
arg2.clone().detach().requires_grad_(True),
arg3.clone().detach().requires_grad_(True),
]
compiled_out = opt_fn(*compiled_args)
for eager_tensor, compiled_tensor in zip(eager_out, compiled_out):
torch.testing.assert_close(
eager_tensor,
compiled_tensor,
rtol=1e-3,
atol=1e-3,
)
@torch._inductor.config.patch(emulate_precision_casts=True)
def test_emulate_precision_casts_mean_ratio_chain(self):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
with dynamo_config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
):
arg0 = torch.rand(
[125070], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg1 = torch.rand(
[1895, 3, 11], dtype=torch.float16, device="cuda", requires_grad=True
)
arg2 = torch.rand(
[1895, 3, 11], dtype=torch.float32, device="cuda", requires_grad=True
)
arg3 = torch.rand(
[1895, 3, 11], dtype=torch.float32, device="cuda", requires_grad=True
)
arg4 = torch.rand(
[1895, 3, 11], dtype=torch.float32, device="cuda", requires_grad=True
)
arg5 = torch.rand(
[5, 379, 165], dtype=torch.float32, device="cuda", requires_grad=True
)
def fn(a0, a1, a2, a3, a4, a5):
t2 = a0.view(379, 165, 2).mean(dim=2)
t7 = ((((a1) - a2) - a3) - a2) - a4
t8 = t7.view(379, 165)
t11 = torch.nn.functional.relu(a5).mean(dim=0)
t12 = t2 - t11
t13 = (((t2) / t8) / t11) / t12
return t13
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
eager_out = fn(arg0, arg1, arg2, arg3, arg4, arg5)
compiled_args = [
tensor.clone().detach().requires_grad_(True)
for tensor in (arg0, arg1, arg2, arg3, arg4, arg5)
]
compiled_out = opt_fn(*compiled_args)
torch.testing.assert_close(
eager_out,
compiled_out,
rtol=5e-3,
atol=1e-1,
)
@torch._inductor.config.patch(emulate_precision_casts=True)
def test_dont_inplace_disjoint_accesses(self):
# TODO - would not need mms if we could annotate donated buffer..
def forward( # noqa: F821, F722
arg0_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722
arg1_1: "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0", # noqa: F821, F722
arg2_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722
arg3_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722
arg4_1: "bf16[2048][1]cuda:0", # noqa: F821, F722
arg5_1: "bf16[2048][1]cuda:0", # noqa: F821, F722
arg6_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722
arg7_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722
):
permute = torch.ops.aten.permute.default(arg0_1, [1, 0])
arg0_1 = None
view = torch.ops.aten.view.default(arg1_1, [32768, 2048])
mm = torch.ops.aten.mm.default(view, permute)
view = permute = None
view_1 = torch.ops.aten.view.default(mm, [8, 4096, 2048])
mm = None
permute_1 = torch.ops.aten.permute.default(arg2_1, [1, 0])
arg2_1 = None
view_2 = torch.ops.aten.view.default(arg1_1, [32768, 2048])
mm_1 = torch.ops.aten.mm.default(view_2, permute_1)
view_2 = permute_1 = None
view_3 = torch.ops.aten.view.default(mm_1, [8, 4096, 2048])
mm_1 = None
permute_2 = torch.ops.aten.permute.default(arg3_1, [1, 0])
arg3_1 = None
view_4 = torch.ops.aten.view.default(arg1_1, [32768, 2048])
arg1_1 = None
mm_2 = torch.ops.aten.mm.default(view_4, permute_2)
view_4 = permute_2 = None
view_5 = torch.ops.aten.view.default(mm_2, [8, 4096, 2048])
mm_2 = None
convert_element_type_6 = torch.ops.prims.convert_element_type.default(
view_1, torch.float32
)
view_1 = None
pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_6, 2)
mean = torch.ops.aten.mean.dim(pow_1, [-1], True)
pow_1 = None
add = torch.ops.aten.add.Tensor(mean, 1e-06)
mean = None
rsqrt = torch.ops.aten.rsqrt.default(add)
add = None
mul = torch.ops.aten.mul.Tensor(convert_element_type_6, rsqrt)
convert_element_type_6 = rsqrt = None
convert_element_type_7 = torch.ops.prims.convert_element_type.default(
arg4_1, torch.float32
)
arg4_1 = None
mul_1 = torch.ops.aten.mul.Tensor(convert_element_type_7, mul)
convert_element_type_7 = mul = None
convert_element_type_8 = torch.ops.prims.convert_element_type.default(
mul_1, torch.bfloat16
)
mul_1 = None
convert_element_type_9 = torch.ops.prims.convert_element_type.default(
view_3, torch.float32
)
view_3 = None
pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_9, 2)
mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True)
pow_2 = None
add_1 = torch.ops.aten.add.Tensor(mean_1, 1e-06)
mean_1 = None
rsqrt_1 = torch.ops.aten.rsqrt.default(add_1)
add_1 = None
mul_2 = torch.ops.aten.mul.Tensor(convert_element_type_9, rsqrt_1)
convert_element_type_9 = rsqrt_1 = None
convert_element_type_10 = torch.ops.prims.convert_element_type.default(
arg5_1, torch.float32
)
arg5_1 = None
mul_3 = torch.ops.aten.mul.Tensor(convert_element_type_10, mul_2)
convert_element_type_10 = mul_2 = None
convert_element_type_11 = torch.ops.prims.convert_element_type.default(
mul_3, torch.bfloat16
)
mul_3 = None
view_6 = torch.ops.aten.view.default(
convert_element_type_8, [8, 4096, -1, 128]
)
convert_element_type_8 = None
view_7 = torch.ops.aten.view.default(
convert_element_type_11, [8, 4096, -1, 128]
)
convert_element_type_11 = None
view_8 = torch.ops.aten.view.default(view_5, [8, 4096, -1, 128])
view_5 = None
convert_element_type_12 = torch.ops.prims.convert_element_type.default(
view_6, torch.float32
)
view_6 = None
convert_element_type_13 = torch.ops.prims.convert_element_type.default(
view_7, torch.float32
)
view_7 = None
unsqueeze = torch.ops.aten.unsqueeze.default(arg6_1, 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
unsqueeze = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(arg7_1, 0)
unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2)
unsqueeze_2 = None
mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_12, unsqueeze_3)
unsqueeze_3 = None
view_9 = torch.ops.aten.view.default(
convert_element_type_12, [8, 4096, 16, 2, 64]
)
convert_element_type_12 = None
unbind = torch.ops.aten.unbind.int(view_9, -2)
view_9 = None
getitem = unbind[0]
getitem_1 = unbind[1]
unbind = None
neg = torch.ops.aten.neg.default(getitem_1)
getitem_1 = None
cat = torch.ops.aten.cat.default([neg, getitem], -1)
neg = getitem = None
mul_5 = torch.ops.aten.mul.Tensor(cat, unsqueeze_1)
cat = unsqueeze_1 = None
add_2 = torch.ops.aten.add.Tensor(mul_4, mul_5)
mul_4 = mul_5 = None
unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg6_1, 0)
arg6_1 = None
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2)
unsqueeze_4 = None
unsqueeze_6 = torch.ops.aten.unsqueeze.default(arg7_1, 0)
arg7_1 = None
unsqueeze_7 = torch.ops.aten.unsqueeze.default(unsqueeze_6, 2)
unsqueeze_6 = None
mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_13, unsqueeze_7)
unsqueeze_7 = None
view_10 = torch.ops.aten.view.default(
convert_element_type_13, [8, 4096, 16, 2, 64]
)
convert_element_type_13 = None
unbind_1 = torch.ops.aten.unbind.int(view_10, -2)
view_10 = None
getitem_2 = unbind_1[0]
getitem_3 = unbind_1[1]
unbind_1 = None
neg_1 = torch.ops.aten.neg.default(getitem_3)
getitem_3 = None
cat_1 = torch.ops.aten.cat.default([neg_1, getitem_2], -1)
neg_1 = getitem_2 = None
mul_7 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5)
cat_1 = unsqueeze_5 = None
add_3 = torch.ops.aten.add.Tensor(mul_6, mul_7)
mul_6 = mul_7 = None
convert_element_type_14 = torch.ops.prims.convert_element_type.default(
add_2, torch.bfloat16
)
add_2 = None
convert_element_type_15 = torch.ops.prims.convert_element_type.default(
add_3, torch.bfloat16
)
add_3 = None
permute_3 = torch.ops.aten.permute.default(
convert_element_type_14, [0, 2, 1, 3]
)
convert_element_type_14 = None
permute_4 = torch.ops.aten.permute.default(
convert_element_type_15, [0, 2, 1, 3]
)
convert_element_type_15 = None
permute_5 = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3])
view_8 = None
return (permute_3, permute_4, permute_5)
from torch._dynamo.debug_utils import aot_graph_input_parser
kwargs = aot_graph_input_parser(forward)
out, code = run_and_get_code(torch.compile(forward), **kwargs)
# ignore tiny values.. prior to this fix absolute error was ~28
self.assertEqual(forward(**kwargs), out, atol=0.01, rtol=2)
FileCheck().check_not("in_out").run(code[0])
# https://github.com/pytorch/pytorch/issues/104937
def test_linear_with_zero_infeature_size(self):
m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda")
x = torch.rand(1, 1, 0, device="cuda")
expect = m(x)
opt_fn = torch.compile(m)
actual = opt_fn(x)
self.assertEqual(expect, actual)
@config.patch(fallback_random=True)
def test_multi_output_layout_fallback(self):
mod = nn.RReLU(lower=3.2350976, upper=8.4220314, inplace=True)
inp = torch.rand([4, 4]).cuda()
m = torch.compile(mod)
with freeze_rng_state():
o1 = m(inp.clone())
o2 = mod(inp.clone())
self.assertEqual(o1, o2)
def test_sorted_masks(self):
@torch.compile()
def foo(x, y):
return (x + y).sum(dim=1)
x = torch.rand([255, 255], device="cuda")
y = torch.rand([255, 255], device="cuda")
_, code = run_and_get_code(foo, x, y)
FileCheck().check("tl.load").check_same("r0_mask").check_same("xmask").run(
code[0]
)
def test_cat_int8_one_kernel(self):
@torch.compile()
def cat(inps):
return torch.cat(inps) + 1
for dtype in [torch.uint8, torch.int8]:
inps = [
torch.empty([256, 256], dtype=dtype, device="cuda") for _ in range(4)
]
out, code = run_and_get_code(cat, inps)
self.assertEqual(torch.cat(inps) + 1, out)
FileCheck().check_not("aten.cat.default(").check_count(
".run(", 1, exactly=True
).run(code[0])
@config.patch("triton.use_block_ptr", True)
def test_selecsls42b_misaligned_address(self):
# https://github.com/triton-lang/triton/issues/2836
@torch.compile(fullgraph=True)
def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
div = torch.ops.aten.div.Scalar(expand, 16)
where = torch.ops.aten.where.self(arg207_1, full, div)
convert_element_type_43 = torch.ops.prims.convert_element_type.default(
where, torch.float32
)
sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
return (sub_2,)
args = [
torch.randn((8, 1024, 4, 4), device="cuda") > 0, # torch.bool tensor
torch.randn((1, 1024, 1, 1), device="cuda"),
torch.randn((8, 1024, 4, 4), device="cuda"),
torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand(
(8, 1024, 4, 4)
),
torch.randn((), device="cuda"),
torch.randn((1024,), device="cuda"),
]
fn(*args)
torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address
def test_mutated_aligned_tensor(self):
t = torch.rand(4096, device="cuda", dtype=torch.float16)
def foo(x):
return x.add_(1)
foo_c = torch.compile(dynamic=False)(foo)
t_orig = t.clone()
# First invocation, assume alignment, second invocation,
# copy to alignment and then mutate after fn invocation
self.assertEqual(foo_c(t[:-1]), foo(t_orig[:-1]))
self.assertEqual(t, t_orig)
self.assertEqual(foo_c(t[1:]), foo(t_orig[1:]))
self.assertEqual(t, t_orig)
def test_non_commutative_scan_op(self):
from torch._higher_order_ops.associative_scan import associative_scan
a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda")
b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda")
def baseline(v, u):
A = []
A.append(b[:, 0])
for i in range(1, v.shape[1]):
A.append(a[:, i] * A[i - 1] + b[:, i])
return torch.stack(A, dim=1)
def combine_fn(i, j):
ia, ib = i
ja, jb = j
return ia * ja, ib * ja + jb
@torch.compile
def compiled_scan(a, b):
return associative_scan(combine_fn, (a, b), dim=-1)[1]
out1 = baseline(a, b)
out2 = compiled_scan(a, b)
self.assertEqual(out1, out2)
def test_dynamic_persistent_reductions(self):
@torch.compile(dynamic=True)
def inner_reduce(x):
assert x.shape[1] <= 1024
return x.sum(1)
a = torch.randn(50, 600, device="cuda")
out, code = run_and_get_code(inner_reduce, a)
self.assertEqual(inner_reduce(a), out)
self.assertTrue("for roffset" not in code)
@torch.compile(dynamic=True)
def outer_reduce(x):
assert x.shape[0] <= 64
return x.sum(0)
out, code = run_and_get_code(outer_reduce, a)
self.assertEqual(outer_reduce(a), out)
self.assertTrue("for roffset" not in code)
def test_scaled_dot_product_efficient_attention_backward(self):
from torch import nn, Tensor
class SelfAttention(nn.Module):
def __init__(
self,
num_attention_heads: int = 12,
hidden_size: int = 768,
attention_probs_dropout_prob: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.dropout_prob = attention_probs_dropout_prob
def transpose_for_scores(self, x: Tensor) -> Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
return x.view(new_x_shape).permute(0, 2, 1, 3)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)
return attn_output
device = torch.device("cuda")
num_attention_heads = 8
hidden_size = 512
attention_probs_dropout_prob = 0.0
model = SelfAttention(
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
).to(device)
model = torch.compile(model)
# runs without failure
batch_size = 8
length = 1
inputs_embeds = torch.randn(batch_size, length, hidden_size, device=device)
attention_mask = torch.ones(batch_size, 1, length, length, device=device)
attn_output = model(hidden_states=inputs_embeds, attention_mask=attention_mask)[
0
]
loss = attn_output.mean()
loss.backward()
def test_non_contiguous_unaligned_input_indices(self):
from torch._inductor.compile_fx import remove_unaligned_input_idxs
inputs = [torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")[1:]]
idxs = remove_unaligned_input_idxs(inputs, [1])
self.assertEqual(idxs, [])
inputs = [
torch.ones(2, 2, device="cuda"),
torch.ones(2, 2, device="cuda"),
torch.ones(2, 2, device="cuda")[1:],
]
idxs = remove_unaligned_input_idxs(inputs, [0, 2])
self.assertEqual(idxs, [0])
@config.patch("triton.cudagraphs", True)
def test_unused_cpu_input_cudagraphs(self):
def fn(x, y):
return x.sin().sin().sin().sin().cos() + 1
fx_graph = torch.fx.symbolic_trace(fn)
inp = [torch.randn(64, device="cuda"), torch.randn(64, device="cpu")]
compiled_fn, (graph,) = run_and_get_graph_lowering(
torch._inductor.compile, fx_graph, inp
)
self.assertEqual(graph.disable_cudagraphs_reason, None)
self.assertEqual(graph.device_types, {"cuda"})
self.assertEqual(compiled_fn(*inp), fn(*inp))
def test_epilogue_fusion_with_view(self):
class ToyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.linear = torch.nn.Linear(262144, 100)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.relu(self.linear(x))
m = ToyModel().to(device="cuda:0")
input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
from torch._inductor.utils import fresh_cache
with fresh_cache():
cm = torch.compile(m, mode="max-autotune")
out = cm(input_tensor)
out2 = m(input_tensor)
self.assertEqual(out, out2, atol=1e-3, rtol=1e-3)
@config.patch("triton.cudagraphs", True)
def test_cpu_index(self):
@torch.compile(fullgraph=True)
def fn(x):
return x[torch.arange(32)]
result, (graph,) = run_and_get_graph_lowering(
fn, torch.randn(64, device="cuda")
)
self.assertEqual(graph.disable_cudagraphs_reason, None)
self.assertEqual(graph.device_types, {"cuda"})
inp = torch.randn(64, device="cuda", requires_grad=True)
result, (graph,) = run_and_get_graph_lowering(fn, inp)
self.assertEqual(graph.disable_cudagraphs_reason, None)
self.assertEqual(graph.device_types, {"cuda"})
result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward())
self.assertEqual(graph.disable_cudagraphs_reason, None)
self.assertEqual(graph.device_types, {"cuda"})
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
def test_triton_interpret(self):
import subprocess
script = """
import os
os.environ["TRITON_INTERPRET"] = "1"
import torch
@torch.compile()
def foo(x):
return x + 1
# somehow gives different results.. still, check that it doesn't error
foo(torch.rand([256], device="cuda"))
"""
subprocess.run([sys.executable, "-c", script], check=True)
def test_reflection_pad_loop_order(self):
def fn(x, y):
a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect")
b = torch.nn.functional.pad(y, (5, 5, 5, 5), mode="reflect")
return a + b
cfn = torch.compile(fn)
a = torch.rand((10, 10, 10), device="cuda")
b = torch.rand((10, 10, 10), device="cuda")
expect = fn(a, b)
actual, code = run_and_get_code(cfn, a, b)
self.assertEqual(expect, actual)
# Expect the code iterates in contiguous order, and is not tiled
lines = code[0].split("\n")
start = lines.index("@triton.jit")
kernel_code = "\n".join(lines[start : start + 14])
self.assertExpectedInline(
kernel_code,
"""\
@triton.jit
def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = (xindex % 20)
x1 = ((xindex // 20) % 20)
x2 = xindex // 400
x3 = xindex
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
)
@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
def test_int64_index_intermediate(self):
def foo(inp):
view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192])
split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1)
view_23 = None
getitem_17 = split_1[0]
getitem_18 = split_1[1]
getitem_19 = split_1[2]
getitem_20 = split_1[3]
getitem_21 = split_1[4]
getitem_22 = split_1[5]
getitem_23 = split_1[6]
getitem_24 = split_1[7]
split_1 = None
cat_1 = torch.ops.aten.cat.default(
[
getitem_17,
getitem_18,
getitem_19,
getitem_20,
getitem_21,
getitem_22,
getitem_23,
getitem_24,
]
)
getitem_17 = getitem_18 = getitem_19 = getitem_20 = getitem_21 = (
getitem_22
) = getitem_23 = getitem_24 = None
return cat_1
for mark_dynamic in [False, True]:
inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda")
if mark_dynamic:
torch._dynamo.mark_dynamic(inp, 0)
foo_c = torch.compile(foo)
torch.testing.assert_allclose(foo(inp), foo_c(inp))
@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
def test_float8_e8m0fnu(self):
device = "cuda"
dtype = torch.float8_e8m0fnu
hp_dtype = torch.float32 # and torch.bfloat16
def foo(x0):
x1 = x0.to(dtype)
x2 = x1.to(hp_dtype)
return x2
x0 = torch.randn(16, 16, device=device, dtype=hp_dtype)
foo_c = torch.compile(foo, backend="inductor", fullgraph=True)
with torch.no_grad():
y_c = foo_c(x0)
self.assertEqual(foo(x0), y_c)
dtype = torch.float8_e8m0fnu
def foo(x0):
x1 = x0 + 1
x2 = x1.view(dtype).view([16 * 16])
return x2
x0 = torch.randint(0, 255, (16, 16), device=device, dtype=torch.uint8)
foo_c = torch.compile(foo, backend="inductor", fullgraph=True)
with torch.no_grad():
result, code = run_and_get_code(foo_c, x0)
FileCheck().check("call").check_not("torch.ops.aten.reshape.default(").run(
code[0]
)
self.assertEqual(foo(x0), result)
@unittest.skipIf(
not config.is_fbcode(),
"bfloat16 atomic add is only supported in fbcode today #97016",
)
@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
def test_atomic_add_bfloat16(self):
def f(x, y):
return torch.index_select(x, 0, y)
x = torch.randn(
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
y = torch.ones(713268, dtype=torch.int64, device="cuda")
x_ref = x.clone().detach().requires_grad_(True)
y_ref = y.clone().detach()
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
fc = FileCheck()
fc.check("tl.atomic_add")
fc.run(bw_code)
self.assertEqual(f(x_ref, y_ref), out)
def test_red_dtype_mismatch(self):
for per in (True, False):
torch._dynamo.reset()
if not per:
torch._inductor.config.triton.persistent_reductions = False
def f(arg0_1, arg1_1):
embedding = torch.ops.aten.embedding.default(arg1_1, arg0_1)
view = torch.ops.aten.view.default(embedding, [64, 3072])
unsqueeze = torch.ops.aten.unsqueeze.default(view, 0)
expand = torch.ops.aten.expand.default(unsqueeze, [576, -1, -1])
view_1 = torch.ops.aten.view.default(expand, [2, 8, 36, 64, 3072])
permute = torch.ops.aten.permute.default(view_1, [0, 1, 3, 2, 4])
clone = torch.ops.aten.clone.default(
permute, memory_format=torch.contiguous_format
)
view_2 = torch.ops.aten.view.default(clone, [2, 18432, 3072])
iota = torch.ops.prims.iota.default(
36,
start=0,
step=1,
dtype=torch.int64,
device="cuda",
requires_grad=False,
)
view_3 = torch.ops.aten.view.default(iota, [1, 36])
max_1 = torch.ops.aten.max.default(view_3)
return (max_1,)
x = torch.ones(1, 64, device="cuda", dtype=torch.int64)
y = torch.randn(64, 3072, device="cuda", dtype=torch.bfloat16)
out = f(x, y)
self.assertEqual(torch.compile(f)(x, y), out)
@unittest.skipIf(
not config.is_fbcode(),
"bfloat16 atomic add is only supported in fbcode today #97016",
)
@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
@config.patch({"bfloat16_atomic_adds_enabled": False})
def test_atomic_add_bfloat16_config(self):
def f(x, y):
return torch.index_select(x, 0, y)
x = torch.randn(
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
y = torch.ones(713268, dtype=torch.int64, device="cuda")
x_ref = x.clone().detach().requires_grad_(True)
y_ref = y.clone().detach()
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
fc = FileCheck()
fc.check_not("tl.atomic_add")
fc.run(bw_code)
self.assertEqual(f(x_ref, y_ref), out)
@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
@unittest.skipIf(
config.is_fbcode(),
"bfloat16 atomic add is supported in fbcode, so we won't fallback",
)
def test_index_add_fallback(self):
def f(x, y):
return torch.index_select(x, 0, y)
x = torch.randn(
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
y = torch.ones(713268, dtype=torch.int64, device="cuda")
x_ref = x.clone().detach().requires_grad_(True)
y_ref = y.clone().detach()
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
fc = FileCheck()
fc.check("aten.index_add")
fc.run(bw_code)
self.assertEqual(f(x_ref, y_ref), out)
@requires_multigpu()
def test_not_initializing_wrong_device(self):
device_stats = torch.cuda.memory_stats("cuda:0")
@torch.compile()
def foo(x, y):
return x @ y
x = torch.rand([256, 256], device="cuda:1", requires_grad=True)
y = torch.rand([256, 256], device="cuda:1", requires_grad=True)
foo(x, y).sum().backward()
device_stats2 = torch.cuda.memory_stats("cuda:0")
self.assertTrue(
device_stats2["active.all.peak"] <= device_stats["active.all.peak"]
)
@config.patch(
{
"triton.prefer_nd_tiling": True,
"triton.max_tiles": 3,
}
)
def test_3d_tiling(self):
full_size, view_size, num_block_pointers, num_tiles = (
(5, 5, 5, 5, 5),
(3, 3, 5, 3, 5),
1,
2,
)
GPU_TYPE = "cuda"
def get_input() -> torch.Tensor:
device = torch.device(GPU_TYPE)
full = torch.randn(full_size).to(device)
return torch.as_strided(full, view_size, full.stride())
a, b = get_input(), get_input()
opt_fn = torch.compile(functools.partial(torch.add))
result, (code,) = run_and_get_code(opt_fn, a, b)
self.assertEqual(result, a + b)
self.assertIn("znumel", code)
@xfailIfPy312Plus # https://github.com/pytorch/pytorch/issues/142032
@unittest.skipIf(config.is_fbcode(), "Dependence on functorch.einops")
def test_repeated_masked_load(self):
target_size = (8, 2)
mem_eff_temporal_upsampling_interp_chunks = 2
from functorch.einops import rearrange
x = torch.randn(1, 8, 12, 12, 4, dtype=torch.float16, device="cuda")
x = x.permute(0, 1, 4, 2, 3) # make non-contiguous
x = rearrange(x, "b c t h w -> b c t (h w)")
def interpolate_chunked(x):
# chunk along c
chunks = x.chunk(chunks=mem_eff_temporal_upsampling_interp_chunks, dim=1)
r = []
for t in chunks:
r.append(
torch.nn.functional.interpolate(
t.float(), size=target_size, mode="nearest"
).to(t.dtype)
)
out_chunked = torch.cat(r, dim=1)
return out_chunked
out_eager = interpolate_chunked(x)
out_compiled = torch.compile(interpolate_chunked)(x)
self.assertEqual(out_eager, out_compiled)
def test_max_autotune_nograd(self):
"""
https://github.com/pytorch/pytorch/issues/155688
Smallest repro for max-autotune not working with no_grad
Before adding __int__ function to torch.utils._sympy.functions.Identity,
running the max_autotune mode would raise an error:
TypeError: Expected a number but got Identity
"""
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_layers = nn.ModuleList(
[
nn.Linear(4, 1, bias=True),
nn.Linear(5, 1, bias=True),
nn.Linear(6, 1, bias=True),
nn.Linear(7, 1, bias=True),
nn.Linear(8, 1, bias=True),
]
)
def forward(self, x):
for layer in self.linear_layers:
x2 = layer(x)
x2 = F.relu(x2)
x = torch.cat((x, x2), dim=1)
return x
model = ToyModel().to("cuda")
input_tensor = torch.randn((2, 4)).to("cuda")
compile_default = torch.compile(model, mode="default")
compile_max_autotune = torch.compile(model, mode="max-autotune")
with torch.no_grad():
default_output = compile_default(input_tensor)
max_autotune_output = compile_max_autotune(input_tensor)
self.assertEqual(default_output, max_autotune_output)
def test_adaptive_avg_pool3d_issue_157248(self):
"""Test for GitHub issue #157248: Conv2d-unsqueeze-AdaptiveAvgPool3d produces incorrect results"""
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.adaptive_pool = torch.nn.AdaptiveAvgPool3d((4, 4, 4))
def forward(self, x):
x = self.conv(x)
# This specific unsqueeze position was problematic due to zero strides
x = x.unsqueeze(1)
x = self.adaptive_pool(x)
return x
model = Model().cuda()
model.eval()
test_cases = [
(1, 3, 8, 8),
(2, 3, 16, 16),
(1, 3, 32, 32),
(1, 3, 15, 15),
(2, 3, 13, 13),
]
for batch, channels, h, w in test_cases:
with self.subTest(input_shape=(batch, channels, h, w)):
input_tensor = torch.randn(batch, channels, h, w, device="cuda")
# Test eager mode
with torch.no_grad():
eager_output = model(input_tensor)
# Test compiled mode with inductor
compiled_model = torch.compile(model, backend="inductor")
with torch.no_grad():
compiled_output = compiled_model(input_tensor)
# They should be identical (or very close)
self.assertTrue(
torch.allclose(eager_output, compiled_output, rtol=1e-5, atol=1e-5),
f"Results differ for input shape {(batch, channels, h, w)}. "
f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}",
)
def test_qwen2_7b_sdpa_input_alignment_requires_recompile(self):
# SDPA constraints ensures inputs have alignment (8).
device = "cuda"
def forward(q_proj, k_proj, attn_mask):
scale = 0.08838834764831845 # 1/sqrt(128)
B = attn_mask.size(0)
S = attn_mask.size(3)
D = 128
d_model = q_proj.size(1)
query_states = q_proj.view(B, S, -1, D).transpose(1, 2) # [B, Hq, S, D]
q = query_states.contiguous()
Hkv = k_proj.size(1) // D
Hq = query_states.size(1)
nrepeats = Hq // Hkv
key_states = k_proj.view(B, S, -1, D).transpose(1, 2) # [B, Hkv, S, D]
kv_repeated = key_states[:, :, None, :].expand(B, Hkv, nrepeats, S, D)
kv_repeated = kv_repeated.contiguous()
k = kv_repeated.reshape(B, Hq, S, D)
v = k.clone() # value tensor
inf = torch.scalar_tensor(
float("-inf"), dtype=torch.bfloat16, device=device
)
zero = torch.scalar_tensor(0.0, dtype=torch.bfloat16, device=device)
where = torch.where(condition=attn_mask, input=zero, other=inf)
pad_amount = 8 - (S % 8)
padded = torch.nn.functional.pad(
where, (0, pad_amount), value=0.0
) # pad last-dim
sliced = padded[..., :S] # back to [B,1,S,S]
attn_bias = sliced.expand(B, Hq, S, S)
sdpa_out, logsumexp, seed, offset = (
torch.ops.aten._scaled_dot_product_efficient_attention.default(
q,
k,
v,
attn_bias,
dropout_p=0.0,
is_causal=True,
scale=scale,
compute_log_sumexp=True,
)
)
zeros = torch.zeros(B, S, d_model, device=device, dtype=torch.bfloat16)
zeros = zeros.reshape(B, S, Hq, D)
grad_out = zeros.permute(0, 2, 1, 3)
out = (
torch.ops.aten._scaled_dot_product_efficient_attention_backward.default(
grad_out,
q,
k,
v,
attn_bias,
sdpa_out,
logsumexp,
seed,
offset,
dropout_p=0.0,
scale=scale,
grad_input_mask=[True, True, True, False],
)
)
return out
B = 2
S = 6144
D = 128
Hq = 28
Hkv = 4
example_inputs = (
torch.randn((B * S, Hq * D), dtype=torch.bfloat16, device=device), # q_proj
torch.randn(
(B * S, Hkv * D), dtype=torch.bfloat16, device=device
), # k_proj
torch.zeros((B, 1, S, S), dtype=torch.bool, device=device), # attn_mask
)
correct = forward(*example_inputs)
compiled = torch.compile(forward, dynamic=True)
actual = compiled(*example_inputs)
self.assertEqual(actual, correct)
# run once more with seqlen that isn't divisible by 8
S = 6102
example_inputs = (
torch.randn((S * B, Hq * D), dtype=torch.bfloat16, device=device), # q_proj
torch.randn(
(S * B, Hkv * D), dtype=torch.bfloat16, device=device
), # k_proj
torch.zeros((B, 1, S, S), dtype=torch.bool, device=device), # attn_mask
)
correct = forward(*example_inputs)
actual = compiled(*example_inputs)
self.assertEqual(actual, correct)
def test_truediv_numerics_with_eager(self):
from decimal import Decimal
y, x = 7.0, 11.0
@torch.compile
def compiled_divide(x, y):
return x / y
for y_dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
for x_dtype in [
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]:
y_ten = torch.tensor([y], dtype=y_dtype, device="cuda")
x_ten = torch.tensor([x], dtype=x_dtype, device="cuda")
torch._dynamo.reset()
compiled_div = Decimal(compiled_divide(x, y_ten).item())
eager_div = Decimal((x / y_ten).item())
self.assertEqual(eager_div, compiled_div)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
if HAS_CUDA_AND_TRITON and not TEST_WITH_ASAN:
run_tests(needs="filelock")