Compare commits

...

1 Commits

Author SHA1 Message Date
2d757f6517 apply 2025-02-18 11:24:34 -08:00
2 changed files with 51 additions and 28 deletions

View File

@ -13,7 +13,13 @@ from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
from .. import config
from ..codecache import CudaKernelParamCache
from ..ir import IRNode, TensorBox
from ..utils import cache_on_self, DeferredLineBase, get_gpu_type, GPU_ALIGN_BYTES
from ..utils import (
cache_on_self,
DeferredLineBase,
get_gpu_type,
GPU_ALIGN_BYTES,
triton_version_uses_attrs_dict,
)
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides
@ -448,6 +454,8 @@ class CppWrapperGpu(CppWrapperCpu):
self.writeline(
f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
)
elif arg is None:
self.writeline(f"auto {var_name} = nullptr;")
else:
self.writeline(f"auto {var_name} = {cexpr(arg)};")
new_args.append(f"&{var_name}")
@ -554,26 +562,36 @@ class CppWrapperGpu(CppWrapperCpu):
)
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
# args with value 1 are added into equal_to_1 and constants
# in triton_meta (in the Python codegen) which makes them
# inlined in the PTX and compiled CUBIN
arg_signatures = []
if (
triton_meta is not None
and triton_meta.get("configs")
and triton_meta.get("signature")
):
equal_to_1 = triton_meta["configs"][0].equal_to_1
call_args = [
arg for i, arg in enumerate(call_args) if i not in equal_to_1
]
arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
# extract the arg signatures from triton_meta
arg_signatures = triton_meta["signature"].values()
arg_signatures = [
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
]
if triton_version_uses_attrs_dict():
arg_signatures = [val for key, val in triton_meta["signature"].items()]
# arg_signatures = [
# val
# for key, val in triton_meta["signature"].items()
# if val != "constexpr"
# ]
else:
# args with value 1 are added into equal_to_1 and constants
# in triton_meta (in the Python codegen) which makes them
# inlined in the PTX and compiled CUBIN
arg_signatures = []
if (
triton_meta is not None
and triton_meta.get("configs")
and triton_meta.get("signature")
):
equal_to_1 = triton_meta["configs"][0].equal_to_1
call_args = [
arg for i, arg in enumerate(call_args) if i not in equal_to_1
]
arg_types = [
t for i, t in enumerate(arg_types) if i not in equal_to_1
]
# extract the arg signatures from triton_meta
arg_signatures = triton_meta["signature"].values()
arg_signatures = [
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
]
call_args_str = self.generate_args_decl(
call_args, arg_types, arg_signatures
)

View File

@ -5775,14 +5775,6 @@ class UserDefinedTritonKernel(ExternKernel):
self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
]
# NOTE: raw_args doesn't include autotuned args.
# But, kernel.constexprs includes indices of autotuned args.
# So, let's recalculate constexpr indices wrt to raw_args.
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if kernel.arg_names.index(kwarg) in kernel.constexprs:
constexpr_indices.append(idx)
if not triton_version_uses_attrs_dict():
"""
Filter out None args.
@ -5793,6 +5785,14 @@ class UserDefinedTritonKernel(ExternKernel):
1. The arg is already tl.constexpr, so leave it in
2. The arg is not tl.constexpr so we have to remove it
"""
# NOTE: raw_args doesn't include autotuned args.
# But, kernel.constexprs includes indices of autotuned args.
# So, let's recalculate constexpr indices wrt to raw_args.
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if kernel.arg_names.index(kwarg) in kernel.constexprs:
constexpr_indices.append(idx)
constexpr_indices_set = OrderedSet(constexpr_indices)
REMOVED = object()
raw_args = [
@ -5831,6 +5831,11 @@ class UserDefinedTritonKernel(ExternKernel):
equal_to_1.append(idx - index_shift)
triton_meta["configs"][0].equal_to_1 = equal_to_1
else:
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if triton_meta["signature"][kwarg] == "constexpr":
constexpr_indices.append(idx)
# Call to kernel
self.codegen_comment(wrapper)