Compare commits

...

8 Commits

Author SHA1 Message Date
8866cf36de Update
[ghstack-poisoned]
2025-10-17 15:38:36 +00:00
f64ebc7405 Update
[ghstack-poisoned]
2025-10-16 21:18:06 +00:00
8cc17d1f1d Update (base update)
[ghstack-poisoned]
2025-10-16 21:18:06 +00:00
6f49e50103 Update
[ghstack-poisoned]
2025-10-15 19:07:47 +00:00
6c2d4559f2 Update
[ghstack-poisoned]
2025-10-02 22:20:59 +00:00
b014555502 Update
[ghstack-poisoned]
2025-10-02 22:11:17 +00:00
b4fb7be431 Update (base update)
[ghstack-poisoned]
2025-09-30 21:02:37 +00:00
d3db59eefe Update
[ghstack-poisoned]
2025-09-30 21:02:37 +00:00
4 changed files with 33 additions and 2 deletions

View File

@ -8423,6 +8423,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertEqual(fn(x[0:]), x[16:][:16])
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
def test_index_float_zero(self):
def fn(arg0, arg1, arg2):
t1 = torch.tanh(arg0)
t2 = t1.clone()
t2.fill_(arg1.item())
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
return torch.nn.functional.embedding(t3, arg2)
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
# from GPT2ForSequenceClassification
@skip_if_gpu_halide
def test_index_tensor(self):

View File

@ -13,6 +13,7 @@ import pytest
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
class TestFuzzerCompileIssues(TestCase):
@ -220,7 +221,7 @@ class TestFuzzerCompileIssues(TestCase):
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164086")
@pytest.mark.skipUnless(HAS_CUDA_AND_TRITON, "Needs CUDA")
def test_fuzzer_issue_164086(self):
torch.manual_seed(0)

View File

@ -141,6 +141,15 @@ class MetalExprPrinter(ExprPrinter_):
x = self.doprint(expr.args[0])
return f"static_cast<float>({x})"
def _print_Float(self, expr: sympy.Expr) -> str:
if expr.is_integer:
# sympy considers 0.0 to be integer, but Metal doesn't.
# this workaround prints the float as an integer
# xref: https://github.com/sympy/sympy/issues/26620
return str(int(expr))
else:
return str(expr)
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
x = self.doprint(expr.args[0])

View File

@ -736,7 +736,12 @@ class TritonPrinter(PythonPrinter):
)
def _print_Float(self, expr: sympy.Expr) -> str:
if config.is_fbcode() and torch.version.hip:
if expr.is_integer:
# sympy considers 0.0 to be integer, but triton doesn't.
# this workaround prints the float as an integer
# xref: https://github.com/sympy/sympy/issues/26620
ret = str(int(expr))
elif config.is_fbcode() and torch.version.hip:
ret = f"{expr}"
else:
ret = f"tl.full([], {expr}, tl.float64)"