Prior to this PR, constexprs were appearing in signatures as `{.. "XBLOCK : tl.constexpr": "constexpr"}` when they really should appear as `{.. "XBLOCK": "constexpr"}`.
This PR represents the argument names as ArgName objects, which can optionally be marked as constexpr.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145583
Approved by: https://github.com/jansel
# Issue
https://github.com/pytorch/pytorch/pull/137243 introduced a feature where the ND tiling algorithm analyzes memory dependencies. It iterates over all `Dep`'s of the kernel. However, the analysis is only applicable to `MemoryDep` instances, which are a subclass of `Dep`. In particular, it doesn't work for `StarDep`'s, for the reasons described here: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/simd.py#L1653
# Fix
This PR changes the algorithm to only iterate over `MemoryDep` instances.
# Testing
Parameterized an existing test for `torch.bucketize` to also run with ND tiling. This test emits a node with `StarDep`'s. Without this PR, the compiler would crash on this test case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144497
Approved by: https://github.com/eellison
# Issue
This PR cleans up an edge case that wasn't handled by https://github.com/pytorch/pytorch/pull/137243. The existing tiling code assumes that `node.get_ranges()` is a reliable source of pointwise and reduction numels. This is true for pointwise kernels, but the situation is more complicated with reductions. Since reductions change the number of elements in a tensor, not all ops within a reduction kernel will have the same number of iterations. For example, `var_mean` fuses pointwise division with the output of reduction sum, and the division lacks the corresponding reduction ranges.
# Fix
Instead of getting numels from `node.get_ranges()`, explicitly pass the global pointwise and reduction numels to the relevant tiling functions. In `SIMDKernel.complete_partial_tiling`, we solve for the missing numel by diving the global numel by the partial tiling's numel. This ensures all tilings have the correct global numel.
Also, in `SIMDKernel.is_compatible`, add the global reduction numel to node ranges that are missing it. For example, `{"x": 8, "r0_": 8}` is compatible with a node of ranges `([8], [])` when we have `reduction_numel=8`.
Finally, this PR generalizes some of the existing codegen to handle multiple reduction dims. We already had code to ignore reduction splits for pointwise kernels, but it only worked for 1D reductions. Now it can handle ND.
# Test plan
This PR parametrizes the existing CI test for `var_mean` to also run with tiled reductions. It also adds a new test checking that `var_mean` generates 2D tilings (with tiled reduction enabled). These new tests would fail on the current main branch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144041
Approved by: https://github.com/jansel
Fixes#134277 and https://github.com/pytorch/pytorch/issues/142317.
Sub-PRs containing refactors from this one:
- https://github.com/pytorch/pytorch/pull/141733
- https://github.com/pytorch/pytorch/pull/141738
- https://github.com/pytorch/pytorch/pull/141751 (based off the former)
- https://github.com/pytorch/pytorch/pull/142249
- https://github.com/pytorch/pytorch/pull/142020
- https://github.com/pytorch/pytorch/pull/143135
These refactor PRs should land before the main one.
# Feature
*Note: to minimize risk, multi-dimensional reductions are gated by the flag `config.triton.tile_reductions`, which defaults to False.*
Instead of having a single reduction dimension called `"r"`, we can now support 2D reductions with `"r0_"` and `"r1_"` dimensions. 2D reductions generate two nested loops, with different block pointer advancements in each loop body. Most of the implementation is generic to ND reductions, but for now the tiling algorithm sets a hard limit at 2D.
Here's an example of a 2D persistent reduction kernel:
```
@triton.jit
def triton_per_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr):
xnumel = 1
r0_numel = 15
R0_BLOCK: tl.constexpr = 16
r1_numel = 15
R1_BLOCK: tl.constexpr = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[None, :, None]
r0_offset = 0
r0_mask = r0_index < r0_numel
r1_index = tl.arange(0, R1_BLOCK)[None, None, :]
r1_offset = 0
r1_mask = r1_index < r1_numel
rnumel = r0_numel * r1_numel
RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
roffset = r1_offset + (r0_offset*r1_numel)
rindex = r1_index + (r0_index*r1_numel)
r0_0 = r0_index
r1_1 = r1_index
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[15, 15], strides=[30, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[r0_offset, r1_offset]), boundary_check=[0, 1], padding_option='zero')[None, :, :]
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
tmp3 = tl.where(r0_mask & r1_mask, tmp1, 0)
tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])
tmp5 = tl.sum(tmp4, 1)[:, None, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp5, None)
''', device_str='cuda')
```
There are a few main differences between this kernel and what Inductor would generate without this PR.
- Instead of an `r`/`RBLOCK` dimension, we have two reduction dimensions: `r0_`/`R0_BLOCK` and `r1_`/`R1_BLOCK`.
- There are special size and indexing variables for reductions, which don't directly correspond to any kernel dimension. (`rindex`, `rnumel`, `RBLOCK`, and `roffset`.) These collapse N-D reduction sizes and indices indices into 1D. This simplifies the codegen for reductions, which sometimes want to access linear indices instead of N-dimensional ones. Doing things this way allows us to generate N-D loads and stores, but access this data as if it were 1D, minimizing the blast radius of this PR. Although this makes the code more verbose, it shouldn't have a perf impact because the triton compiler eliminates dead code.
- We generate the line `tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])` before performing the actual reduction. This reshapes N reduction dimensions into 1D. This allows us to reduce over all N dimensions at once, simplifying the codegen and allowing the Triton complier to decide the order of processing under the hood.
Here's an example of a looped reduction:
```
@triton.jit
def triton_red_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr):
xnumel = 3
r0_numel = 43
r1_numel = 129
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :, None]
r1_base = tl.arange(0, R1_BLOCK)[None, None, :]
rnumel = r0_numel * r1_numel
RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
rbase = r1_base + (r0_base*r1_numel)
x0 = xindex
block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[3, 43, 129], strides=[11094, 258, 1], block_shape=[XBLOCK, R0_BLOCK, R1_BLOCK], order=[2, 1, 0], offsets=[xoffset, 0, 0])
_tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
for r1_offset in range(0, r1_numel, R1_BLOCK):
r1_index = r1_offset + r1_base
r1_mask = r1_index < r1_numel
roffset = r1_offset + (r0_offset*r1_numel)
rindex = r1_index + (r0_index*r1_numel)
r0_1 = r0_index
r1_2 = r1_index
tmp0 = tl.load(block_ptr0, boundary_check=[0, 1, 2], padding_option='zero', eviction_policy='evict_first')
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask & r1_mask & xmask, tmp3, _tmp2)
block_ptr0 = tl.advance(block_ptr0, [0, 0, R1_BLOCK])
block_ptr0 = tl.advance(block_ptr0, [0, R0_BLOCK, (-1)*R1_BLOCK*((128 + R1_BLOCK) // R1_BLOCK)])
tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK])
tmp2 = tl.sum(tmp4, 1)[:, None, None]
tl.store(tl.make_block_ptr(out_ptr0, shape=[3], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.reshape(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
In addition to the aforementioned changes to the persistent reduction, multidimensional looped reductions have a few more lines of code:
- They calculate indices inside the loop using `r0_base` and `r1_base`. For compatibility with existing codegen, these are collapsed to the 1D variant `rbase`.
- Block pointer advancements are more nuanced for multidimensional loops. At the end of each loop body, we emit a `tl.advance` line which not only increments the pointer in its own dimension, but also undoes the cumulative increments of the previous loop level. This is equivalent to the usual practice in nested loops of starting with a fresh iteration variable at each level. Implementing this required refactoring the way we generate pointer advancements into a new `self.pointer_advancements` field of the kernel, which categorizes advancements by dimension.
The biggest difficulty in implementing this feature was that we represented tiling with a tuple like `(5,2)`. In the existing codebase, the compiler can infer that the reduction dimension of `(5,2)` is `2`, since reductions are always the last dimension. This became cumbersome now that we have to support multiple reduction dimensions, so I refactored tiling into a dict like `{"x": 5, "r0_": 2, "r1_": 4}`. This required quite a few code changes, but I don't think it makes the underlying logic much more complex. This will also make it easier to eventually support simultaneous pointwise and reduction tiling, like `{"x": 5, "y": 5, "r0_": 2, "r1_": 4}`. (This is not supported today, but we might want to do it eventually.)
The existing tiling algorithm generalized naturally to support reductions. For pointwise kernels, we tile the pointwise dimensions (`"x"`, `"y"`) as is. For reduction kernels, we never tile the `"x"` dimension, and only tile the reduction dimensions (`"r0_"`, `"r1_"`). Thus we only ever tile pointwise OR reduction dimensions, but not both. In principle it seems possible to support both, but it would likely require changes to the kernel fusion and autotuning logic. I thought it best to keep this PR as minimal as possible since it already touched a lot of different files.
Unfortunately, these changes weren't enough to get block pointers in some seemingly simple test cases. In some tests for `argmax` and `var_mean`, we already collapse reduction dimensions into 1D and generate modular indexing expressions, prior to tiling. So it's not trivial to figure out how to expand the collapsed reduction dimension back to a shape that would simplify the indexing.
To address these cases, this PR adds a new feature to the `config.prefer_nd_tiling` option, which analyzes reads and writes in the kernel, using the same mod-div pattern matching logic that generates block pointers later on. By matching this pattern, we can solve for the tiling splits which *would* simplify the indexing expression, and use then use that tiling to eliminate the modular indexing and emit a block pointer. This tiling mode is still off by default, but it's important for certain applications where we need to get as many block pointers as possible.
# Test plan
This touches pretty much anything that uses the Triton and Halide backends, so the existing CI provides good coverage. However, 2D reductions are gated behind a few feature flags like `config.prefer_nd_tiling` and `config.tile_reductions`, so this really only checks that the PR doesn't break 1D reductions.
In addition to existing CI tests, this PR also adds some new tests that specifically stress 2D reductions:
- `test_2d_reduction_odd_shapes`: test 2D reductions with a variety of ops and sizes. This covers the typical persistent and looped reductions.
- `test_2d_reduce_no_x_dim`: test 2D reductions with no x dimension.
- `test_2d_welford_reduction`: test 2D welford reductions with block pointers.
- `test_welford_non_block_pointer`: test a 2D welford reduction when block pointer analysis fails.
- `test_reduction_multiple_discontiguous_dims`: test reducing over more than one discontiguous dimension. We won't get a block pointer for this case, since that would require 3D tiling, but we're currently limited to 2D.
- `test_2d_reduction_multi_kernel`: test multi kernel autotuning on a 2D softmax kernel.
- `test_enable_tiled_reductions`: test that `config.triton.tile_reductions` enables/disables this feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137243
Approved by: https://github.com/jansel
Co-authored-by: Yueming Hao <yhao@meta.com>
Co-authored-by: Jason Ansel <jansel@meta.com>
Summary:
This diff mainly adds code changes to dump `inductor_triton_kernel_to_post_grad_nodes.json` artifact which contains mapping info from post_grad -> inductor kernel code:
`{"inductor_triton_kernel_name": [post_grad_node_0, post_grad_node_1, ..., ], "..."}.`
Example paste: P1695235000 verified on the test model. See "Test Plan":
We use this artifact to demonstrate provenance tracking in the frontend 3-tab highlighter tool:
https://github.com/YUNQIUGUO/compiler_explorer (copy/pasted the input files for demo purpose for now and will integrate with Shangdi's tool to 4-tab)
https://pxl.cl/66BzK
Note: Currently only supports mapping for inductor's`TritonKernel` type. TODO for enhancing more support for `ExternKernel` and other inductor generated kernel type, etc.
Test Plan:
test_model_coverage.sh:
```
#!/bin/sh
MODEL_ENTITY_ID=644688112
SNAPSHOT_ID=32
MODULE=merge
# buck2 build --show-output mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true -c fbcode.nvcc_arch=a100,h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark
TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=0 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCH_LOGS="+inductor, schedule, fusion, output_code" TORCH_TRACE="tmp/guorachel_tt" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 ../buck-out/v2/gen/fbcode/d29ee94b913014f1/caffe2/torch/fb/model_transform/experimental/benchmark/__mts_gpu_benchmark__/mts_gpu_benchmark.par --model-path manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend AOT_INDUCTOR_EP --gpu-trace --aot-inductor-config="{'max_autotune': True}" 2>&1 | tee output.txt
```
{F1973765026}
```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:provenance_tracing -- --exact 'caffe2/test/inductor:provenance_tracing - test_triton_kernel_post_grad_mapping_aot_inductor (caffe2.test.inductor.test_provenance_tracing.TestProvenanceTracingArtifact)'
```
```
TORCH_LOGS="+inductor, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_post_grad_mapping_aot_inductor
```
Differential Revision: D66967510
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143055
Approved by: https://github.com/chenyang78
For prologues which only do either loads like gathers or dtype conversions, and no actual arithmetic on lower-precision types, we can codegen them without upcasting to fp32 without changing numerics.
Prologues that actually do arithmetic will need to use invoke quant. But I would like to to support upcasts/gathers out of the box.
We could potentially extend this in the future to avoid upcasting max pooling operations as well, if there were perf benefits to be had (less likely).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142402
Approved by: https://github.com/jansel
ghstack dependencies: #142401
We load inputs to prologue fusion with a mask. That mask must still be zero before we run `tl.dot`. Previously, we would always apply the mask:
```
tmp0 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, xindex.shape)), a_mask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
a = tl.where(a_mask, tmp1, 0.0)
```
now we do not need to ->
```
tmp0 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, xindex.shape)), a_mask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
a = tmp1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142401
Approved by: https://github.com/jansel
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243.
# Feature
This PR changes the `RINDEX` / `"r"` symbol type to `(R0_INDEX, R1_INDEX)` and `("r0_", "r1_")`, respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so `"r1_"` is never used. However, it prepares other parts of the system to handle `"r1_"` once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land.
The only change to the generated triton code is to rename `"rindex"` -> `"r0_index"`, `"RBLOCK"` -> `"R0_BLOCK"`, etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables like `rindex = r0_index`. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR.
These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf.
# Test plan
The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from `r.*` to `r0_.*`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142020
Approved by: https://github.com/jansel
Co-authored-by: Jason Ansel <jansel@meta.com>
For prologues which only do either loads like gathers or dtype conversions, and no actual arithmetic on lower-precision types, we can codegen them without upcasting to fp32 without changing numerics.
Prologues that actually do arithmetic will need to use invoke quant. But I would like to to support upcasts/gathers out of the box.
We could potentially extend this in the future to avoid upcasting max pooling operations as well, if there were perf benefits to be had (less likely).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142402
Approved by: https://github.com/jansel
ghstack dependencies: #134532, #142350, #142400, #142401
We load inputs to prologue fusion with a mask. That mask must still be zero before we run `tl.dot`. Previously, we would always apply the mask:
```
tmp0 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, xindex.shape)), a_mask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
a = tl.where(a_mask, tmp1, 0.0)
```
now we do not need to ->
```
tmp0 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, xindex.shape)), a_mask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
a = tmp1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142401
Approved by: https://github.com/jansel
ghstack dependencies: #134532, #142350, #142400
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. Previously, we would typically check for reductions by `tree.prefix == "r"`. This PR moves the check into a helper function. This makes it easier to generalize the code to multi-dimensional reductions, which could have multiple prefixes like `("r0_", "r1_")`.
Tested by the existing CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141738
Approved by: https://github.com/jansel
Fix https://github.com/pytorch/pytorch/issues/128063 .
Now for this snippet
```
def f(x):
y = torch.sum(torch.sum(x, dim=-1))
z = x / 10.0
z_t = z.t().contiguous().t()
return y, z, z_t
```
Inductor could generate a single kernel for the first reduction and the two ponitwise kernels (if loop-ordering after fusion is enabled). And the generated kernel read `x` only ONCE. (with no proper handling, the two pointwise's may each access x once even if they are fused).
The PR needs fix 2 subtile bugs regarding LOAF .
1. when we reorder loops for a FusedSchedulerNode, we check if each sub-node's sizes matches. But some node has sizes in `list` type (if its loop is not reordered) while others have its sizes in `tuple` type (if its loop is reordered). I could change the upstream code to uniformly use either `list` or `tuple`. But without strong enforcement, future code could break this. So I just convert sizes to uniform type before comparison.
2. We have a cache for tiling decisions of a BaseSchedulerNode. If we reorder loops for the node, we should invalidate the cache. Otherwise, a stale tiling decision can result in (very) bad kernel.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139376
Approved by: https://github.com/jansel, https://github.com/eellison