From 62bac0798100e0e06a86b7a4cee1788413e3d0ca Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 7 Aug 2025 21:58:18 -0700 Subject: [PATCH] [inductor][triton] support profile_scratch launcher arg (#159772) This adds support for Triton after https://github.com/triton-lang/triton/pull/7258 landed. https://github.com/triton-lang/triton/pull/7258 adds a new argument to all the Triton kernels - a profile_scratch argument, similar to global_scratch. This PR updates the static cuda launcher and the AOTI kernel callers to pass in these arguments when calling the Triton kernel. Tests: https://github.com/pytorch/pytorch/pull/159158. I also verified these test locally with triton 3.2, 3.3, and 3.4. Fixes: * static_cuda_launcher (test/repro: `python tools/dynamo/verify_dynamo.py`) * AOTI calling logic (test/repro: `TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_linalg_vander_cuda_float32`) Differential Revision: [D79825121](https://our.internmc.facebook.com/intern/diff/D79825121) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159772 Approved by: https://github.com/NikhilAPatel, https://github.com/eellison --- torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/cpp_wrapper_gpu.py | 43 +++++++++------ .../codegen/cuda/device_op_overrides.py | 52 +++++++++---------- .../codegen/xpu/device_op_overrides.py | 4 +- .../_inductor/runtime/static_cuda_launcher.py | 31 ++++++----- torch/_inductor/runtime/triton_heuristics.py | 19 ++++++- 6 files changed, 90 insertions(+), 63 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 471c9030f1e6..40ebbed13ddd 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -362,8 +362,8 @@ class DeviceOpOverrides: def tma_descriptor_helpers(self) -> str: raise NotImplementedError - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: # optionally return (scratch definition, arg name) raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 430511ce4ebf..6bbbab859900 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -211,12 +211,17 @@ class DeferredTritonCallWrapper: ] arg_types = [arg_type_loookup[name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args] + scratch_spaces = { + name: params[name] + for name in ["global_scratch", "profile_scratch"] + if params.get(name, None) is not None + } call_args_str = wrapper.generate_args_decl( prefix, call_args, arg_types, arg_signatures, - workspace_size=params.get("global_scratch") or 0, + scratch_spaces=scratch_spaces, ) prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") launch_kernel_args = [ @@ -454,7 +459,7 @@ class CppWrapperGpu(CppWrapperCpu): arg_types, arg_signatures, is_triton_kernel=True, - workspace_size=0, + scratch_spaces: Optional[dict[str, int]] = None, ): """ Generates any declarations of args to pass into a kernel call, and then returns the arg names. @@ -572,22 +577,26 @@ class CppWrapperGpu(CppWrapperCpu): ): process_args(arg, arg_type, arg_signature) - if ( - is_triton_kernel - and ( - global_scratch := self.device_codegen.cpp_global_scratch( - next(self.arg_var_id), - workspace=TritonScratchWorkspace( - size=workspace_size, - generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), - ), + for scratch_name, workspace_size in (scratch_spaces or {}).items(): + if ( + is_triton_kernel + and ( + scratch := self.device_codegen.cpp_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=( + lambda: self.codegen_dtype(torch.uint8) + ), + ), + prefix=scratch_name, + ) ) - ) - is not None - ): - global_scratch_def, global_scratch_var = global_scratch - code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) - new_args.append(f"&{global_scratch_var}") + is not None + ): + scratch_def, scratch_var = scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def]) + new_args.append(f"&{scratch_var}") return ", ".join(new_args) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 0ba067742294..147515e0decf 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -4,7 +4,6 @@ from typing import Optional import torch -from ...utils import triton_version_uses_attrs_dict from ..common import ( DeviceOpOverrides, register_device_op_overrides, @@ -333,34 +332,33 @@ class CUDADeviceOpOverrides(DeviceOpOverrides): def cpp_device_ptr(self) -> str: return "CUdeviceptr" - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: - if triton_version_uses_attrs_dict(): - var_name = f"global_scratch_{idx}" - if workspace.size > 0: - size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" - stride_array = f"int64_t {var_name}_stride[] = {{1}};" - device_type = "cached_torch_device_type_cuda" - device_idx = "device_idx_" + prefix = f"{prefix}_" if prefix else "" + var_name = f"{prefix}scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" - return ( - [ - f"{size_array}", - f"{stride_array}", - f"AtenTensorHandle {var_name}_handle;", - ( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " - f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" - ), - f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", - f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", - ], - var_name, - ) - else: - return [f"CUdeviceptr {var_name} = 0;"], var_name - return None + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 632cfd29f174..99502ca2dd97 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -58,8 +58,8 @@ class XPUDeviceOpOverrides(DeviceOpOverrides): def cpp_device_ptr(self) -> str: return "void *" - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: return None diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index a52df4745f59..3290e25eeae4 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -63,16 +63,21 @@ class StaticallyLaunchedCudaKernel: kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared ) + def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: + if hasattr(kernel.metadata, param_name): + if getattr(kernel.metadata, param_name) > 0: + raise NotImplementedError( + f"{scratch_name} scratch not yet supported" + ) + return True + return False + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. # Inductor never uses this field or enables it, but we still have to pass # an extra None into the set of params if its enabled - if hasattr(kernel.metadata, "global_scratch_size"): - if kernel.metadata.global_scratch_size > 0: - raise NotImplementedError("Global scratch not yet supported") - else: - self.has_global_scratch = True - else: - self.has_global_scratch = False + self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size") + # same situation for profile scratch - triton-lang/triton#7258 + self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") self.arg_tys = self.arg_ty_from_signature(kernel.src) self.function: Optional[int] = ( @@ -214,12 +219,12 @@ class StaticallyLaunchedCudaKernel: # thing, it should always match. # Get rid of constants before passing to cubin launcher - # Add a None if triton wants an extra parameter to the cubin - if self.has_global_scratch: - arg_tys = self.arg_tys + "O" - args = (*args, None) - else: - arg_tys = self.arg_tys + # Add a None if triton wants extra parameters for scratch spaces + arg_tys = self.arg_tys + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) assert len(args) == len(arg_tys) # TODO: can handle grid functions here or in C++, so diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ba8de8f9829e..8425cba55795 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1061,6 +1061,7 @@ class CachingAutotuner(KernelInterface): "def_args": launcher.def_args, "call_args": launcher.call_args, "global_scratch": launcher.global_scratch, + "profile_scratch": launcher.profile_scratch, } from torch._inductor.codecache import CudaKernelParamCache @@ -1754,9 +1755,23 @@ class TritonCompileResult(CompileResult[CompiledKernel]): launcher.def_args = def_args launcher.call_args = call_args kernel_metadata = getattr(self.kernel, "metadata", None) - launcher.global_scratch = getattr( - kernel_metadata, "global_scratch_size", None + + # for the scratch arguments: None indicates that the kernel doesn't + # take any scratch argument; otherwise a number indicates the number + # of bytes of scratch that need to be provided. + + # in AMD's Triton backend, the global scratch size is never provided + # (but for AMD it's safe to pass an extra null arg, so always include it) + global_scratch: Optional[int] = getattr( + kernel_metadata, + "global_scratch_size", + (0 if torch.version.hip else None), ) + profile_scratch: Optional[int] = getattr( + kernel_metadata, "profile_scratch_size", None + ) + launcher.global_scratch = global_scratch + launcher.profile_scratch = profile_scratch return launcher