**Summary**
In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it.
**Fusion**
- The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM.
- During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design.
- In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles.
**Code Gen**
We maintain a list of epilogues and codegen it one by one.
- If any of the GEMM has bias, we create a extra `bias_add` epilogue and prepend it at first of the epilogue list.
- If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list.
**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #143796
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012
- Support flexible number of GEMMs
- Share activation across GEMMs
- The Grouped GEMM Template supports independent activations
- However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
- Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
- Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```
**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return self.linear0(x), self.linear1(x)
if __name__ == "__main__":
with torch.no_grad():
input = torch.randn(batch_size, in_features, dtype=dtype)
m = M(bias=bias).to(dtype=dtype).eval()
cm = torch.compile(m)
act_res = cm(input)
```
Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py
**Next Step**
- Support Epilogue fusion
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
Fixes#134277 and https://github.com/pytorch/pytorch/issues/142317.
Sub-PRs containing refactors from this one:
- https://github.com/pytorch/pytorch/pull/141733
- https://github.com/pytorch/pytorch/pull/141738
- https://github.com/pytorch/pytorch/pull/141751 (based off the former)
- https://github.com/pytorch/pytorch/pull/142249
- https://github.com/pytorch/pytorch/pull/142020
- https://github.com/pytorch/pytorch/pull/143135
These refactor PRs should land before the main one.
# Feature
*Note: to minimize risk, multi-dimensional reductions are gated by the flag `config.triton.tile_reductions`, which defaults to False.*
Instead of having a single reduction dimension called `"r"`, we can now support 2D reductions with `"r0_"` and `"r1_"` dimensions. 2D reductions generate two nested loops, with different block pointer advancements in each loop body. Most of the implementation is generic to ND reductions, but for now the tiling algorithm sets a hard limit at 2D.
Here's an example of a 2D persistent reduction kernel:
```
@triton.jit
def triton_per_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr):
xnumel = 1
r0_numel = 15
R0_BLOCK: tl.constexpr = 16
r1_numel = 15
R1_BLOCK: tl.constexpr = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[None, :, None]
r0_offset = 0
r0_mask = r0_index < r0_numel
r1_index = tl.arange(0, R1_BLOCK)[None, None, :]
r1_offset = 0
r1_mask = r1_index < r1_numel
rnumel = r0_numel * r1_numel
RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
roffset = r1_offset + (r0_offset*r1_numel)
rindex = r1_index + (r0_index*r1_numel)
r0_0 = r0_index
r1_1 = r1_index
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[15, 15], strides=[30, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[r0_offset, r1_offset]), boundary_check=[0, 1], padding_option='zero')[None, :, :]
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
tmp3 = tl.where(r0_mask & r1_mask, tmp1, 0)
tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])
tmp5 = tl.sum(tmp4, 1)[:, None, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp5, None)
''', device_str='cuda')
```
There are a few main differences between this kernel and what Inductor would generate without this PR.
- Instead of an `r`/`RBLOCK` dimension, we have two reduction dimensions: `r0_`/`R0_BLOCK` and `r1_`/`R1_BLOCK`.
- There are special size and indexing variables for reductions, which don't directly correspond to any kernel dimension. (`rindex`, `rnumel`, `RBLOCK`, and `roffset`.) These collapse N-D reduction sizes and indices indices into 1D. This simplifies the codegen for reductions, which sometimes want to access linear indices instead of N-dimensional ones. Doing things this way allows us to generate N-D loads and stores, but access this data as if it were 1D, minimizing the blast radius of this PR. Although this makes the code more verbose, it shouldn't have a perf impact because the triton compiler eliminates dead code.
- We generate the line `tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])` before performing the actual reduction. This reshapes N reduction dimensions into 1D. This allows us to reduce over all N dimensions at once, simplifying the codegen and allowing the Triton complier to decide the order of processing under the hood.
Here's an example of a looped reduction:
```
@triton.jit
def triton_red_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr):
xnumel = 3
r0_numel = 43
r1_numel = 129
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :, None]
r1_base = tl.arange(0, R1_BLOCK)[None, None, :]
rnumel = r0_numel * r1_numel
RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
rbase = r1_base + (r0_base*r1_numel)
x0 = xindex
block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[3, 43, 129], strides=[11094, 258, 1], block_shape=[XBLOCK, R0_BLOCK, R1_BLOCK], order=[2, 1, 0], offsets=[xoffset, 0, 0])
_tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
for r1_offset in range(0, r1_numel, R1_BLOCK):
r1_index = r1_offset + r1_base
r1_mask = r1_index < r1_numel
roffset = r1_offset + (r0_offset*r1_numel)
rindex = r1_index + (r0_index*r1_numel)
r0_1 = r0_index
r1_2 = r1_index
tmp0 = tl.load(block_ptr0, boundary_check=[0, 1, 2], padding_option='zero', eviction_policy='evict_first')
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask & r1_mask & xmask, tmp3, _tmp2)
block_ptr0 = tl.advance(block_ptr0, [0, 0, R1_BLOCK])
block_ptr0 = tl.advance(block_ptr0, [0, R0_BLOCK, (-1)*R1_BLOCK*((128 + R1_BLOCK) // R1_BLOCK)])
tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK])
tmp2 = tl.sum(tmp4, 1)[:, None, None]
tl.store(tl.make_block_ptr(out_ptr0, shape=[3], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.reshape(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
In addition to the aforementioned changes to the persistent reduction, multidimensional looped reductions have a few more lines of code:
- They calculate indices inside the loop using `r0_base` and `r1_base`. For compatibility with existing codegen, these are collapsed to the 1D variant `rbase`.
- Block pointer advancements are more nuanced for multidimensional loops. At the end of each loop body, we emit a `tl.advance` line which not only increments the pointer in its own dimension, but also undoes the cumulative increments of the previous loop level. This is equivalent to the usual practice in nested loops of starting with a fresh iteration variable at each level. Implementing this required refactoring the way we generate pointer advancements into a new `self.pointer_advancements` field of the kernel, which categorizes advancements by dimension.
The biggest difficulty in implementing this feature was that we represented tiling with a tuple like `(5,2)`. In the existing codebase, the compiler can infer that the reduction dimension of `(5,2)` is `2`, since reductions are always the last dimension. This became cumbersome now that we have to support multiple reduction dimensions, so I refactored tiling into a dict like `{"x": 5, "r0_": 2, "r1_": 4}`. This required quite a few code changes, but I don't think it makes the underlying logic much more complex. This will also make it easier to eventually support simultaneous pointwise and reduction tiling, like `{"x": 5, "y": 5, "r0_": 2, "r1_": 4}`. (This is not supported today, but we might want to do it eventually.)
The existing tiling algorithm generalized naturally to support reductions. For pointwise kernels, we tile the pointwise dimensions (`"x"`, `"y"`) as is. For reduction kernels, we never tile the `"x"` dimension, and only tile the reduction dimensions (`"r0_"`, `"r1_"`). Thus we only ever tile pointwise OR reduction dimensions, but not both. In principle it seems possible to support both, but it would likely require changes to the kernel fusion and autotuning logic. I thought it best to keep this PR as minimal as possible since it already touched a lot of different files.
Unfortunately, these changes weren't enough to get block pointers in some seemingly simple test cases. In some tests for `argmax` and `var_mean`, we already collapse reduction dimensions into 1D and generate modular indexing expressions, prior to tiling. So it's not trivial to figure out how to expand the collapsed reduction dimension back to a shape that would simplify the indexing.
To address these cases, this PR adds a new feature to the `config.prefer_nd_tiling` option, which analyzes reads and writes in the kernel, using the same mod-div pattern matching logic that generates block pointers later on. By matching this pattern, we can solve for the tiling splits which *would* simplify the indexing expression, and use then use that tiling to eliminate the modular indexing and emit a block pointer. This tiling mode is still off by default, but it's important for certain applications where we need to get as many block pointers as possible.
# Test plan
This touches pretty much anything that uses the Triton and Halide backends, so the existing CI provides good coverage. However, 2D reductions are gated behind a few feature flags like `config.prefer_nd_tiling` and `config.tile_reductions`, so this really only checks that the PR doesn't break 1D reductions.
In addition to existing CI tests, this PR also adds some new tests that specifically stress 2D reductions:
- `test_2d_reduction_odd_shapes`: test 2D reductions with a variety of ops and sizes. This covers the typical persistent and looped reductions.
- `test_2d_reduce_no_x_dim`: test 2D reductions with no x dimension.
- `test_2d_welford_reduction`: test 2D welford reductions with block pointers.
- `test_welford_non_block_pointer`: test a 2D welford reduction when block pointer analysis fails.
- `test_reduction_multiple_discontiguous_dims`: test reducing over more than one discontiguous dimension. We won't get a block pointer for this case, since that would require 3D tiling, but we're currently limited to 2D.
- `test_2d_reduction_multi_kernel`: test multi kernel autotuning on a 2D softmax kernel.
- `test_enable_tiled_reductions`: test that `config.triton.tile_reductions` enables/disables this feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137243
Approved by: https://github.com/jansel
Co-authored-by: Yueming Hao <yhao@meta.com>
Co-authored-by: Jason Ansel <jansel@meta.com>
Fixes#143406
After this PR the error for missing Triton is:
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3624, in create_backend
raise TritonMissing(inspect.currentframe())
torch._dynamo.exc.TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
Setting `TORCHDYNAMO_VERBOSE=1` yields something like the old error:
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 576, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1383, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1167, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 548, in __call__
return _compile(
^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 988, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 716, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 751, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 232, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 663, in transform
tracer.run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 2870, in run
super().run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 1053, in run
while self.step():
^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 963, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3050, in RETURN_VALUE
self._return(inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3035, in _return
self.output.compile_subgraph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1102, in compile_subgraph
self.compile_and_call_fx_graph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1383, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1433, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1463, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/__init__.py", line 2314, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1880, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1145, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 754, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 676, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 489, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1758, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 686, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1044, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1975, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1981, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1916, in codegen
self.scheduler.codegen()
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3667, in codegen
return self._codegen()
^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3761, in _codegen
if device is not None and self.get_backend(device).ready_to_flush():
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3631, in get_backend
self.backends[device] = self.create_backend(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3624, in create_backend
raise TritonMissing(inspect.currentframe())
torch._dynamo.exc.TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
This PR also strips dynamo stack frames from other types of backend compile errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143552
Approved by: https://github.com/yanboliang
Summary:
**Re-land the pr**. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`.
```
_____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________
Unexpected success
```
------
(Since I am trying the other solution for https://github.com/pytorch/pytorch/pull/141082, I moved out the test case fixes from that pr to a separate pr to land first.)
-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.
The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0
-------
The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.
Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`
----
It's the same issue as fixed in https://github.com/pytorch/pytorch/pull/136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.
Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e
Differential Revision: D67012816
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142474
Approved by: https://github.com/eellison, https://github.com/sijiac, https://github.com/shunting314
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper):
1. The min. hardware and Triton version requirements are met for the TMA support.
2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).
3. The `config.triton.enable_persistent_tma_matmul` is set to `True`.
Additional notes:
1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.
2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.
3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).
4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in https://github.com/triton-lang/triton/pull/5290) in the new Triton template, which should allow lifting 2 and 3 above.
5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.
6. This PR is rebased onto and unifies with two related PRs landed previously: https://github.com/pytorch/pytorch/pull/142045 (some infra unification with the persistent+TMA template for _scaled_mm) and https://github.com/pytorch/pytorch/pull/134532 (add possibility to disable prolog fusion for selected choices).
7. The current Triton TMA API only supports 1D and 2D descriptors (even after https://github.com/triton-lang/triton/pull/5290, see [here](9829ce87cc/python/triton/language/core.py (L1957))). For now, this blocks adding persistent+TMA template for `torch.bmm`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142101
Approved by: https://github.com/drisspg, https://github.com/eellison
This PR moves `decide_global_ordering_of_comms` to run first before all other Inductor scheduler passes, so that downstream passes have the correct dependency tracking info. It also moves peak memory pass and overlap pass to the end of all passes, because they need to be the final decision maker on the node order to achieve the desired peak memory and overlap.
This PR fixes hard-to-debug peak memory pass errors caused by incorrect tracking in `.unmet_dependencies` during the enablement of SimpleFSDP on internal models.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142822
Approved by: https://github.com/eellison
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
**Summary:**
(Since I am trying the other solution for https://github.com/pytorch/pytorch/pull/141082, I moved out the test case fixes from that pr to a separate pr to land first.)
-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.
The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0
-------
The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.
Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`
----
It's the same issue as fixed in https://github.com/pytorch/pytorch/pull/136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.
**Test Plan:**
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
And ran a float8 dynamic scaling training script to verify it e2e
-----
Differential Revision: D66906175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142273
Approved by: https://github.com/eellison
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Adding some dynamo timed for the purpose of better understanding AOTI compilation time.
Probably would require a few more passes. A lot of time is spent in Scheduler.__init__, and not enough annotations are there.
run_command_and_check takes a lot time as well. But there is probably not much we can do. Maybe we can add a config to tune C++ optimization level?
traces:
<img width="1205" alt="Screenshot 2024-11-08 at 4 41 10 PM" src="https://github.com/user-attachments/assets/61645264-b3af-4d4a-804d-700b0f831c7c">
Differential Revision: D65554141
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140198
Approved by: https://github.com/desertfire
Here are the cases that Inductor does autotuning at compile time:
1. pad mm: benchmark to decide if we should pad or not
2. template autotuning: benchmark triton/cutlass templates and ATen kernel for matmul/conv and pick the fastest one.
The PR annotate these cases with `dynamo_timed`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139431
Approved by: https://github.com/ezyang
Summary:
I think we can inplace a buffer if all of the users of said buffer are "inconsequential", defined as having been removed, being completed, or being part of the ancestors set. In particular, this allows LayerNorm to inplace its input buffer.
Implements:
https://github.com/pytorch/pytorch/issues/132826
Test Plan:
New unit test of matmul followed by LayerNorm, make sure there's an inplaced buffer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138383
Approved by: https://github.com/eellison
Partially fixing https://github.com/pytorch/pytorch/issues/138685
Add a (relatively safe?) heuristics to skip fusion if we can potentially increasing peak memory.
The doc string mainly explains what this PR is doing:
```
The implementation is more like a heuristic since we don't really know if we are at peak
or not when trying to fuse these two ndoes. The order of nodes may change later which makes the
peak memory estimation hard.
Here is how we decide the LOWER BOUND of extra memory allocation if we fuse these 2 nodes:
1. find all buffers read by each node with a single user. These buffers are supposed to
be reused if we don't fuses these 2 nodes
2. find the intersection of these buffers for the two node and sum the total buffer size.
If we don't fuse these two nodes, we can at lease avoid this much memory allocation.
Note that the extra memory allocation is not necessarily causing peak memory increase.
This is just a heuristic.
We return true only if the saving for fusion can not trade off the extra memory allocation.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138756
Approved by: https://github.com/jansel
ghstack dependencies: #139136