mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] in emulate_precision_casts, disable fma fusion in triton (#163073)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163073 Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee75c3d91f
commit
eb3fbf5b08
@ -14196,6 +14196,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
with self.assertRaises(RuntimeError):
|
||||
compiled = torch.compile(fn, backend="inductor")(a, b)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@config.patch(emulate_precision_casts=True)
|
||||
def test_emulate_precision_triton_fp_fusion(self):
|
||||
def fn(a, b):
|
||||
return 2.001 * a + b
|
||||
|
||||
a = torch.full([256], 0.5001, device=GPU_TYPE, dtype=torch.float16)
|
||||
b = torch.full([256], -1, device=GPU_TYPE, dtype=torch.float16)
|
||||
|
||||
compiled = torch.compile(fn)
|
||||
out, (code,) = run_and_get_code(compiled, a, b)
|
||||
self.assertTrue("'enable_fp_fusion': False" in code)
|
||||
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
@ -2567,6 +2567,52 @@ def forward(self, arg0_1, arg1_1):
|
||||
expected = torch.compile(fn, fullgraph=True)(inp)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@requires_gpu
|
||||
@inductor_config.patch("emulate_precision_casts", True)
|
||||
def test_triton_kernel_emulate_precision_unaffected(self):
|
||||
@triton.jit
|
||||
def triton_(in_ptr, out_ptr, numel, add_amount, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(in_ptr + offsets, mask=(offsets < numel))
|
||||
output = x * x
|
||||
if add_amount is not None:
|
||||
output = output + add_amount
|
||||
tl.store(out_ptr + offsets, output, mask=(offsets < numel))
|
||||
|
||||
def fn(x):
|
||||
y = torch.empty_like(x)
|
||||
BLOCK_SIZE = 256
|
||||
grid = (1,)
|
||||
triton_[grid](x, y, x.numel(), None, BLOCK_SIZE)
|
||||
return y
|
||||
|
||||
t1 = torch.rand(5, device=GPU_TYPE)
|
||||
fn = torch.compile(fn)
|
||||
_, (code,) = run_and_get_code(fn, t1)
|
||||
self.assertTrue("enable_fp_fusion" not in code)
|
||||
|
||||
@requires_gpu
|
||||
@inductor_config.patch("emulate_precision_casts", True)
|
||||
@inductor_config.patch("max_autotune_gemm_backends", "TRITON")
|
||||
def test_triton_kernel_emulate_precision_mm_kernels_do_not_change(self):
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
@torch.compile(mode="max-autotune")
|
||||
def fn(a, b):
|
||||
return a @ b
|
||||
|
||||
t1 = torch.rand(512, 512, device=GPU_TYPE)
|
||||
t2 = torch.rand(512, 512, device=GPU_TYPE)
|
||||
try:
|
||||
_, (code,) = run_and_get_code(fn, t1, t2)
|
||||
self.assertTrue("enable_fp_fusion" not in code)
|
||||
except Exception as e:
|
||||
if "NoValidChoicesError" in str(e):
|
||||
raise unittest.SkipTest(
|
||||
"where inductor has no triton mm kernels available, this test is meaningless"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def make_mutation_test(fn):
|
||||
@requires_gpu
|
||||
|
@ -4416,6 +4416,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
# https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
|
||||
for arg_num in equal_1_arg_indices(signature): # type: ignore[index]
|
||||
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
|
||||
triton_meta["enable_fp_fusion"] = not config.emulate_precision_casts
|
||||
|
||||
self.triton_meta = triton_meta
|
||||
|
||||
|
@ -755,6 +755,8 @@ class CachingAutotuner(KernelInterface):
|
||||
"debug": compile_meta["debug"],
|
||||
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
||||
}
|
||||
if "enable_fp_fusion" in compile_meta:
|
||||
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
|
||||
if HAS_WARP_SPEC:
|
||||
options.update(
|
||||
{
|
||||
|
Reference in New Issue
Block a user