[aoti] Fix workspace generation for triton (#135552)

Fixes #131337

- add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`.
- do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead.
- add workspace allocation generation code to `kernel_autotune_calls`. e.g.
```python
    workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8)
    workspace.zero_()
    .....
    triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0)
    del buf2, arg0_1, arg1_1, workspace
```
-  add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code.

The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `.

```cpp
    static constexpr int64_t int_array_0[] = {1280L, };
    static constexpr int64_t int_array_1[] = {1L, };
    AtenTensorHandle workspace_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda,  0, &workspace_handle));

        RAIIAtenTensorHandle workspace(workspace_handle);
        workspace.zero_();
```

- Fix handle grid_fn  for grid computation. Pass in "RBLOCK" to `split_scan_grid`
-  Fix dynamic shapes:
Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined.

The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code.

- We also generate slightly different cpp code depending on if `abi_compatible` is turned on.
```cpp
RAIIAtenTensorHandle workspace(workspace_handle);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get()));
```
vs

```cpp
    at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA);
    workspace.zero_();
```

Test Plan:

```
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1  python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
TORCHINDUCTOR_CPP_WRAPPER=1  python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552
Approved by: https://github.com/desertfire
This commit is contained in:
Shangdi Yu
2024-09-12 23:53:07 +00:00
committed by PyTorch MergeBot
parent 00dc7d4356
commit d383325392
9 changed files with 65 additions and 11 deletions

View File

@ -188,6 +188,7 @@ if RUN_CUDA:
BaseTest("test_sum_int"), # bool, int64, int8, uint8
BaseTest("test_transpose"), # multiple outputs, buffer clear
BaseTest("test_unspec_inputs"),
BaseTest("test_consecutive_split_cumprod"),
BaseTest("test_pointwise_hermite_polynomial_he"),
BaseTest("test_pointwise_hermite_polynomial_h"),
BaseTest(

View File

@ -1420,6 +1420,7 @@ class KernelArgs:
arg_defs.append("ws_ptr")
call_args.append("workspace")
precompile_args.append(self.workspace_arg)
arg_types.append(torch.uint8)
return arg_defs, call_args, precompile_args, arg_types
def aliases(self):

View File

@ -6,9 +6,9 @@ from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union
import sympy
from torch import dtype as torch_dtype
from torch import dtype as torch_dtype, uint8
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from torch._inductor.runtime.triton_heuristics import grid as default_grid
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
from .. import config
from ..codecache import CudaKernelParamCache
@ -88,11 +88,11 @@ class DeferredGpuDefaultGrid:
grid = self.grid
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
grid = self._process_grid(grid)
grid_callable = self.grid_callable or default_grid
assert self.grid_callable is not None, "grid_callable can't be None"
if not self.grid_extra_kwargs:
grid_fn = grid_callable(*grid)
grid_fn = self.grid_callable(*grid)
else:
grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
@ -102,6 +102,7 @@ class DeferredGpuDefaultGrid:
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
"RBLOCK": params["r_block"],
}
return grid_fn(block_cfg)
@ -338,7 +339,7 @@ class CppWrapperGpu(CppWrapperCpu):
kernel_name: str,
grid: List[Any],
gpu: bool = True,
grid_callable: Optional[Callable[..., Any]] = None,
grid_callable: Optional[Callable[..., Any]] = default_grid_fn,
**grid_extra_kwargs,
):
"""
@ -440,3 +441,22 @@ class CppWrapperGpu(CppWrapperCpu):
),
)
self.writeline("}")
def generate_workspace_allocation(self, nbytes, device, zero_fill):
line = self.make_allocation(
"workspace", device, uint8, shape=(nbytes,), stride=(1,)
)
self.writeline(line)
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(line)
if zero_fill:
if config.abi_compatible:
# TODO: remove this function to use the default WrapperCodegen behavior after service platform has zero_() symbol
# default behavior is f"workspace.zero_(){self.ending}"
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())){self.ending}"
)
else:
self.writeline(f"workspace.zero_(){self.ending}")
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}")

View File

@ -27,6 +27,7 @@ import torch
import torch._logging
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
@ -2797,9 +2798,12 @@ class TritonKernel(SIMDKernel):
if tree.prefix == "x" and self.no_x_dim:
code.writeline("XBLOCK: tl.constexpr = 1")
def _get_grid_fn(self):
def _get_grid_fn_str(self):
return "grid"
def _get_grid_fn(self):
return default_grid_fn
def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
for tree in self.range_trees:
@ -2828,7 +2832,9 @@ class TritonKernel(SIMDKernel):
ws.nbytes, current_device, ws.zero_fill
)
grid = wrapper.generate_default_grid(name, grid)
grid = wrapper.generate_default_grid(
name, grid, grid_callable=self._get_grid_fn()
)
wrapper.generate_kernel_call(
name,
call_args,
@ -2837,7 +2843,7 @@ class TritonKernel(SIMDKernel):
gpu=True,
triton=True,
arg_types=arg_types,
grid_fn=self._get_grid_fn(),
grid_fn=self._get_grid_fn_str(),
triton_meta=self.triton_meta,
)

View File

@ -6,6 +6,7 @@ import torch._inductor.runtime.hints
from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._inductor.runtime.triton_heuristics import split_scan_grid
from torch._prims_common import prod
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv
@ -170,5 +171,8 @@ class TritonSplitScanKernel(TritonKernel):
def _get_heuristic(self):
return "split_scan"
def _get_grid_fn(self):
def _get_grid_fn_str(self):
return "split_scan_grid"
def _get_grid_fn(self):
return split_scan_grid

View File

@ -595,6 +595,7 @@ class WrapperCodeGen(CodeGen):
async_compile = AsyncCompile()
generate_example_value = AlgorithmSelectorCache.generate_example_value
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
"""
)
@ -1496,12 +1497,18 @@ class WrapperCodeGen(CodeGen):
return SymbolicCallArg(expr, tree.numel)
def generate_workspace_allocation(self, nbytes, device, zero_fill):
if isinstance(nbytes, sympy.Expr):
nbytes = V.graph.sizevars.size_hint(nbytes)
line = self.make_allocation(
"workspace", device, torch.uint8, shape=(nbytes,), stride=(1,)
)
self.writeline(line)
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(line)
if zero_fill:
self.writeline(f"workspace.zero_(){self.ending}")
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}")
def wrap_kernel_call(self, name, call_args):
return f"{name}({', '.join(call_args)}){self.ending}"
@ -1720,7 +1727,12 @@ class WrapperCodeGen(CodeGen):
key, arg = arg.split("=")
if isinstance(arg_type, torch_dtype):
if arg not in tensor_args:
# workspace allocation is already generated by `generate_workspace_allocation()`
# in `TritonKernel.call_kernel()`.
if arg == "workspace":
arg_str = "workspace"
tensor_args[arg] = arg_str
elif arg not in tensor_args:
arg_str = self.generate_example_arg_value(
arg, arg_type, raw_arg, i
)

View File

@ -753,6 +753,7 @@ class CachingAutotuner(KernelInterface):
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
"y_block": launcher.config.kwargs.get("YBLOCK", None),
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
"r_block": launcher.config.kwargs.get("RBLOCK", None),
"num_warps": launcher.bin.num_warps
if hasattr(launcher.bin, "num_warps")
else launcher.bin.metadata.num_warps,

View File

@ -528,6 +528,8 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked(
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor(
AtenTensorHandle repeats,
int64_t* output_size,

View File

@ -1166,3 +1166,10 @@ AOTITorchError aoti_torch__alloc_from_pool(
strides));
});
}
AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
t->zero_();
});
}