[inductor][triton] support profile_scratch launcher arg (#159772)

This adds support for Triton after https://github.com/triton-lang/triton/pull/7258 landed. https://github.com/triton-lang/triton/pull/7258 adds a new argument to all the Triton kernels - a profile_scratch argument, similar to global_scratch. This PR updates the static cuda launcher and the AOTI kernel callers to pass in these arguments when calling the Triton kernel.

Tests: https://github.com/pytorch/pytorch/pull/159158. I also verified these test locally with triton 3.2, 3.3, and 3.4.

Fixes:
* static_cuda_launcher (test/repro: `python tools/dynamo/verify_dynamo.py`)
* AOTI calling logic (test/repro: `TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_linalg_vander_cuda_float32`)

Differential Revision: [D79825121](https://our.internmc.facebook.com/intern/diff/D79825121)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159772
Approved by: https://github.com/NikhilAPatel, https://github.com/eellison
This commit is contained in:
David Berard
2025-08-07 21:58:18 -07:00
committed by PyTorch MergeBot
parent 7f4cb4a3e0
commit 62bac07981
6 changed files with 90 additions and 63 deletions

View File

@ -362,8 +362,8 @@ class DeviceOpOverrides:
def tma_descriptor_helpers(self) -> str: def tma_descriptor_helpers(self) -> str:
raise NotImplementedError raise NotImplementedError
def cpp_global_scratch( def cpp_scratch(
self, idx: int, workspace: TritonScratchWorkspace self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
) -> Optional[tuple[list[str], str]]: ) -> Optional[tuple[list[str], str]]:
# optionally return (scratch definition, arg name) # optionally return (scratch definition, arg name)
raise NotImplementedError raise NotImplementedError

View File

@ -211,12 +211,17 @@ class DeferredTritonCallWrapper:
] ]
arg_types = [arg_type_loookup[name] for name in call_args] arg_types = [arg_type_loookup[name] for name in call_args]
arg_signatures = [triton_meta["signature"][name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args]
scratch_spaces = {
name: params[name]
for name in ["global_scratch", "profile_scratch"]
if params.get(name, None) is not None
}
call_args_str = wrapper.generate_args_decl( call_args_str = wrapper.generate_args_decl(
prefix, prefix,
call_args, call_args,
arg_types, arg_types,
arg_signatures, arg_signatures,
workspace_size=params.get("global_scratch") or 0, scratch_spaces=scratch_spaces,
) )
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
launch_kernel_args = [ launch_kernel_args = [
@ -454,7 +459,7 @@ class CppWrapperGpu(CppWrapperCpu):
arg_types, arg_types,
arg_signatures, arg_signatures,
is_triton_kernel=True, is_triton_kernel=True,
workspace_size=0, scratch_spaces: Optional[dict[str, int]] = None,
): ):
""" """
Generates any declarations of args to pass into a kernel call, and then returns the arg names. Generates any declarations of args to pass into a kernel call, and then returns the arg names.
@ -572,22 +577,26 @@ class CppWrapperGpu(CppWrapperCpu):
): ):
process_args(arg, arg_type, arg_signature) process_args(arg, arg_type, arg_signature)
if ( for scratch_name, workspace_size in (scratch_spaces or {}).items():
is_triton_kernel if (
and ( is_triton_kernel
global_scratch := self.device_codegen.cpp_global_scratch( and (
next(self.arg_var_id), scratch := self.device_codegen.cpp_scratch(
workspace=TritonScratchWorkspace( next(self.arg_var_id),
size=workspace_size, workspace=TritonScratchWorkspace(
generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), size=workspace_size,
), generate_dtype_str=(
lambda: self.codegen_dtype(torch.uint8)
),
),
prefix=scratch_name,
)
) )
) is not None
is not None ):
): scratch_def, scratch_var = scratch
global_scratch_def, global_scratch_var = global_scratch code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def])
code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) new_args.append(f"&{scratch_var}")
new_args.append(f"&{global_scratch_var}")
return ", ".join(new_args) return ", ".join(new_args)

View File

