[inductor] in emulate_precision_casts, disable fma fusion in triton (#163073)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163073
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Markus Hoehnerbach
2025-09-23 14:07:16 -07:00
committed by PyTorch MergeBot
parent ee75c3d91f
commit eb3fbf5b08
4 changed files with 63 additions and 0 deletions

View File

@ -14196,6 +14196,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
with self.assertRaises(RuntimeError):
compiled = torch.compile(fn, backend="inductor")(a, b)
@requires_cuda_and_triton
@config.patch(emulate_precision_casts=True)
def test_emulate_precision_triton_fp_fusion(self):
def fn(a, b):
return 2.001 * a + b
a = torch.full([256], 0.5001, device=GPU_TYPE, dtype=torch.float16)
b = torch.full([256], -1, device=GPU_TYPE, dtype=torch.float16)
compiled = torch.compile(fn)
out, (code,) = run_and_get_code(compiled, a, b)
self.assertTrue("'enable_fp_fusion': False" in code)
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
# end of class CommonTemplate - add new tests here

View File

@ -2567,6 +2567,52 @@ def forward(self, arg0_1, arg1_1):
expected = torch.compile(fn, fullgraph=True)(inp)
self.assertEqual(actual, expected)
@requires_gpu
@inductor_config.patch("emulate_precision_casts", True)
def test_triton_kernel_emulate_precision_unaffected(self):
@triton.jit
def triton_(in_ptr, out_ptr, numel, add_amount, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(in_ptr + offsets, mask=(offsets < numel))
output = x * x
if add_amount is not None:
output = output + add_amount
tl.store(out_ptr + offsets, output, mask=(offsets < numel))
def fn(x):
y = torch.empty_like(x)
BLOCK_SIZE = 256
grid = (1,)
triton_[grid](x, y, x.numel(), None, BLOCK_SIZE)
return y
t1 = torch.rand(5, device=GPU_TYPE)
fn = torch.compile(fn)
_, (code,) = run_and_get_code(fn, t1)
self.assertTrue("enable_fp_fusion" not in code)
@requires_gpu
@inductor_config.patch("emulate_precision_casts", True)
@inductor_config.patch("max_autotune_gemm_backends", "TRITON")
def test_triton_kernel_emulate_precision_mm_kernels_do_not_change(self):
from torch._inductor.utils import run_and_get_code
@torch.compile(mode="max-autotune")
def fn(a, b):
return a @ b
t1 = torch.rand(512, 512, device=GPU_TYPE)
t2 = torch.rand(512, 512, device=GPU_TYPE)
try:
_, (code,) = run_and_get_code(fn, t1, t2)
self.assertTrue("enable_fp_fusion" not in code)
except Exception as e:
if "NoValidChoicesError" in str(e):
raise unittest.SkipTest(
"where inductor has no triton mm kernels available, this test is meaningless"
) from e
raise
def make_mutation_test(fn):
@requires_gpu

View File

@ -4416,6 +4416,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
# https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
for arg_num in equal_1_arg_indices(signature): # type: ignore[index]
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
triton_meta["enable_fp_fusion"] = not config.emulate_precision_casts
self.triton_meta = triton_meta

View File

@ -755,6 +755,8 @@ class CachingAutotuner(KernelInterface):
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if "enable_fp_fusion" in compile_meta:
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
if HAS_WARP_SPEC:
options.update(
{