mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor/Triton] Customize triton codegen to optionally preserve input dtype on tl.load (#132406)
Differential Revision: D60536337 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132406 Approved by: https://github.com/jfix71, https://github.com/blaine-rister
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ff3a5be1b
commit
1f19ccb5b3
@ -11396,6 +11396,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
|
||||
copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TritonCodeGenTests(TestCase):
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
@ -11900,6 +11901,20 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
# it does not move the tensor constructor to cuda and keeps it on CPU.
|
||||
self.assertFalse("empty_strided_cuda(()" in code)
|
||||
|
||||
@requires_gpu()
|
||||
@parametrize("upcast_to_fp32", [False, True])
|
||||
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
|
||||
@torch.compile
|
||||
def func(a, b):
|
||||
return a * b
|
||||
|
||||
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 2
|
||||
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
|
||||
func_opt = torch._dynamo.optimize("inductor")(func)
|
||||
code = run_and_get_triton_code(func_opt, *inps)
|
||||
fp32_cast_in_code = "float32" in code
|
||||
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
|
||||
|
||||
@config.patch("triton.use_block_ptr", False)
|
||||
def test_evict_last_non_coalesced_loads(self):
|
||||
@torch.compile
|
||||
|
@ -539,7 +539,10 @@ def triton_compute_type(dtype):
|
||||
triton_type_name = str(dtype).split(".")[-1]
|
||||
if triton_type_name == "bool":
|
||||
triton_type_name = "int1"
|
||||
elif triton_type_name in ("float16", "bfloat16"):
|
||||
elif (
|
||||
triton_type_name in ("float16", "bfloat16")
|
||||
and config.triton.codegen_upcast_to_fp32
|
||||
):
|
||||
# float16 math is done in float32 inside the kernel
|
||||
triton_type_name = "float32"
|
||||
elif triton_type_name == "float8_e4m3fn":
|
||||
@ -557,7 +560,10 @@ def _get_primitive_bitwidth(dtype):
|
||||
if hasattr(dtype, "is_floating_point"):
|
||||
if dtype.is_floating_point:
|
||||
# triton_compute_type changes the bitwidth
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
if (
|
||||
dtype in [torch.bfloat16, torch.float16]
|
||||
and config.triton.codegen_upcast_to_fp32
|
||||
):
|
||||
return 32
|
||||
return torch.finfo(dtype).bits
|
||||
else:
|
||||
@ -669,7 +675,10 @@ class TritonOverrides(OpOverrides):
|
||||
# In such as case, we will have to convert the input tensor to
|
||||
# its src_type, perform bitcast, and then convert the bit-casted
|
||||
# tensor back to float to ensure we use values with the right precision.
|
||||
if src_dtype in (torch.float16, torch.bfloat16):
|
||||
if (
|
||||
src_dtype in (torch.float16, torch.bfloat16)
|
||||
and config.triton.codegen_upcast_to_fp32
|
||||
):
|
||||
triton_src_dtype = str(src_dtype).split(".")[-1]
|
||||
cast_x = f"{x}.to(tl.{triton_src_dtype})"
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
@ -1778,7 +1787,10 @@ class TritonKernel(SIMDKernel):
|
||||
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
|
||||
|
||||
dtype = V.graph.get_dtype(name)
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
if (
|
||||
dtype in (torch.float16, torch.bfloat16)
|
||||
and config.triton.codegen_upcast_to_fp32
|
||||
):
|
||||
line += ".to(tl.float32)"
|
||||
if dtype == torch.bool and torch.version.hip is None:
|
||||
# Workaround for https://github.com/openai/triton/issues/2151
|
||||
|
@ -899,6 +899,9 @@ class triton:
|
||||
# Valid values: "compile_error", "runtime_error", "accuracy"
|
||||
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
||||
|
||||
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
|
||||
codegen_upcast_to_fp32 = True
|
||||
|
||||
|
||||
class aot_inductor:
|
||||
# AOTInductor output path
|
||||
|
Reference in New Issue
Block a user