@ -4,7 +4,6 @@ from typing import Optional
import torch import torch
from ...utils import triton_version_uses_attrs_dict
from ..common import ( from ..common import (
DeviceOpOverrides, DeviceOpOverrides,
register_device_op_overrides, register_device_op_overrides,
@ -333,34 +332,33 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
def cpp_device_ptr(self) -> str: def cpp_device_ptr(self) -> str:
return "CUdeviceptr" return "CUdeviceptr"
def cpp_global_scratch( def cpp_scratch(
self, idx: int, workspace: TritonScratchWorkspace self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
) -> Optional[tuple[list[str], str]]: ) -> Optional[tuple[list[str], str]]:
if triton_version_uses_attrs_dict(): prefix = f"{prefix}_" if prefix else ""
var_name = f"global_scratch_{idx}" var_name = f"{prefix}scratch_{idx}"
if workspace.size > 0: if workspace.size > 0:
size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};"
stride_array = f"int64_t {var_name}_stride[] = {{1}};" stride_array = f"int64_t {var_name}_stride[] = {{1}};"
device_type = "cached_torch_device_type_cuda" device_type = "cached_torch_device_type_cuda"
device_idx = "device_idx_" device_idx = "device_idx_"
return ( return (
[ [
f"{size_array}", f"{size_array}",
f"{stride_array}", f"{stride_array}",
f"AtenTensorHandle {var_name}_handle;", f"AtenTensorHandle {var_name}_handle;",
( (
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, "
f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));"
), ),
f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);",
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({var_name}_tensor.data_ptr());", f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({var_name}_tensor.data_ptr());",
], ],
var_name, var_name,
) )
else: else:
return [f"CUdeviceptr {var_name} = 0;"], var_name return [f"CUdeviceptr {var_name} = 0;"], var_name
return None
register_device_op_overrides("cuda", CUDADeviceOpOverrides()) register_device_op_overrides("cuda", CUDADeviceOpOverrides())

View File

@ -58,8 +58,8 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
def cpp_device_ptr(self) -> str: def cpp_device_ptr(self) -> str:
return "void *" return "void *"
def cpp_global_scratch( def cpp_scratch(
self, idx: int, workspace: TritonScratchWorkspace self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
) -> Optional[tuple[list[str], str]]: ) -> Optional[tuple[list[str], str]]:
return None return None

View File

@ -63,16 +63,21 @@ class StaticallyLaunchedCudaKernel:
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
) )
def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
if hasattr(kernel.metadata, param_name):
if getattr(kernel.metadata, param_name) > 0:
raise NotImplementedError(
f"{scratch_name} scratch not yet supported"
)
return True
return False
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
# Inductor never uses this field or enables it, but we still have to pass # Inductor never uses this field or enables it, but we still have to pass
# an extra None into the set of params if its enabled # an extra None into the set of params if its enabled
if hasattr(kernel.metadata, "global_scratch_size"): self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size")
if kernel.metadata.global_scratch_size > 0: # same situation for profile scratch - triton-lang/triton#7258
raise NotImplementedError("Global scratch not yet supported") self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
else:
self.has_global_scratch = True
else:
self.has_global_scratch = False
self.arg_tys = self.arg_ty_from_signature(kernel.src) self.arg_tys = self.arg_ty_from_signature(kernel.src)
self.function: Optional[int] = ( self.function: Optional[int] = (
@ -214,12 +219,12 @@ class StaticallyLaunchedCudaKernel:
# thing, it should always match. # thing, it should always match.
# Get rid of constants before passing to cubin launcher # Get rid of constants before passing to cubin launcher
# Add a None if triton wants an extra parameter to the cubin # Add a None if triton wants extra parameters for scratch spaces
if self.has_global_scratch: arg_tys = self.arg_tys
arg_tys = self.arg_tys + "O" for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
args = (*args, None) if has_scratch:
else: arg_tys = arg_tys + "O"
arg_tys = self.arg_tys args = (*args, None)
assert len(args) == len(arg_tys) assert len(args) == len(arg_tys)
# TODO: can handle grid functions here or in C++, so # TODO: can handle grid functions here or in C++, so

View File

@ -1061,6 +1061,7 @@ class CachingAutotuner(KernelInterface):
"def_args": launcher.def_args, "def_args": launcher.def_args,
"call_args": launcher.call_args, "call_args": launcher.call_args,
"global_scratch": launcher.global_scratch, "global_scratch": launcher.global_scratch,
"profile_scratch": launcher.profile_scratch,
} }
from torch._inductor.codecache import CudaKernelParamCache from torch._inductor.codecache import CudaKernelParamCache
@ -1754,9 +1755,23 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
launcher.def_args = def_args launcher.def_args = def_args
launcher.call_args = call_args launcher.call_args = call_args
kernel_metadata = getattr(self.kernel, "metadata", None) kernel_metadata = getattr(self.kernel, "metadata", None)
launcher.global_scratch = getattr(
kernel_metadata, "global_scratch_size", None # for the scratch arguments: None indicates that the kernel doesn't
# take any scratch argument; otherwise a number indicates the number
# of bytes of scratch that need to be provided.
# in AMD's Triton backend, the global scratch size is never provided
# (but for AMD it's safe to pass an extra null arg, so always include it)
global_scratch: Optional[int] = getattr(
kernel_metadata,
"global_scratch_size",
(0 if torch.version.hip else None),
) )
profile_scratch: Optional[int] = getattr(
kernel_metadata, "profile_scratch_size", None
)
launcher.global_scratch = global_scratch
launcher.profile_scratch = profile_scratch
return launcher return launcher