mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti] Fix workspace generation for triton (#135552)
Fixes #131337 - add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`. - do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead. - add workspace allocation generation code to `kernel_autotune_calls`. e.g. ```python workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8) workspace.zero_() ..... triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0) del buf2, arg0_1, arg1_1, workspace ``` - add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code. The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `. ```cpp static constexpr int64_t int_array_0[] = {1280L, }; static constexpr int64_t int_array_1[] = {1L, }; AtenTensorHandle workspace_handle; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda, 0, &workspace_handle)); RAIIAtenTensorHandle workspace(workspace_handle); workspace.zero_(); ``` - Fix handle grid_fn for grid computation. Pass in "RBLOCK" to `split_scan_grid` - Fix dynamic shapes: Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined. The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code. - We also generate slightly different cpp code depending on if `abi_compatible` is turned on. ```cpp RAIIAtenTensorHandle workspace(workspace_handle); AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())); ``` vs ```cpp at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA); workspace.zero_(); ``` Test Plan: ``` TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
00dc7d4356
commit
d383325392
@ -188,6 +188,7 @@ if RUN_CUDA:
|
||||
BaseTest("test_sum_int"), # bool, int64, int8, uint8
|
||||
BaseTest("test_transpose"), # multiple outputs, buffer clear
|
||||
BaseTest("test_unspec_inputs"),
|
||||
BaseTest("test_consecutive_split_cumprod"),
|
||||
BaseTest("test_pointwise_hermite_polynomial_he"),
|
||||
BaseTest("test_pointwise_hermite_polynomial_h"),
|
||||
BaseTest(
|
||||
|
@ -1420,6 +1420,7 @@ class KernelArgs:
|
||||
arg_defs.append("ws_ptr")
|
||||
call_args.append("workspace")
|
||||
precompile_args.append(self.workspace_arg)
|
||||
arg_types.append(torch.uint8)
|
||||
return arg_defs, call_args, precompile_args, arg_types
|
||||
|
||||
def aliases(self):
|
||||
|
@ -6,9 +6,9 @@ from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
from torch import dtype as torch_dtype
|
||||
from torch import dtype as torch_dtype, uint8
|
||||
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
||||
from torch._inductor.runtime.triton_heuristics import grid as default_grid
|
||||
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
|
||||
|
||||
from .. import config
|
||||
from ..codecache import CudaKernelParamCache
|
||||
@ -88,11 +88,11 @@ class DeferredGpuDefaultGrid:
|
||||
grid = self.grid
|
||||
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
|
||||
grid = self._process_grid(grid)
|
||||
grid_callable = self.grid_callable or default_grid
|
||||
assert self.grid_callable is not None, "grid_callable can't be None"
|
||||
if not self.grid_extra_kwargs:
|
||||
grid_fn = grid_callable(*grid)
|
||||
grid_fn = self.grid_callable(*grid)
|
||||
else:
|
||||
grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)
|
||||
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
|
||||
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert (
|
||||
@ -102,6 +102,7 @@ class DeferredGpuDefaultGrid:
|
||||
"XBLOCK": params["x_block"],
|
||||
"YBLOCK": params["y_block"],
|
||||
"ZBLOCK": params["z_block"],
|
||||
"RBLOCK": params["r_block"],
|
||||
}
|
||||
return grid_fn(block_cfg)
|
||||
|
||||
@ -338,7 +339,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
kernel_name: str,
|
||||
grid: List[Any],
|
||||
gpu: bool = True,
|
||||
grid_callable: Optional[Callable[..., Any]] = None,
|
||||
grid_callable: Optional[Callable[..., Any]] = default_grid_fn,
|
||||
**grid_extra_kwargs,
|
||||
):
|
||||
"""
|
||||
@ -440,3 +441,22 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
),
|
||||
)
|
||||
self.writeline("}")
|
||||
|
||||
def generate_workspace_allocation(self, nbytes, device, zero_fill):
|
||||
line = self.make_allocation(
|
||||
"workspace", device, uint8, shape=(nbytes,), stride=(1,)
|
||||
)
|
||||
self.writeline(line)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.writeline(line)
|
||||
if zero_fill:
|
||||
if config.abi_compatible:
|
||||
# TODO: remove this function to use the default WrapperCodegen behavior after service platform has zero_() symbol
|
||||
# default behavior is f"workspace.zero_(){self.ending}"
|
||||
self.writeline(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())){self.ending}"
|
||||
)
|
||||
else:
|
||||
self.writeline(f"workspace.zero_(){self.ending}")
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}")
|
||||
|
@ -27,6 +27,7 @@ import torch
|
||||
import torch._logging
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
|
||||
@ -2797,9 +2798,12 @@ class TritonKernel(SIMDKernel):
|
||||
if tree.prefix == "x" and self.no_x_dim:
|
||||
code.writeline("XBLOCK: tl.constexpr = 1")
|
||||
|
||||
def _get_grid_fn(self):
|
||||
def _get_grid_fn_str(self):
|
||||
return "grid"
|
||||
|
||||
def _get_grid_fn(self):
|
||||
return default_grid_fn
|
||||
|
||||
def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):
|
||||
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
|
||||
for tree in self.range_trees:
|
||||
@ -2828,7 +2832,9 @@ class TritonKernel(SIMDKernel):
|
||||
ws.nbytes, current_device, ws.zero_fill
|
||||
)
|
||||
|
||||
grid = wrapper.generate_default_grid(name, grid)
|
||||
grid = wrapper.generate_default_grid(
|
||||
name, grid, grid_callable=self._get_grid_fn()
|
||||
)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
@ -2837,7 +2843,7 @@ class TritonKernel(SIMDKernel):
|
||||
gpu=True,
|
||||
triton=True,
|
||||
arg_types=arg_types,
|
||||
grid_fn=self._get_grid_fn(),
|
||||
grid_fn=self._get_grid_fn_str(),
|
||||
triton_meta=self.triton_meta,
|
||||
)
|
||||
|
||||
|
@ -6,6 +6,7 @@ import torch._inductor.runtime.hints
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
|
||||
from torch._inductor.runtime.triton_heuristics import split_scan_grid
|
||||
from torch._prims_common import prod
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
@ -170,5 +171,8 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
def _get_heuristic(self):
|
||||
return "split_scan"
|
||||
|
||||
def _get_grid_fn(self):
|
||||
def _get_grid_fn_str(self):
|
||||
return "split_scan_grid"
|
||||
|
||||
def _get_grid_fn(self):
|
||||
return split_scan_grid
|
||||
|
@ -595,6 +595,7 @@ class WrapperCodeGen(CodeGen):
|
||||
|
||||
async_compile = AsyncCompile()
|
||||
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
||||
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
||||
"""
|
||||
)
|
||||
|
||||
@ -1496,12 +1497,18 @@ class WrapperCodeGen(CodeGen):
|
||||
return SymbolicCallArg(expr, tree.numel)
|
||||
|
||||
def generate_workspace_allocation(self, nbytes, device, zero_fill):
|
||||
if isinstance(nbytes, sympy.Expr):
|
||||
nbytes = V.graph.sizevars.size_hint(nbytes)
|
||||
line = self.make_allocation(
|
||||
"workspace", device, torch.uint8, shape=(nbytes,), stride=(1,)
|
||||
)
|
||||
self.writeline(line)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.writeline(line)
|
||||
if zero_fill:
|
||||
self.writeline(f"workspace.zero_(){self.ending}")
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}")
|
||||
|
||||
def wrap_kernel_call(self, name, call_args):
|
||||
return f"{name}({', '.join(call_args)}){self.ending}"
|
||||
@ -1720,7 +1727,12 @@ class WrapperCodeGen(CodeGen):
|
||||
key, arg = arg.split("=")
|
||||
|
||||
if isinstance(arg_type, torch_dtype):
|
||||
if arg not in tensor_args:
|
||||
# workspace allocation is already generated by `generate_workspace_allocation()`
|
||||
# in `TritonKernel.call_kernel()`.
|
||||
if arg == "workspace":
|
||||
arg_str = "workspace"
|
||||
tensor_args[arg] = arg_str
|
||||
elif arg not in tensor_args:
|
||||
arg_str = self.generate_example_arg_value(
|
||||
arg, arg_type, raw_arg, i
|
||||
)
|
||||
|
@ -753,6 +753,7 @@ class CachingAutotuner(KernelInterface):
|
||||
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
|
||||
"y_block": launcher.config.kwargs.get("YBLOCK", None),
|
||||
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
|
||||
"r_block": launcher.config.kwargs.get("RBLOCK", None),
|
||||
"num_warps": launcher.bin.num_warps
|
||||
if hasattr(launcher.bin, "num_warps")
|
||||
else launcher.bin.metadata.num_warps,
|
||||
|
@ -528,6 +528,8 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked(
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor(
|
||||
AtenTensorHandle repeats,
|
||||
int64_t* output_size,
|
||||
|
@ -1166,3 +1166,10 @@ AOTITorchError aoti_torch__alloc_from_pool(
|
||||
strides));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
|
||||
t->zero_();
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user