Fixes#131337
- add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`.
- do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead.
- add workspace allocation generation code to `kernel_autotune_calls`. e.g.
```python
workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8)
workspace.zero_()
.....
triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0)
del buf2, arg0_1, arg1_1, workspace
```
- add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code.
The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `.
```cpp
static constexpr int64_t int_array_0[] = {1280L, };
static constexpr int64_t int_array_1[] = {1L, };
AtenTensorHandle workspace_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda, 0, &workspace_handle));
RAIIAtenTensorHandle workspace(workspace_handle);
workspace.zero_();
```
- Fix handle grid_fn for grid computation. Pass in "RBLOCK" to `split_scan_grid`
- Fix dynamic shapes:
Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined.
The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code.
- We also generate slightly different cpp code depending on if `abi_compatible` is turned on.
```cpp
RAIIAtenTensorHandle workspace(workspace_handle);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get()));
```
vs
```cpp
at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA);
workspace.zero_();
```
Test Plan:
```
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552
Approved by: https://github.com/desertfire
Summary:
As title. Follow up to add stats summary (mean/min/max, etc) for jit inductor tensor value printing as well.
The inductor python wrapper code level printing would look something like this:
{F1859224287}
Test Plan: CI
Reviewed By: chenyang78
Differential Revision: D62415575
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135887
Approved by: https://github.com/chenyang78
Fixes#131337
- add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`.
- do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead.
- add workspace allocation generation code to `kernel_autotune_calls`. e.g.
```python
workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8)
workspace.zero_()
.....
triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0)
del buf2, arg0_1, arg1_1, workspace
```
- add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code.
The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `.
```cpp
static constexpr int64_t int_array_0[] = {1280L, };
static constexpr int64_t int_array_1[] = {1L, };
AtenTensorHandle workspace_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda, 0, &workspace_handle));
RAIIAtenTensorHandle workspace(workspace_handle);
workspace.zero_();
```
- Fix handle grid_fn for grid computation. Pass in "RBLOCK" to `split_scan_grid`
- Fix dynamic shapes:
Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined.
The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code.
- We also generate slightly different cpp code depending on if `abi_compatible` is turned on.
```cpp
RAIIAtenTensorHandle workspace(workspace_handle);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get()));
```
vs
```cpp
at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA);
workspace.zero_();
```
Test Plan:
```
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552
Approved by: https://github.com/desertfire
This PR add dynamic shapes support to foreach and combo kernels for horizontal fusion.
A flag `combo_kernel_foreach_dynamic_shapes` (default False to avoid disturb production workflows) is added to _inductor/config.py. Setting it to True enables automatic dynamic shapes for foreach kernels. It is always enabled for combo kernels cases. Added unit cases.
This PR also fixes a flaky test case for [T198833257](https://www.internalfb.com/intern/tasks/?t=198833257)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134477
Approved by: https://github.com/mlazos
Fixes#130394
TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in https://github.com/pytorch/pytorch/issues/130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue. This PR enables dense and non-dense outputs' strides follow the strides required by semantics.
The comparison between the original and after this fix for the test is the below.
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 8
x1 = (xindex // 8)
- x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask)
tmp1 = tmp0 + tmp0
- tl.store(out_ptr0 + (x2), tmp1, xmask)
+ tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask)
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (16, 8), (16, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
- buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32)
+ buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32)
stream0 = get_raw_stream(0)
triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0)
del arg0_1
return (buf1, )
```
The buf1 is created with exact stride required by users, and its values are written in same stride with the input.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130956
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/desertfire
If the scalar tensor is an output tensor, it shouldn't be unwrapped (i.e. `.item()` called) since `tl.store` requires a pointer type for outputs. This issue only occurs for mutated buffers: the input tensor is also used as an output tensor.
Fixes #ISSUE_NUMBER
@yanboliang @jansel @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132859
Approved by: https://github.com/jansel
This PR enables dynamic shapes for the CK backend for gemm max autotune (see #125453).
This is achieved via unhardcoding the problem sizes from the template body and passing them as parameters instead.
We handle passing the problem sizes for the kernel call as well as for the benchmark call.
# Testing
`pytest test/inductor/test_ck_backend.py [-k dynamic]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133285
Approved by: https://github.com/ColinPeppler
Summary:
Follow up small diff to fix a couple issues:
- add condition for cuda/gpu case to only print kernel name list in the second pass i.e. when we do the cpp wrapper codegen
- other minor fixes around `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT` option
Test Plan:
```
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="triton_poi_fused_0" AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda
```
Differential Revision: D60954888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133016
Approved by: https://github.com/ColinPeppler
Summary:
Reland #124969 by backing out D60397377 "Back out "[1/2] PT2 Inductor ComboKernels - Foreach cases (#124969)""
The original diff D54134695 was reverted because of failure of ads nightly cogwheel tests.
The root cause: the logic for generating mask in Triton kernel needed update after a recent refactoring on triton.py. This diff includes the fix of the root cause.
See D54134695 or #124969 for more details.
Test Plan:
Originally failed tests
f585704630
f585733786
Diff patched:
f586664028
f586663820
Differential Revision: D60458597
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132182
Approved by: https://github.com/Yuzhen11
## What is sympy fn str arg?
It's a string such as `sqrt` which also happens to be a real sympy function (e.g. `sympy.sqrt`)
## Crash
```
torch/_inductor/sizevars.py", line 468, in symbolic_hint
expr = self.simplify(expr) # where expr is 'sqrt'
torch/_inductor/sizevars.py", line 66, in simplify
return sympy.expand(expr).xreplace(self.replacements)
sympy/core/function.py", line 2816, in expand
return sympify(e).expand(deep=deep, modulus=modulus, **hints)
AttributeError: 'function' object has no attribute 'expand'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131253
Approved by: https://github.com/desertfire
Fixes#126338
## Issue Summary
When torchinductor compiles the combination `functional_collective -> view.dtype -> wait`, a memory leak occurs. This happens because `view.dtype` is compiled into an out-of-place Triton kernel that copies the input data to a new tensor, even if the data hasn't completed collection via the wait operation. The tensor used by `collective` is only freed when the `wait` operation triggers the garbage collector, see [~WorkRegistry](https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/Functional.cpp#L41). However, since `wait` now waits for a new tensor, the previous one is never freed. The `view.dtype` should only check the metadata instead of creating a new tensor. The current lowering is against its semantics and causes memory leaks.
See more great discussions in the #126338
This kind of lowering also generates unnecessary triton kernels for `view.dtype` when it can't be fused with other operations.
## Fix
The function `aten.view.dtype` is a CPU operation that changes the metadata of its input. After discussions with @eellison and @bdhirsh, we decided to change the lowering of `aten.view.dtype` to ensure it fallback properly to the correct `aten.view.dtype` instead of generating a Triton kernel in some cases. This approach also preserves the same semantics of the view operation.
When the model calls `aten.view.dtype` with a data type whose bit width matches the input's original data type, we lower it to the newly added `DtypeView` in IR, acting like a `ReinterpretView`. When the operation can be fused, its `make_loader` is called to maintain the correct type conversion for each load instruction. When the operation can't be fused, it falls back to `aten.view.dtype` to avoid Triton kernel generation.
## Example
```python
@torch.compile
def fn(x, y):
x = x.view(torch.float16)
y = y.view(torch.float16) + 1
return x @ y
x = torch.randn((2, 2), device=self.device, dtype=torch.bfloat16)
y = torch.randn((2, 2), device=self.device, dtype=torch.bfloat16)
fn(x, y)
```
The output code generated before this fix is like the following.
```python
triton_poi_fused_add_view_0...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.bfloat16).to(tl.float32, bitcast=True).to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
triton_poi_fused_add_view_1...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.bfloat16).to(tl.float32, bitcast=True).to(tl.float32)
tmp2 = 1.0
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
def call(args):
...
triton_poi_fused_view_0.run(arg0_1, buf0, 4, grid=grid(4), stream=stream0)
del arg0_1
buf1 = empty_strided_cuda((2, 2), (2, 1), torch.float16)
# Source Nodes: [view_1, y], Original ATen: [aten.add, aten.view]
triton_poi_fused_add_view_1.run(arg1_1, buf1, 4, grid=grid(4), stream=stream0)
del arg1_1
buf2 = empty_strided_cuda((2, 2), (2, 1), torch.float16)
# Source Nodes: [matmul, view_1, x, y], Original ATen: [aten.add, aten.mm, aten.view]
extern_kernels.mm(buf0, buf1, out=buf2)
```
As you can see, the two `view` operations are compiled to two kernels `triton_poi_fused_view_0` nad `triton_poi_fused_add_view_1`. Both of them has a line `tmp1 = tmp0.to(tl.bfloat16).to(tl.float32, bitcast=True).to(tl.float32)` which does the type conversion.
The main issue is that the first `view` operation didn't do anything to the actual data. But it generates a triton kernel with a new output tensor. Another small issue is that this triton kernel can't be compiled because `bitcast=True` only support type converstion with same bidwidth.
The following are output code generated after this PR.
```python
triton_poi_fused_add_0...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.bfloat16).to(tl.float32)
tmp2 = 1.0
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
def call(args):
...
triton_poi_fused_add_0.run(arg1_1, buf0, 4, grid=grid(4), stream=stream0)
del arg1_1
buf1 = empty_strided_cuda((2, 2), (2, 1), torch.float16)
# Source Nodes: [matmul, y], Original ATen: [aten.add, aten.mm]
extern_kernels.mm(aten.view.dtype(arg0_1, torch.float16), buf0, out=buf1)
```
The first `view` operation has been replaced with the `aten.view.dtype` and it is directly passed as an argument. The second one is still there because it is fused with the following add operation. The invalid bitcast operation is removed too.
The following two code snippets is for the upcasts and downcasts. For dtype in `torch.float16, torch.bfloat16`, each load will be upcasted to float32, then downcast to its original dtype to ensure use values with the right precision.
7bda23ef84/torch/_inductor/codegen/triton.py (L1725-L1726)7bda23ef84/torch/_inductor/codegen/triton.py (L629-L642)
Huge thanks to @eellison, @bdhirsh, @shunting314, and @desertfire .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128883
Approved by: https://github.com/eellison
Summary:
We have the cache to guarantee the `sym` is codegen only once, see the following code
```
def ensure_size_computed(self, sym: sympy.Symbol):
if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
if sym in self.computed_sizes:
return
self.computed_sizes.add(sym)
expr = V.graph.sizevars.inv_precomputed_replacements[sym]
self.writeline(
f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
)
```
However, we don't consider the case when same `sym`s need to be codegen in both conditions (true branch and false branch), which caused the issue of `undefined symbols`: P1441378833
To fix the issue, we use a stack to capture the state before doing the condition codegen and restore the state after doing the codegen
Test Plan:
TORCH_LOGS="+inductor" buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100 -c fbcode.enable_gpu_sections=true --config 'cxx.extra_cxxflags=-g1' -c fbcode.platform010_cuda_version=12 //scripts/hhh:repro_cond_torch_compile
PYTORCH_TEST_FBCODE=1 TORCH_COMPILE_DEBUG=1 buck2 run mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true //caffe2/test/inductor:control_flow -- -r test_cond_control_flow_with_precomputed_size
Differential Revision: D58973730
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129492
Approved by: https://github.com/aakhundov
Currently a buffer represents both a tensor with physical storage and a
computation that produces the tensor as a result.
This PR attempts to split these into two different concepts in the scheduler.
This should allow us to have multiple outputs from a single operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128893
Approved by: https://github.com/lezcano