mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor][triton] support profile_scratch launcher arg (#159772)
This adds support for Triton after https://github.com/triton-lang/triton/pull/7258 landed. https://github.com/triton-lang/triton/pull/7258 adds a new argument to all the Triton kernels - a profile_scratch argument, similar to global_scratch. This PR updates the static cuda launcher and the AOTI kernel callers to pass in these arguments when calling the Triton kernel. Tests: https://github.com/pytorch/pytorch/pull/159158. I also verified these test locally with triton 3.2, 3.3, and 3.4. Fixes: * static_cuda_launcher (test/repro: `python tools/dynamo/verify_dynamo.py`) * AOTI calling logic (test/repro: `TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_linalg_vander_cuda_float32`) Differential Revision: [D79825121](https://our.internmc.facebook.com/intern/diff/D79825121) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159772 Approved by: https://github.com/NikhilAPatel, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f4cb4a3e0
commit
62bac07981
@ -362,8 +362,8 @@ class DeviceOpOverrides:
|
|||||||
def tma_descriptor_helpers(self) -> str:
|
def tma_descriptor_helpers(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def cpp_global_scratch(
|
def cpp_scratch(
|
||||||
self, idx: int, workspace: TritonScratchWorkspace
|
self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
|
||||||
) -> Optional[tuple[list[str], str]]:
|
) -> Optional[tuple[list[str], str]]:
|
||||||
# optionally return (scratch definition, arg name)
|
# optionally return (scratch definition, arg name)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -211,12 +211,17 @@ class DeferredTritonCallWrapper:
|
|||||||
]
|
]
|
||||||
arg_types = [arg_type_loookup[name] for name in call_args]
|
arg_types = [arg_type_loookup[name] for name in call_args]
|
||||||
arg_signatures = [triton_meta["signature"][name] for name in call_args]
|
arg_signatures = [triton_meta["signature"][name] for name in call_args]
|
||||||
|
scratch_spaces = {
|
||||||
|
name: params[name]
|
||||||
|
for name in ["global_scratch", "profile_scratch"]
|
||||||
|
if params.get(name, None) is not None
|
||||||
|
}
|
||||||
call_args_str = wrapper.generate_args_decl(
|
call_args_str = wrapper.generate_args_decl(
|
||||||
prefix,
|
prefix,
|
||||||
call_args,
|
call_args,
|
||||||
arg_types,
|
arg_types,
|
||||||
arg_signatures,
|
arg_signatures,
|
||||||
workspace_size=params.get("global_scratch") or 0,
|
scratch_spaces=scratch_spaces,
|
||||||
)
|
)
|
||||||
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
|
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
|
||||||
launch_kernel_args = [
|
launch_kernel_args = [
|
||||||
@ -454,7 +459,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||||||
arg_types,
|
arg_types,
|
||||||
arg_signatures,
|
arg_signatures,
|
||||||
is_triton_kernel=True,
|
is_triton_kernel=True,
|
||||||
workspace_size=0,
|
scratch_spaces: Optional[dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates any declarations of args to pass into a kernel call, and then returns the arg names.
|
Generates any declarations of args to pass into a kernel call, and then returns the arg names.
|
||||||
@ -572,22 +577,26 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||||||
):
|
):
|
||||||
process_args(arg, arg_type, arg_signature)
|
process_args(arg, arg_type, arg_signature)
|
||||||
|
|
||||||
if (
|
for scratch_name, workspace_size in (scratch_spaces or {}).items():
|
||||||
is_triton_kernel
|
if (
|
||||||
and (
|
is_triton_kernel
|
||||||
global_scratch := self.device_codegen.cpp_global_scratch(
|
and (
|
||||||
next(self.arg_var_id),
|
scratch := self.device_codegen.cpp_scratch(
|
||||||
workspace=TritonScratchWorkspace(
|
next(self.arg_var_id),
|
||||||
size=workspace_size,
|
workspace=TritonScratchWorkspace(
|
||||||
generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)),
|
size=workspace_size,
|
||||||
),
|
generate_dtype_str=(
|
||||||
|
lambda: self.codegen_dtype(torch.uint8)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
prefix=scratch_name,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
is not None
|
||||||
is not None
|
):
|
||||||
):
|
scratch_def, scratch_var = scratch
|
||||||
global_scratch_def, global_scratch_var = global_scratch
|
code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def])
|
||||||
code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def])
|
new_args.append(f"&{scratch_var}")
|
||||||
new_args.append(f"&{global_scratch_var}")
|
|
||||||
|
|
||||||
return ", ".join(new_args)
|
return ", ".join(new_args)
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import triton_version_uses_attrs_dict
|
|
||||||
from ..common import (
|
from ..common import (
|
||||||
DeviceOpOverrides,
|
DeviceOpOverrides,
|
||||||
register_device_op_overrides,
|
register_device_op_overrides,
|
||||||
@ -333,34 +332,33 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
|
|||||||
def cpp_device_ptr(self) -> str:
|
def cpp_device_ptr(self) -> str:
|
||||||
return "CUdeviceptr"
|
return "CUdeviceptr"
|
||||||
|
|
||||||
def cpp_global_scratch(
|
def cpp_scratch(
|
||||||
self, idx: int, workspace: TritonScratchWorkspace
|
self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
|
||||||
) -> Optional[tuple[list[str], str]]:
|
) -> Optional[tuple[list[str], str]]:
|
||||||
if triton_version_uses_attrs_dict():
|
prefix = f"{prefix}_" if prefix else ""
|
||||||
var_name = f"global_scratch_{idx}"
|
var_name = f"{prefix}scratch_{idx}"
|
||||||
if workspace.size > 0:
|
if workspace.size > 0:
|
||||||
size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};"
|
size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};"
|
||||||
stride_array = f"int64_t {var_name}_stride[] = {{1}};"
|
stride_array = f"int64_t {var_name}_stride[] = {{1}};"
|
||||||
device_type = "cached_torch_device_type_cuda"
|
device_type = "cached_torch_device_type_cuda"
|
||||||
device_idx = "device_idx_"
|
device_idx = "device_idx_"
|
||||||
|
|
||||||
return (
|
return (
|
||||||
[
|
[
|
||||||
f"{size_array}",
|
f"{size_array}",
|
||||||
f"{stride_array}",
|
f"{stride_array}",
|
||||||
f"AtenTensorHandle {var_name}_handle;",
|
f"AtenTensorHandle {var_name}_handle;",
|
||||||
(
|
(
|
||||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, "
|
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, "
|
||||||
f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));"
|
f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));"
|
||||||
),
|
),
|
||||||
f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);",
|
f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);",
|
||||||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({var_name}_tensor.data_ptr());",
|
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({var_name}_tensor.data_ptr());",
|
||||||
],
|
],
|
||||||
var_name,
|
var_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return [f"CUdeviceptr {var_name} = 0;"], var_name
|
return [f"CUdeviceptr {var_name} = 0;"], var_name
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
register_device_op_overrides("cuda", CUDADeviceOpOverrides())
|
register_device_op_overrides("cuda", CUDADeviceOpOverrides())
|
||||||
|
@ -58,8 +58,8 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
|
|||||||
def cpp_device_ptr(self) -> str:
|
def cpp_device_ptr(self) -> str:
|
||||||
return "void *"
|
return "void *"
|
||||||
|
|
||||||
def cpp_global_scratch(
|
def cpp_scratch(
|
||||||
self, idx: int, workspace: TritonScratchWorkspace
|
self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
|
||||||
) -> Optional[tuple[list[str], str]]:
|
) -> Optional[tuple[list[str], str]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -63,16 +63,21 @@ class StaticallyLaunchedCudaKernel:
|
|||||||
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
|
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
|
||||||
|
if hasattr(kernel.metadata, param_name):
|
||||||
|
if getattr(kernel.metadata, param_name) > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{scratch_name} scratch not yet supported"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
|
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
|
||||||
# Inductor never uses this field or enables it, but we still have to pass
|
# Inductor never uses this field or enables it, but we still have to pass
|
||||||
# an extra None into the set of params if its enabled
|
# an extra None into the set of params if its enabled
|
||||||
if hasattr(kernel.metadata, "global_scratch_size"):
|
self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size")
|
||||||
if kernel.metadata.global_scratch_size > 0:
|
# same situation for profile scratch - triton-lang/triton#7258
|
||||||
raise NotImplementedError("Global scratch not yet supported")
|
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
|
||||||
else:
|
|
||||||
self.has_global_scratch = True
|
|
||||||
else:
|
|
||||||
self.has_global_scratch = False
|
|
||||||
|
|
||||||
self.arg_tys = self.arg_ty_from_signature(kernel.src)
|
self.arg_tys = self.arg_ty_from_signature(kernel.src)
|
||||||
self.function: Optional[int] = (
|
self.function: Optional[int] = (
|
||||||
@ -214,12 +219,12 @@ class StaticallyLaunchedCudaKernel:
|
|||||||
# thing, it should always match.
|
# thing, it should always match.
|
||||||
# Get rid of constants before passing to cubin launcher
|
# Get rid of constants before passing to cubin launcher
|
||||||
|
|
||||||
# Add a None if triton wants an extra parameter to the cubin
|
# Add a None if triton wants extra parameters for scratch spaces
|
||||||
if self.has_global_scratch:
|
arg_tys = self.arg_tys
|
||||||
arg_tys = self.arg_tys + "O"
|
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
|
||||||
args = (*args, None)
|
if has_scratch:
|
||||||
else:
|
arg_tys = arg_tys + "O"
|
||||||
arg_tys = self.arg_tys
|
args = (*args, None)
|
||||||
assert len(args) == len(arg_tys)
|
assert len(args) == len(arg_tys)
|
||||||
|
|
||||||
# TODO: can handle grid functions here or in C++, so
|
# TODO: can handle grid functions here or in C++, so
|
||||||
|
@ -1061,6 +1061,7 @@ class CachingAutotuner(KernelInterface):
|
|||||||
"def_args": launcher.def_args,
|
"def_args": launcher.def_args,
|
||||||
"call_args": launcher.call_args,
|
"call_args": launcher.call_args,
|
||||||
"global_scratch": launcher.global_scratch,
|
"global_scratch": launcher.global_scratch,
|
||||||
|
"profile_scratch": launcher.profile_scratch,
|
||||||
}
|
}
|
||||||
from torch._inductor.codecache import CudaKernelParamCache
|
from torch._inductor.codecache import CudaKernelParamCache
|
||||||
|
|
||||||
@ -1754,9 +1755,23 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
|||||||
launcher.def_args = def_args
|
launcher.def_args = def_args
|
||||||
launcher.call_args = call_args
|
launcher.call_args = call_args
|
||||||
kernel_metadata = getattr(self.kernel, "metadata", None)
|
kernel_metadata = getattr(self.kernel, "metadata", None)
|
||||||
launcher.global_scratch = getattr(
|
|
||||||
kernel_metadata, "global_scratch_size", None
|
# for the scratch arguments: None indicates that the kernel doesn't
|
||||||
|
# take any scratch argument; otherwise a number indicates the number
|
||||||
|
# of bytes of scratch that need to be provided.
|
||||||
|
|
||||||
|
# in AMD's Triton backend, the global scratch size is never provided
|
||||||
|
# (but for AMD it's safe to pass an extra null arg, so always include it)
|
||||||
|
global_scratch: Optional[int] = getattr(
|
||||||
|
kernel_metadata,
|
||||||
|
"global_scratch_size",
|
||||||
|
(0 if torch.version.hip else None),
|
||||||
)
|
)
|
||||||
|
profile_scratch: Optional[int] = getattr(
|
||||||
|
kernel_metadata, "profile_scratch_size", None
|
||||||
|
)
|
||||||
|
launcher.global_scratch = global_scratch
|
||||||
|
launcher.profile_scratch = profile_scratch
|
||||||
return launcher
|
return launcher
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user