Compare commits

...

1 Commits

Author SHA1 Message Date
297250166b [inductor] getting AOT inductor to treat None args correctly
linter

remove import

ghstack-source-id: 72ceaf4a8e8c5bb2c465cf293c1e436876186645
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138910

lint

address feedback

lint

nit
2024-10-28 14:46:45 -07:00
4 changed files with 99 additions and 7 deletions

View File

@ -49,6 +49,7 @@ from torch.utils import _pytree as pytree
if HAS_CUDA:
import triton # @manual
from triton import language as tl
from torch.testing._internal.triton_utils import (
add_kernel,
@ -3696,6 +3697,56 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), example_inputs)
def test_none_args_aot_codegen(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=["n_elements"],
)
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
# We want to include an arg known to be 1 at compile time
# This is because we remove None args from the arg list; changing the eq_1/constexpr arg indices.
# We want to make sure we recompute these correctly
EQ_1_ARG,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
if in_ptr0 is not None:
x = tl.load(in_ptr0 + offsets, mask=mask)
else:
x = 0.0
output = tl.sin(x) + EQ_1_ARG
tl.store(out_ptr + offsets, output, mask=mask)
def sin_triton(x, out):
n_elements = out.numel()
sin_kernel[(n_elements,)](x, out, 1, n_elements)
return out
x = torch.randn(65, device="cuda")
out = torch.empty_like(x)
not_none_inputs = (x, out)
none_inputs = (None, out)
# AOTI compilation specializes on either None or non-None inputs
# So we have to check twice here
self.check_model(sin_triton, none_inputs)
self.check_model(sin_triton, not_none_inputs)
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

View File

@ -836,7 +836,11 @@ class PythonWrapperCodegen(CodeGen):
for arg in raw_args
]
self.generate_kernel_call(
kernel_name, args, grid_fn=grid_fn, arg_types=arg_types, raw_args=raw_args
kernel_name,
args,
grid_fn=grid_fn,
arg_types=arg_types,
raw_args=raw_args,
)
def generate_tma_descriptor(self, desc):

View File

@ -5414,17 +5414,51 @@ class UserDefinedTritonKernel(ExternKernel):
2. The arg is not tl.constexpr so we have to remove it
"""
constexpr_indices_set = set(constexpr_indices)
REMOVED = object()
raw_args = [
arg
for idx, arg in enumerate(raw_args)
(idx, arg)
if (arg is not None) or (arg is None and idx in constexpr_indices_set)
else (idx, REMOVED)
for idx, arg in enumerate(raw_args)
]
removed_none_args = [idx for idx, val in raw_args if val == REMOVED]
raw_args = list(filter(lambda tup: tup[1] != REMOVED, raw_args))
raw_args = [val for idx, val in raw_args]
# We have to compute the constexpr indices for the new, filtered raw_args
# We also have to adjust equal_to_1.
eq1_indices_set = set(triton_meta["configs"][0].equal_to_1)
constexpr_indices = []
equal_to_1 = []
index_shift = 0
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
# every time we encounter an idx we removed, adjust by one to account for it
# So for example if we had [None, const X]
# iter 1:
# None was removed, adjust=1
# iter 2:
# X is const at idx=1, but the adjusted idx is 0 now, because None was removed
if idx in removed_none_args:
index_shift += 1
continue
arg_index = kernel.arg_names.index(kwarg)
if arg_index in kernel.constexprs:
constexpr_indices.append(idx - index_shift)
if arg_index in eq1_indices_set:
equal_to_1.append(idx - index_shift)
triton_meta["configs"][0].equal_to_1 = equal_to_1
# Call to kernel
self.codegen_comment(wrapper)
wrapper.generate_user_defined_triton_kernel(
new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices
new_name,
raw_args,
self.grid,
configs,
triton_meta,
constexpr_indices,
)
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:

View File

@ -499,11 +499,14 @@ class CachingAutotuner(KernelInterface):
so we use self.fn.constexprs instead.
3. It isn't in the compile_meta signature
"""
none_args = set(compile_meta["constants"].keys())
known_constants = {
arg for i, arg in enumerate(self.fn.arg_names) if i in self.fn.constexprs
}
none_args = none_args.difference(known_constants)
none_args = {
k
for k, v in compile_meta["constants"].items()
if v is None and k not in known_constants
}
none_args = none_args.difference(set(compile_meta["signature"].keys()))
call_args = [