This is the last directory to opt in for the regular mypy.ini file. Will put up a diff to remove unused ignores before making sure we're also type checking all the files in the mypy strict configurations
Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check
step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199
after:
INFO 0 errors (6,884 ignored)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165692
Approved by: https://github.com/oulgen
### Implementation of #151705
This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.
To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:
1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)
### Summary of This PR
This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.
### Notable Changes
1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.
@eellison @jansel @PaulZhang12 @shunting314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
# Feature
Currently, `torch._inductor.compile_aot` always uses the `WrapperFxCodegen` class. In contrast, Python and C++ codegen allow users to register custom backends. This PR brings that feature to FX codegen.
# Test plan
Added a CI test registering a custom FX backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162317
Approved by: https://github.com/jansel
# why
- to have a central registry of templates/externkernelchoice
to match them to heuristics etc, they need unique names
- mm is both the triton template name and the aten_mm name
# what
- add a uid() to KernelTemplate/ExternKernelChoice that returns name
- override in ExternKernel to prepend "aten::"
- override in TritonTemplate to prepend "triton::"
This id is just use to find template heuristics, so it has no other
impact
# testing
```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```
Differential Revision: [D81520579](https://our.internmc.facebook.com/intern/diff/D81520579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161341
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #162075, #161340
Using the existing WrapperFxCodegen backend, this PR prototypes an AOT version of it which will directly return a graph module.
How to use:
```python
exported_gm = torch.export.export(model, inp, dynamic_shapes=dynamic_shapes).module()
compiled_gm = torch._inductor.aot_compile(
exported_gm, inp, options={"fx_wrapper": True, "compile_threads": 1}
)
assert torch.allclose(model(*inp), compiled_gm(*inp))
```
The motivation behind this is that backends like ExecuTorch/MTIA would like to use inductor's optimization technologies, but might have their own graph lowering pipelines so they might not want to use AOTI (which generates an so).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160765
Approved by: https://github.com/jansel
An out of tree backend can have its own configuration options that the user can enable to control inductor compilation. These config options need to be taken into account when calculating the key that is used to determine cache miss / hits. This PR allows out of tree backends to specify a custom config module that has the same type as `torch._inductor.config` that can be used to control codegen (in addition to the default config), and will be used when creating the cache key.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158254
Approved by: https://github.com/eellison
Tests: `python test/inductor/test_aot_inductor.py -vvv -k device_tma`
Device-side TMA in Triton allows the kernel author to construct the TMA descriptor on the device (which composes with things like autotuning much better). However, it also requires a scratch space to be provided into which the TMA descriptor will be constructed. In the new TMA API (tl.make_tensor_descriptor), this is implemented using a "global scratch space" - a tensor which is allocated beforehand and then passed in as an argument for the kernel.
To support this in AOTI, this PR:
* records the global scratch space needed (triton_heuristics.py), so that it can be used during AOTI codegen
* allocates global scratch, if needed (cuda/device_op_overrides.py)
* plumbs `device_idx_` into the triton caller function, so that global scratch can be allocated on the right device)
* updates tests to verify this works for dynamically shaped inputs
This PR should support both inductor-generated device-side TMA (e.g. persistent TMA mm) and user-defined triton kernels that contain device-side TMA (which is the test I ran to verify this works)
Note: this overrides any user-provided allocator function (typically with eager triton code, the user must provide their own custom allocator function that is used to allocate scratch space).
For Meta reviewers, here is a tlparse from running `python test/inductor/test_aot_inductor.py -vvv -k test_triton_kernel_on_device_tma_dynamic_True_tma_version_new_cuda` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpFg13g1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
Differential Revision: [D77352139](https://our.internmc.facebook.com/intern/diff/D77352139)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155896
Approved by: https://github.com/desertfire
This PR adds JIT inductor support for user-defined triton kernels using the new host-side TMA api.
* handle TensorDescriptor.from_tensor in ir.py
* codegen TensorDescriptor.from_tensor in wrapper.py
* generate the right signature for functions that take TensorDescriptor arguments (i.e. in the @triton_heuristics.user_autotune decorator)
AOTI support is not implemented yet.
Tests: ran test_triton_kernels.py w/ both Triton 3.3 and 3.4 and there were no failures.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155814
Approved by: https://github.com/aakhundov
ghstack dependencies: #155777
# Sub-PRs
These PRs contain refactors from the main one. They should be reviewed and merged first.
- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587
# Feature
The goals of this PR are twofold.
## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.
In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.
This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.
## Goal 2: Convert Wrapper IR into FX IR.
One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.
It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.
The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.
Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.
# Current status
Things that seem to work:
- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
- Handled the following cases, in addition to what was already in the Memory Planning IR:
- Comments
- Triton kernels
- Extern/fallback kernels
- Freeing tensors (`del buf0`)
- MultiOutput
- Graph outputs
- ReinterpretView / StorageBox, for both call args and outputs.
- FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
- Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
- Calling wrapped Triton kernels.
- Calling extern kernels and certain types of fallback kernels.
- Support both `extern_kernels.*` and `aten.*`.
- Support multi-output kernels like `torch.topk`.
- Graphs with multiple inputs/outputs.
- Training i.e. calling `Tensor.backward()` in a compiled function.
- Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.
Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
- Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
- Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
- CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.
# Out-of-tree compilers
With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.
```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen
class MyCustomBackend(WrapperFxCodegen):
def compile_graph(self, gm):
# Add 1 to the graph's outputs
def compiled_fn(*args):
return [x + 1 for x in gm.graph.forward(*args)]
return compiled_fn
```
# Example FX graphs
This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.
Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
return (buf0,)
```
Here's a more complicated graph that calls a `torch.addmm` extern kernel.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
%buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
return (buf2,)
```
Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
%buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
%buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
%triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
return (buf1, buf2)
```
Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
%buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
return (buf0_view_buf0_0,)
```
Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
%s6 : [num_users=0] = placeholder[target=s6]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
return buf0
```
Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
%s10 : [num_users=0] = placeholder[target=s10]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
return buf0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
# Sub-PRs
These PRs contain refactors from the main one. They should be reviewed and merged first.
- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587
# Feature
The goals of this PR are twofold.
## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.
In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.
This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.
## Goal 2: Convert Wrapper IR into FX IR.
One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.
It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.
The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.
Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.
# Current status
Things that seem to work:
- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
- Handled the following cases, in addition to what was already in the Memory Planning IR:
- Comments
- Triton kernels
- Extern/fallback kernels
- Freeing tensors (`del buf0`)
- MultiOutput
- Graph outputs
- ReinterpretView / StorageBox, for both call args and outputs.
- FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
- Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
- Calling wrapped Triton kernels.
- Calling extern kernels and certain types of fallback kernels.
- Support both `extern_kernels.*` and `aten.*`.
- Support multi-output kernels like `torch.topk`.
- Graphs with multiple inputs/outputs.
- Training i.e. calling `Tensor.backward()` in a compiled function.
- Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.
Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
- Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
- Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
- CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.
# Out-of-tree compilers
With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.
```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen
class MyCustomBackend(WrapperFxCodegen):
def compile_graph(self, gm):
# Add 1 to the graph's outputs
def compiled_fn(*args):
return [x + 1 for x in gm.graph.forward(*args)]
return compiled_fn
```
# Example FX graphs
This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.
Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
return (buf0,)
```
Here's a more complicated graph that calls a `torch.addmm` extern kernel.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
%buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
return (buf2,)
```
Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
%buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
%buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
%triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
return (buf1, buf2)
```
Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
%buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
return (buf0_view_buf0_0,)
```
Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
%s6 : [num_users=0] = placeholder[target=s6]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
return buf0
```
Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
%s10 : [num_users=0] = placeholder[target=s10]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
return buf0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel