mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor][triton] support profile_scratch launcher arg (#159772)
This adds support for Triton after https://github.com/triton-lang/triton/pull/7258 landed. https://github.com/triton-lang/triton/pull/7258 adds a new argument to all the Triton kernels - a profile_scratch argument, similar to global_scratch. This PR updates the static cuda launcher and the AOTI kernel callers to pass in these arguments when calling the Triton kernel. Tests: https://github.com/pytorch/pytorch/pull/159158. I also verified these test locally with triton 3.2, 3.3, and 3.4. Fixes: * static_cuda_launcher (test/repro: `python tools/dynamo/verify_dynamo.py`) * AOTI calling logic (test/repro: `TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_linalg_vander_cuda_float32`) Differential Revision: [D79825121](https://our.internmc.facebook.com/intern/diff/D79825121) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159772 Approved by: https://github.com/NikhilAPatel, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f4cb4a3e0
commit
62bac07981
@ -362,8 +362,8 @@ class DeviceOpOverrides:
|
||||
def tma_descriptor_helpers(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def cpp_global_scratch(
|
||||
self, idx: int, workspace: TritonScratchWorkspace
|
||||
def cpp_scratch(
|
||||
self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
|
||||
) -> Optional[tuple[list[str], str]]:
|
||||
# optionally return (scratch definition, arg name)
|
||||
raise NotImplementedError
|
||||
|
@ -211,12 +211,17 @@ class DeferredTritonCallWrapper:
|
||||
]
|
||||
arg_types = [arg_type_loookup[name] for name in call_args]
|
||||
arg_signatures = [triton_meta["signature"][name] for name in call_args]
|
||||
scratch_spaces = {
|
||||
name: params[name]
|
||||
for name in ["global_scratch", "profile_scratch"]
|
||||
if params.get(name, None) is not None
|
||||
}
|
||||
call_args_str = wrapper.generate_args_decl(
|
||||
prefix,
|
||||
call_args,
|
||||
arg_types,
|
||||
arg_signatures,
|
||||
workspace_size=params.get("global_scratch") or 0,
|
||||
scratch_spaces=scratch_spaces,
|
||||
)
|
||||
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
|
||||
launch_kernel_args = [
|
||||
@ -454,7 +459,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
arg_types,
|
||||
arg_signatures,
|
||||
is_triton_kernel=True,
|
||||
workspace_size=0,
|
||||
scratch_spaces: Optional[dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Generates any declarations of args to pass into a kernel call, and then returns the arg names.
|
||||
@ -572,22 +577,26 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
):
|
||||
process_args(arg, arg_type, arg_signature)
|
||||
|
||||
if (
|
||||
is_triton_kernel
|
||||
and (
|
||||
global_scratch := self.device_codegen.cpp_global_scratch(
|
||||
next(self.arg_var_id),
|
||||
workspace=TritonScratchWorkspace(
|
||||
size=workspace_size,
|
||||
generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)),
|
||||
),
|
||||
for scratch_name, workspace_size in (scratch_spaces or {}).items():
|
||||
if (
|
||||
is_triton_kernel
|
||||
and (
|
||||
scratch := self.device_codegen.cpp_scratch(
|
||||
next(self.arg_var_id),
|
||||
workspace=TritonScratchWorkspace(
|
||||
size=workspace_size,
|
||||
generate_dtype_str=(
|
||||
lambda: self.codegen_dtype(torch.uint8)
|
||||
),
|
||||
),
|
||||
prefix=scratch_name,
|
||||
)
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
global_scratch_def, global_scratch_var = global_scratch
|
||||
code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def])
|
||||
new_args.append(f"&{global_scratch_var}")
|
||||
is not None
|
||||
):
|
||||
scratch_def, scratch_var = scratch
|
||||
code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def])
|
||||
new_args.append(f"&{scratch_var}")
|
||||
|
||||
return ", ".join(new_args)
|
||||
|
||||
|
@ -4,7 +4,6 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import triton_version_uses_attrs_dict
|
||||
from ..common import (
|
||||
DeviceOpOverrides,
|
||||
register_device_op_overrides,
|
||||
@ -333,34 +332,33 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
|
||||
def cpp_device_ptr(self) -> str:
|
||||
return "CUdeviceptr"
|
||||
|
||||
def cpp_global_scratch(
|
||||
self, idx: int, workspace: TritonScratchWorkspace
|
||||
def cpp_scratch(
|
||||
self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
|
||||
) -> Optional[tuple[list[str], str]]:
|
||||
if triton_version_uses_attrs_dict():
|
||||
var_name = f"global_scratch_{idx}"
|
||||
if workspace.size > 0:
|
||||
size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};"
|
||||
stride_array = f"int64_t {var_name}_stride[] = {{1}};"
|
||||
device_type = "cached_torch_device_type_cuda"
|
||||
device_idx = "device_idx_"
|
||||
prefix = f"{prefix}_" if prefix else ""
|
||||
var_name = f"{prefix}scratch_{idx}"
|
||||
if workspace.size > 0:
|
||||
size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};"
|
||||
stride_array = f"int64_t {var_name}_stride[] = {{1}};"
|
||||
device_type = "cached_torch_device_type_cuda"
|
||||
device_idx = "device_idx_"
|
||||
|
||||
return (
|
||||
[
|
||||
f"{size_array}",
|
||||
f"{stride_array}",
|
||||
f"AtenTensorHandle {var_name}_handle;",
|
||||
(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, "
|
||||
f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));"
|
||||
),
|
||||
f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);",
|
||||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({var_name}_tensor.data_ptr());",
|
||||
],
|
||||
var_name,
|
||||
)
|
||||
else:
|
||||
return [f"CUdeviceptr {var_name} = 0;"], var_name
|
||||
return None
|
||||
return (
|
||||
[
|
||||
f"{size_array}",
|
||||
f"{stride_array}",
|
||||
f"AtenTensorHandle {var_name}_handle;",
|
||||
(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, "
|
||||
f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));"
|
||||
),
|
||||
f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);",
|
||||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({var_name}_tensor.data_ptr());",
|
||||
],
|
||||
var_name,
|
||||
)
|
||||
else:
|
||||
return [f"CUdeviceptr {var_name} = 0;"], var_name
|
||||
|
||||
|
||||
register_device_op_overrides("cuda", CUDADeviceOpOverrides())
|
||||
|
@ -58,8 +58,8 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
|
||||
def cpp_device_ptr(self) -> str:
|
||||
return "void *"
|
||||
|
||||
def cpp_global_scratch(
|
||||
self, idx: int, workspace: TritonScratchWorkspace
|
||||
def cpp_scratch(
|
||||
self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
|
||||
) -> Optional[tuple[list[str], str]]:
|
||||
return None
|
||||
|
||||
|
@ -63,16 +63,21 @@ class StaticallyLaunchedCudaKernel:
|
||||
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
|
||||
)
|
||||
|
||||
def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
|
||||
if hasattr(kernel.metadata, param_name):
|
||||
if getattr(kernel.metadata, param_name) > 0:
|
||||
raise NotImplementedError(
|
||||
f"{scratch_name} scratch not yet supported"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
|
||||
# Inductor never uses this field or enables it, but we still have to pass
|
||||
# an extra None into the set of params if its enabled
|
||||
if hasattr(kernel.metadata, "global_scratch_size"):
|
||||
if kernel.metadata.global_scratch_size > 0:
|
||||
raise NotImplementedError("Global scratch not yet supported")
|
||||
else:
|
||||
self.has_global_scratch = True
|
||||
else:
|
||||
self.has_global_scratch = False
|
||||
self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size")
|
||||
# same situation for profile scratch - triton-lang/triton#7258
|
||||
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
|
||||
|
||||
self.arg_tys = self.arg_ty_from_signature(kernel.src)
|
||||
self.function: Optional[int] = (
|
||||
@ -214,12 +219,12 @@ class StaticallyLaunchedCudaKernel:
|
||||
# thing, it should always match.
|
||||
# Get rid of constants before passing to cubin launcher
|
||||
|
||||
# Add a None if triton wants an extra parameter to the cubin
|
||||
if self.has_global_scratch:
|
||||
arg_tys = self.arg_tys + "O"
|
||||
args = (*args, None)
|
||||
else:
|
||||
arg_tys = self.arg_tys
|
||||
# Add a None if triton wants extra parameters for scratch spaces
|
||||
arg_tys = self.arg_tys
|
||||
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
|
||||
if has_scratch:
|
||||
arg_tys = arg_tys + "O"
|
||||
args = (*args, None)
|
||||
assert len(args) == len(arg_tys)
|
||||
|
||||
# TODO: can handle grid functions here or in C++, so
|
||||
|
@ -1061,6 +1061,7 @@ class CachingAutotuner(KernelInterface):
|
||||
"def_args": launcher.def_args,
|
||||
"call_args": launcher.call_args,
|
||||
"global_scratch": launcher.global_scratch,
|
||||
"profile_scratch": launcher.profile_scratch,
|
||||
}
|
||||
from torch._inductor.codecache import CudaKernelParamCache
|
||||
|
||||
@ -1754,9 +1755,23 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
||||
launcher.def_args = def_args
|
||||
launcher.call_args = call_args
|
||||
kernel_metadata = getattr(self.kernel, "metadata", None)
|
||||
launcher.global_scratch = getattr(
|
||||
kernel_metadata, "global_scratch_size", None
|
||||
|
||||
# for the scratch arguments: None indicates that the kernel doesn't
|
||||
# take any scratch argument; otherwise a number indicates the number
|
||||
# of bytes of scratch that need to be provided.
|
||||
|
||||
# in AMD's Triton backend, the global scratch size is never provided
|
||||
# (but for AMD it's safe to pass an extra null arg, so always include it)
|
||||
global_scratch: Optional[int] = getattr(
|
||||
kernel_metadata,
|
||||
"global_scratch_size",
|
||||
(0 if torch.version.hip else None),
|
||||
)
|
||||
profile_scratch: Optional[int] = getattr(
|
||||
kernel_metadata, "profile_scratch_size", None
|
||||
)
|
||||
launcher.global_scratch = global_scratch
|
||||
launcher.profile_scratch = profile_scratch
|
||||
return launcher
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user