[Inductor/Triton] Customize triton codegen to optionally preserve input dtype on tl.load (#132406)

Differential Revision: D60536337

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132406
Approved by: https://github.com/jfix71, https://github.com/blaine-rister
This commit is contained in:
Xinran / Allan Rui
2024-08-23 22:58:43 +00:00
committed by PyTorch MergeBot
parent 8ff3a5be1b
commit 1f19ccb5b3
3 changed files with 34 additions and 4 deletions

View File

@ -11396,6 +11396,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
@instantiate_parametrized_tests
class TritonCodeGenTests(TestCase):
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
@ -11900,6 +11901,20 @@ if HAS_GPU and not TEST_WITH_ASAN:
# it does not move the tensor constructor to cuda and keeps it on CPU.
self.assertFalse("empty_strided_cuda(()" in code)
@requires_gpu()
@parametrize("upcast_to_fp32", [False, True])
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
@torch.compile
def func(a, b):
return a * b
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 2
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
func_opt = torch._dynamo.optimize("inductor")(func)
code = run_and_get_triton_code(func_opt, *inps)
fp32_cast_in_code = "float32" in code
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
@config.patch("triton.use_block_ptr", False)
def test_evict_last_non_coalesced_loads(self):
@torch.compile

View File

@ -539,7 +539,10 @@ def triton_compute_type(dtype):
triton_type_name = str(dtype).split(".")[-1]
if triton_type_name == "bool":
triton_type_name = "int1"
elif triton_type_name in ("float16", "bfloat16"):
elif (
triton_type_name in ("float16", "bfloat16")
and config.triton.codegen_upcast_to_fp32
):
# float16 math is done in float32 inside the kernel
triton_type_name = "float32"
elif triton_type_name == "float8_e4m3fn":
@ -557,7 +560,10 @@ def _get_primitive_bitwidth(dtype):
if hasattr(dtype, "is_floating_point"):
if dtype.is_floating_point:
# triton_compute_type changes the bitwidth
if dtype in [torch.bfloat16, torch.float16]:
if (
dtype in [torch.bfloat16, torch.float16]
and config.triton.codegen_upcast_to_fp32
):
return 32
return torch.finfo(dtype).bits
else:
@ -669,7 +675,10 @@ class TritonOverrides(OpOverrides):
# In such as case, we will have to convert the input tensor to
# its src_type, perform bitcast, and then convert the bit-casted
# tensor back to float to ensure we use values with the right precision.
if src_dtype in (torch.float16, torch.bfloat16):
if (
src_dtype in (torch.float16, torch.bfloat16)
and config.triton.codegen_upcast_to_fp32
):
triton_src_dtype = str(src_dtype).split(".")[-1]
cast_x = f"{x}.to(tl.{triton_src_dtype})"
if dtype in (torch.float16, torch.bfloat16):
@ -1778,7 +1787,10 @@ class TritonKernel(SIMDKernel):
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
dtype = V.graph.get_dtype(name)
if dtype in (torch.float16, torch.bfloat16):
if (
dtype in (torch.float16, torch.bfloat16)
and config.triton.codegen_upcast_to_fp32
):
line += ".to(tl.float32)"
if dtype == torch.bool and torch.version.hip is None:
# Workaround for https://github.com/openai/triton/issues/2151

View File

@ -899,6 +899,9 @@ class triton:
# Valid values: "compile_error", "runtime_error", "accuracy"
inject_relu_bug_TESTING_ONLY: Optional[str] = None
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
codegen_upcast_to_fp32 = True
class aot_inductor:
# AOTInductor output path