Compare commits

...

9 Commits

Author SHA1 Message Date
2fffa35acd delete test as instructed
[ghstack-poisoned]
2025-10-22 18:00:35 +00:00
8866cf36de Update
[ghstack-poisoned]
2025-10-17 15:38:36 +00:00
f64ebc7405 Update
[ghstack-poisoned]
2025-10-16 21:18:06 +00:00
8cc17d1f1d Update (base update)
[ghstack-poisoned]
2025-10-16 21:18:06 +00:00
6f49e50103 Update
[ghstack-poisoned]
2025-10-15 19:07:47 +00:00
6c2d4559f2 Update
[ghstack-poisoned]
2025-10-02 22:20:59 +00:00
b014555502 Update
[ghstack-poisoned]
2025-10-02 22:11:17 +00:00
b4fb7be431 Update (base update)
[ghstack-poisoned]
2025-09-30 21:02:37 +00:00
d3db59eefe Update
[ghstack-poisoned]
2025-09-30 21:02:37 +00:00
4 changed files with 32 additions and 62 deletions

View File

@ -8423,6 +8423,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertEqual(fn(x[0:]), x[16:][:16])
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
def test_index_float_zero(self):
def fn(arg0, arg1, arg2):
t1 = torch.tanh(arg0)
t2 = t1.clone()
t2.fill_(arg1.item())
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
return torch.nn.functional.embedding(t3, arg2)
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
# from GPT2ForSequenceClassification
@skip_if_gpu_halide
def test_index_tensor(self):

View File

@ -13,6 +13,7 @@ import pytest
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
class TestFuzzerCompileIssues(TestCase):
@ -220,67 +221,6 @@ class TestFuzzerCompileIssues(TestCase):
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164086")
def test_fuzzer_issue_164086(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2, arg3, arg4, arg5):
t0 = arg0 # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
t1 = torch.tanh(
t0
) # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
t2 = t1.clone()
t2.zero_() # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
t3 = (
arg1 # size=(50000, 128), stride=(50000, 1), dtype=float16, device=cuda
)
t4 = arg2 # size=(46, 128), stride=(46, 1), dtype=float16, device=cuda
t5 = torch.nn.functional.linear(
t3, t4
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t6 = arg3 # size=(50000, 4, 46), stride=(184, 46, 1), dtype=float16, device=cuda
t7 = t6.max(
dim=1
).values # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t8 = arg4 # size=(25786, 46), stride=(46, 1), dtype=float16, device=cuda
t9 = arg5 # size=(24214, 46), stride=(46, 1), dtype=float16, device=cuda
t10 = torch.cat(
[t8, t9], dim=0
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t11 = torch.pow(
torch.pow(torch.pow(torch.pow(t5, t7), t10), t5), t7
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t12 = torch.nn.functional.embedding(
torch.clamp(t2, 0, t11.size(0) - 1).to(torch.long), t11
) # size=(42, 56, 46), stride=(2576, 46, 1), dtype=float16, device=cuda
output = t12
return output
arg0 = torch.randint(0, 1000, [42, 56], dtype=torch.int64, device="cuda")
arg1 = torch.rand(
[50000, 128], dtype=torch.float16, device="cuda", requires_grad=True
)
arg2 = torch.rand(
[46, 128], dtype=torch.float16, device="cuda", requires_grad=True
)
arg3 = torch.rand(
[50000, 4, 46], dtype=torch.float16, device="cuda", requires_grad=True
)
arg4 = torch.rand(
[25786, 46], dtype=torch.float16, device="cuda", requires_grad=True
)
arg5 = torch.rand(
[24214, 46], dtype=torch.float16, device="cuda", requires_grad=True
)
out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163877")
def test_fuzzer_issue_163877(self):
torch.manual_seed(0)

View File

@ -141,6 +141,15 @@ class MetalExprPrinter(ExprPrinter_):
x = self.doprint(expr.args[0])
return f"static_cast<float>({x})"
def _print_Float(self, expr: sympy.Expr) -> str:
if expr.is_integer:
# sympy considers 0.0 to be integer, but Metal doesn't.
# this workaround prints the float as an integer
# xref: https://github.com/sympy/sympy/issues/26620
return str(int(expr))
else:
return str(expr)
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
x = self.doprint(expr.args[0])

View File

@ -736,7 +736,12 @@ class TritonPrinter(PythonPrinter):
)
def _print_Float(self, expr: sympy.Expr) -> str:
if config.is_fbcode() and torch.version.hip:
if expr.is_integer:
# sympy considers 0.0 to be integer, but triton doesn't.
# this workaround prints the float as an integer
# xref: https://github.com/sympy/sympy/issues/26620
ret = str(int(expr))
elif config.is_fbcode() and torch.version.hip:
ret = f"{expr}"
else:
ret = f"tl.full([], {expr}, tl.float64)"