diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 9b8b80206559..241bad352ccb 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -188,6 +188,7 @@ if RUN_CUDA: BaseTest("test_sum_int"), # bool, int64, int8, uint8 BaseTest("test_transpose"), # multiple outputs, buffer clear BaseTest("test_unspec_inputs"), + BaseTest("test_consecutive_split_cumprod"), BaseTest("test_pointwise_hermite_polynomial_he"), BaseTest("test_pointwise_hermite_polynomial_h"), BaseTest( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 95183c1484fe..02f13667b70c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1420,6 +1420,7 @@ class KernelArgs: arg_defs.append("ws_ptr") call_args.append("workspace") precompile_args.append(self.workspace_arg) + arg_types.append(torch.uint8) return arg_defs, call_args, precompile_args, arg_types def aliases(self): diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 5719d3eba589..13f13f44be9e 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -6,9 +6,9 @@ from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import sympy -from torch import dtype as torch_dtype +from torch import dtype as torch_dtype, uint8 from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name -from torch._inductor.runtime.triton_heuristics import grid as default_grid +from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn from .. import config from ..codecache import CudaKernelParamCache @@ -88,11 +88,11 @@ class DeferredGpuDefaultGrid: grid = self.grid assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" grid = self._process_grid(grid) - grid_callable = self.grid_callable or default_grid + assert self.grid_callable is not None, "grid_callable can't be None" if not self.grid_extra_kwargs: - grid_fn = grid_callable(*grid) + grid_fn = self.grid_callable(*grid) else: - grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) + grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) params = CudaKernelParamCache.get(self.kernel_name) assert ( @@ -102,6 +102,7 @@ class DeferredGpuDefaultGrid: "XBLOCK": params["x_block"], "YBLOCK": params["y_block"], "ZBLOCK": params["z_block"], + "RBLOCK": params["r_block"], } return grid_fn(block_cfg) @@ -338,7 +339,7 @@ class CppWrapperGpu(CppWrapperCpu): kernel_name: str, grid: List[Any], gpu: bool = True, - grid_callable: Optional[Callable[..., Any]] = None, + grid_callable: Optional[Callable[..., Any]] = default_grid_fn, **grid_extra_kwargs, ): """ @@ -440,3 +441,22 @@ class CppWrapperGpu(CppWrapperCpu): ), ) self.writeline("}") + + def generate_workspace_allocation(self, nbytes, device, zero_fill): + line = self.make_allocation( + "workspace", device, uint8, shape=(nbytes,), stride=(1,) + ) + self.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) + if zero_fill: + if config.abi_compatible: + # TODO: remove this function to use the default WrapperCodegen behavior after service platform has zero_() symbol + # default behavior is f"workspace.zero_(){self.ending}" + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())){self.ending}" + ) + else: + self.writeline(f"workspace.zero_(){self.ending}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 629b6286b224..25b267dbc3ee 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -27,6 +27,7 @@ import torch import torch._logging from torch._dynamo.utils import preserve_rng_state from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties +from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing @@ -2797,9 +2798,12 @@ class TritonKernel(SIMDKernel): if tree.prefix == "x" and self.no_x_dim: code.writeline("XBLOCK: tl.constexpr = 1") - def _get_grid_fn(self): + def _get_grid_fn_str(self): return "grid" + def _get_grid_fn(self): + return default_grid_fn + def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): # TODO(jansel): if there are constants, we shouldn't bother passing them as args for tree in self.range_trees: @@ -2828,7 +2832,9 @@ class TritonKernel(SIMDKernel): ws.nbytes, current_device, ws.zero_fill ) - grid = wrapper.generate_default_grid(name, grid) + grid = wrapper.generate_default_grid( + name, grid, grid_callable=self._get_grid_fn() + ) wrapper.generate_kernel_call( name, call_args, @@ -2837,7 +2843,7 @@ class TritonKernel(SIMDKernel): gpu=True, triton=True, arg_types=arg_types, - grid_fn=self._get_grid_fn(), + grid_fn=self._get_grid_fn_str(), triton_meta=self.triton_meta, ) diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 9eea00fbb8d6..31dab3992364 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -6,6 +6,7 @@ import torch._inductor.runtime.hints from torch._inductor import config from torch._inductor.codegen.simd import IterationRangesRoot from torch._inductor.codegen.triton import triton_compute_type, TritonKernel +from torch._inductor.runtime.triton_heuristics import split_scan_grid from torch._prims_common import prod from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv @@ -170,5 +171,8 @@ class TritonSplitScanKernel(TritonKernel): def _get_heuristic(self): return "split_scan" - def _get_grid_fn(self): + def _get_grid_fn_str(self): return "split_scan_grid" + + def _get_grid_fn(self): + return split_scan_grid diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 19e4bb293c7e..5422198194bf 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -595,6 +595,7 @@ class WrapperCodeGen(CodeGen): async_compile = AsyncCompile() generate_example_value = AlgorithmSelectorCache.generate_example_value + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda """ ) @@ -1496,12 +1497,18 @@ class WrapperCodeGen(CodeGen): return SymbolicCallArg(expr, tree.numel) def generate_workspace_allocation(self, nbytes, device, zero_fill): + if isinstance(nbytes, sympy.Expr): + nbytes = V.graph.sizevars.size_hint(nbytes) line = self.make_allocation( "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) ) self.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) if zero_fill: self.writeline(f"workspace.zero_(){self.ending}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") def wrap_kernel_call(self, name, call_args): return f"{name}({', '.join(call_args)}){self.ending}" @@ -1720,7 +1727,12 @@ class WrapperCodeGen(CodeGen): key, arg = arg.split("=") if isinstance(arg_type, torch_dtype): - if arg not in tensor_args: + # workspace allocation is already generated by `generate_workspace_allocation()` + # in `TritonKernel.call_kernel()`. + if arg == "workspace": + arg_str = "workspace" + tensor_args[arg] = arg_str + elif arg not in tensor_args: arg_str = self.generate_example_arg_value( arg, arg_type, raw_arg, i ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index bdc6a1e10bca..ce5ac0b920e5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -753,6 +753,7 @@ class CachingAutotuner(KernelInterface): "x_block": launcher.config.kwargs.get("XBLOCK", 1), "y_block": launcher.config.kwargs.get("YBLOCK", None), "z_block": launcher.config.kwargs.get("ZBLOCK", None), + "r_block": launcher.config.kwargs.get("RBLOCK", None), "num_warps": launcher.bin.num_warps if hasattr(launcher.bin, "num_warps") else launcher.bin.metadata.num_warps, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b470b5f10061..c204fc27460b 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -528,6 +528,8 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle repeats, int64_t* output_size, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index f49bf23b9ce4..2860972f88da 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1166,3 +1166,10 @@ AOTITorchError aoti_torch__alloc_from_pool( strides)); }); } + +AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); + t->zero_(); + }); +}