Compare commits

...

1 Commits

Author SHA1 Message Date
d2f5c6f701 save 2025-02-20 11:11:09 -08:00
2 changed files with 60 additions and 30 deletions

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import os
from itertools import chain, count, zip_longest
from typing import Any, Callable, Hashable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import sympy
@ -13,7 +13,13 @@ from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
from .. import config
from ..codecache import CudaKernelParamCache
from ..ir import IRNode, TensorBox
from ..utils import cache_on_self, DeferredLineBase, get_gpu_type, GPU_ALIGN_BYTES
from ..utils import (
cache_on_self,
DeferredLineBase,
get_gpu_type,
GPU_ALIGN_BYTES,
triton_version_uses_attrs_dict,
)
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides
@ -24,6 +30,8 @@ from .wrapper import PythonWrapperCodegen, SymbolicCallArg
if TYPE_CHECKING:
from collections.abc import Hashable
from ..graph import GraphLowering
@ -553,27 +561,43 @@ class CppWrapperGpu(CppWrapperCpu):
device_index, call_args
)
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
# args with value 1 are added into equal_to_1 and constants
# in triton_meta (in the Python codegen) which makes them
# inlined in the PTX and compiled CUBIN
arg_signatures = []
if (
triton_meta is not None
and triton_meta.get("configs")
and triton_meta.get("signature")
):
equal_to_1 = triton_meta["configs"][0].equal_to_1
call_args = [
arg for i, arg in enumerate(call_args) if i not in equal_to_1
]
arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
# extract the arg signatures from triton_meta
arg_signatures = triton_meta["signature"].values()
if triton_version_uses_attrs_dict():
signature = triton_meta["signature"]
arg_signatures = [
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
val for val in signature.values() if val != "constexpr"
]
call_args = [
call_arg
for call_arg, arg_name in zip(call_args, signature)
if signature[arg_name] != "constexpr"
]
arg_types = [
arg_type
for arg_type, arg_name in zip(arg_types, signature)
if signature[arg_name] != "constexpr"
]
else:
# args with value 1 are added into equal_to_1 and constants
# in triton_meta (in the Python codegen) which makes them
# inlined in the PTX and compiled CUBIN
arg_signatures = []
if (
triton_meta is not None
and triton_meta.get("configs")
and triton_meta.get("signature")
):
equal_to_1 = triton_meta["configs"][0].equal_to_1
call_args = [
arg for i, arg in enumerate(call_args) if i not in equal_to_1
]
arg_types = [
t for i, t in enumerate(arg_types) if i not in equal_to_1
]
# extract the arg signatures from triton_meta
arg_signatures = triton_meta["signature"].values()
arg_signatures = [
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
]
call_args_str = self.generate_args_decl(
call_args, arg_types, arg_signatures
)

View File

@ -5775,14 +5775,6 @@ class UserDefinedTritonKernel(ExternKernel):
self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
]
# NOTE: raw_args doesn't include autotuned args.
# But, kernel.constexprs includes indices of autotuned args.
# So, let's recalculate constexpr indices wrt to raw_args.
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if kernel.arg_names.index(kwarg) in kernel.constexprs:
constexpr_indices.append(idx)
if not triton_version_uses_attrs_dict():
"""
Filter out None args.
@ -5793,7 +5785,16 @@ class UserDefinedTritonKernel(ExternKernel):
1. The arg is already tl.constexpr, so leave it in
2. The arg is not tl.constexpr so we have to remove it
"""
constexpr_indices_set = OrderedSet(constexpr_indices)
# NOTE: raw_args doesn't include autotuned args.
# But, kernel.constexprs includes indices of autotuned args.
# So, let's recalculate constexpr indices wrt to raw_args.
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if kernel.arg_names.index(kwarg) in kernel.constexprs:
constexpr_indices.append(idx)
constexpr_indices_set = OrderedSet(constexpr_indices)
REMOVED = object()
raw_args = [
(
@ -5831,6 +5832,11 @@ class UserDefinedTritonKernel(ExternKernel):
equal_to_1.append(idx - index_shift)
triton_meta["configs"][0].equal_to_1 = equal_to_1
else:
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if triton_meta["signature"][kwarg] == "constexpr":
constexpr_indices.append(idx)
# Call to kernel
self.codegen_comment(wrapper)