This does two things,
1) Short circuit `welford_reduce` on the first iteration to ignore the accumulator (big win for small `rnumel`)
2) Replace division with multiplication by reciprocal
Currently this is not enough to match two pass reduction with bfloat16 but it is still a significant improvement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120330
Approved by: https://github.com/lezcano
I want to log metadata for inductor generated triton kernels for a couple of purposes
1. with these metadata, it should be convenient to find unaligned reduction kernels and try the idea here https://github.com/pytorch/pytorch/issues/119929 . I think it's nice to try on kernels that are used in real models
2. I'm thinking that based on the collected kernel metadata, I can build a simple offline tool by benchmarking each kernel with ncu and augment each kernel metadata with: latency, theoretical membw (estimated memory access / latency), and actually achieved membw. Hopefully this can point us to some good optimization opportunities.
Command:
```
TORCHINDUCTOR_CACHE_DIR=`realpath ~/inductor-caches/kernel-metadata-log` TORCHINDUCTOR_ENABLED_METRIC_TABLES=kernel_metadata TORCHINDUCTOR_BENCHMARK_KERNEL=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 time python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --training
```
The best practice here is to point inductor cache to a folder outside of /tmp so that one can always run the kernel again based on the path stored in kernel metadata. (folders under /tmp may get removed by the system)
Here is first 1000 rows of collected metadata for huggingface: https://gist.github.com/shunting314/cf4ebdaaaa7e852efcaa93524c868e5f
And here is the total 10K kernels collected for huggingface. The gist can not be rendered as a csv since it's too large: https://gist.github.com/shunting314/7f841528e2debdc2ae05dece4ac591be .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120048
Approved by: https://github.com/jansel
Summary: Currently, when a custom (user-written) Triton kernel has a ReinterpretView argument in IR, we're always skipping the alignment checking for this argument when preparing the `signature_of` for the AOT compilation of the Triton kernel (via setting `TensorArg.check_alignment` to `False`). This is problematic for user-written kernels where, albeit reinterpreted, the argument of the Triton kernel (the data pointer) can still be aligned to 16. When we skip alignment checking, the performance of the AOT-compiled internal Triton kernels can degrade 2x--3x.
In this PR, we replace `TensorArg.check_alignment` by `TensorArg.offset`, in which we specify the offset of the `ReinterpretView.layout` relative to the underlying `ir.Buffer` (corresponding to the data pointer before reinterpretation). As the size and stride of the layout don't change the alignment properties, those can be skipped. Importantly, for `ReinterpretView` arguments of custom Triton kernels, we use `arg.data.get_name()` as the buffer name. That, together with the offset, is used to check the alignment.
Bonus: the namedtuples in `codegen/common.py` are refactored as `dataclass`es, with nicer type hints and default values (for the newly added `TensorArg.offset`).
Test Plan:
```
$ python test/inductor/test_aot_inductor.py -k test_triton_kernel_reinterpret_view
...
----------------------------------------------------------------------
Ran 6 tests in 27.952s
OK (skipped=4)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119649
Approved by: https://github.com/oulgen
This PR adds a new type of triton kernel in which data is persistent but the
reduction dimension is split over multiple blocks (up to the entire kernel).
though this is called a reduction dimension, in actuality we only support scans.
because of this limitation, i have to be able to block fusions of split scan
operations with reductions so chose to add a new `ir.SplitScan` node which
is identical but allows for differentiation in the scheduler.
The split scan kernel is also the first to require an additional workspace buffer
which is used to communicate between cuda blocks. this is slightly tricky as we
the exact scratch space requirement isn't known until the grid size is calculated.
here i workaround the issue by setting a minimum rblock size and always allocating
to the maximum possible grid size for a given input tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117992
Approved by: https://github.com/jansel
ghstack dependencies: #117991
Currently the dimension handling in triton kernels has various special cases e.g.
- handling "r" for non-reduction vs persistent reduction vs non-persistent reduction.
- handling "x" when `no_x_dim` is set
This adds three new properties to the range tree objects which capture the
same information in a more generic way:
- `is_loop`: true for the "r" dimension of a non-persistent reduction
- `tensor_dim`: Optional index of the triton tensor dimension
- `grid_dim`: Optional index of the triton grid dimension
The motivation here is I want to add a new split scan kernel type which is:
- not a persistent reduction, yet has `is_loop=False` for the "r" dimension
- Has a `grid_dim` for the "r" dimension
These flags now only need to be set once in `initialize_range_trees`, instead of having
to infer them throughout the code based on the tree prefix and various other kernel flags.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117991
Approved by: https://github.com/lezcano
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
Our current throughput calculations for kernel benchmarks have some issues,
particularly when we slice inputs in the kernel. In such cases, we count
the original inputs as part of the memory traffic passed across the kernel.
This is incorrect because it may result in a much larger throughput
calculation, which can even exceed the theoretical bandwidth.
Instead, we should only count the size of the "slices" that contribute to
the actual memory traffic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118858
Approved by: https://github.com/jansel
Special values (`NaN`/`+/-Inf`) are not correctly during codegen for `ir.Scan` nodes. This
is a fairly minor bugfix that has not come up since the only two scan
ops with lowerings use "normal" values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118788
Approved by: https://github.com/peterbell10
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.
The suppressions were added automatically with the following script generated by GPT-4:
```
import re
# Read the error file
with open("error_file.txt", "r") as f:
errors = f.readlines()
# Parse the lines with errors and error types
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
This diff introduce the following changes:
1. Fix sympy_subs to preserve integer and non-negative properties of replaced symbol when replacement is string
why is this needed?
I was compiling an expression:
x*abs(y) where y =-2
what happens is that this expression is passed as ``s1*abs(s0)`` then s0 is replaced to ks0 with a call to sympy_subs.
but sympy_subs used to replace s0 (integer=false, nonegative=false) with ks0(inetegr=true, nonegative = true)
resulting in ``x*abs(ks0) = x*ks0`` which is wrong
2. rename sympy_symbol to sympy_index_symbol to make it explicit.
3. add assertion that replaced expression is not passed as string but always a sympy expression.
Fixes https://github.com/pytorch/pytorch/issues/117757
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118150
Approved by: https://github.com/ezyang
Previously, we generated the grid argument with tree.numel for
a benchmark TritonKernel. This was not correct, because it
didn't match the launch config used for profiling and running.
This PR fixed the issue by emitting the grid value computed
by the kernel's grid_fn, which is used by the profiler and
the kernel's runner.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118202
Approved by: https://github.com/shunting314, https://github.com/jansel
For a persistent reduction, we generate 2 flavor of 'equivalant' kernels at the same time
- persistent reduction
- regular reduction
A MultiKernel wraps these 2 kernels and pick the one with better performance at runtime.
Here I talk more about implementation details:
- Inductor maintains states for generating kernels. E.g. the wrapper code. After we generate code for one kernel, we need restore the inductor state before we can generate the counterpart.
***There is one thing I need some comments from others***:
There is one tricky thing about kernel arguments. In general, inductor removes a buffer from the argument list if it's only used inside the kernel. But somehow a buffer removed by persistent reduction kernel may still be kept by the regular (non-persistent) reduction kernel because of some CSE invalidation rule. My current implementation avoid removing buffers if multi_kernel is enabled. This makes sure both flavors of reduction has consistent argument list. Another idea I have is, we generate the multi-kernel definition with the union of arguments from both sub-kernels. Let each sub-kernel pick the subset of arguments it wants. But this will make the code-gen or multi-kernel much complex.
I'm not sure if there is some easy and clean way to resolve this.
Testing command:
```
TORCHINDUCTOR_MULTI_KERNEL=1 TORCH_LOGS=+torch._inductor.graph TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only BertForMaskedLM --training
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103469
Approved by: https://github.com/jansel
As the [RFC](https://github.com/pytorch/pytorch/issues/114856) mentions, this is the step 1 to add Intel GPU backend as an alternative inductor backend.
### Design
Typically, in order to integrate Intel GPU backend into Inductor, we need to inherit from `WrapperCodegen` and `TritonScheduling` and implement the corresponding subclasses respectively. However, since `WrapperCodegen` and `TritonScheduling` have some device-bias code generation **scattered** in their methods, overriding them in subclasses would introduce a lot of duplicated parent class code.
For example:
2a44034895/torch/_inductor/codegen/wrapper.py (L487)2a44034895/torch/_inductor/codegen/triton.py (L1996)
So we abstract the device-bias code scattered in WrapperCodegen and TritonScheduling and provide a unified interface "DeviceOpOverrides". This way, when integrating a new backend, we can maximize the reuse of `WrapperCodegen` and `TritonScheduling` code by inherit and implement this interface for device flexibility.
Currently the `DeviceOpOverrides` only cover Python wrapper code generation. We can futher extend it to cover Cpp wrapper code generation on demand.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116020
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.
Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.
Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.
Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.
Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
Fixes#114310 and supersedes #114748.
There are two reasons why we have quite a few special cases for `round`:
1. `round` is actually two ops. With `ndigits=None` (default), `round` always returns an integer. When `ndigits` is an integer, the returned type is a float.
2. Although `round` takes two arguments, it is a unary function with a parameter rather than a binary one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115259
Approved by: https://github.com/peterbell10, https://github.com/lezcano
Currently the inductor code for `x.any(-1)` does a this strange dance:
```python
tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask)
tmp1 = tmp0.to(tl.int64)
tmp2 = (tmp1 != 0)
```
This happens because `register_lowering` is doing type promotion with the
dimension argument, and so promotes to `int64` which we then cast back to bool.
A better fix would be to fix `register_lowering` but for now I just remove
the unnecessary type promotion from `aten.any`.
In the current code we also see:
```python
tmp5 = tl.where(rmask & xmask, tmp3, 0)
```
which promotes the boolean value to int since `0` is an int32 in triton.
This fixes it to generate a boolean constant instead.
Finally there is also a triton bug where the `tl.load` itself upcasts to
`tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final
kernel code looks like:
```python
tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.full([1, 1], 0, tl.int1)
tmp4 = tl.where(rmask & xmask, tmp1, tmp3)
tmp5 = triton_helpers.any(tmp4, 1)[:, None]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109913
Approved by: https://github.com/lezcano
Description:
- Added non-integer expr support for floordiv in triton codegen
- Added a test
- cpp test is skipped as failing and https://github.com/pytorch/pytorch/pull/115647 may fix it
This PR is fixing compilation error with the following code:
```python
import torch
def func(x, a):
n = (a * 1.234) // 8.234
y = x + n
return y
cfunc = torch.compile(func, dynamic=True, fullgraph=True)
device = "cuda"
x = torch.tensor(0, dtype=torch.float32, device=device)
a = 33
out = cfunc(x, a)
expected = func(x, a)
torch.testing.assert_close(out, expected)
```
Error message on Nightly:
```
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fx_wrapper' raised:
CompilationError: at 7:38:def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = ((1.23400000000000*ks0) // 8.23400000000000)
^
AssertionError()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115751
Approved by: https://github.com/peterbell10
This PR fixed#104791
bitcast requires the source and target have the bitwidth.
Because the input tensor's dtype could be promoted, e.g. from float16 to
float, we have to cast the tensor to its original source dtype before
invoking bitcast in such cases. After that, we also need to convert
the bit-casted tensor back to float to make sure we keep using higher
precision values for the rest of the computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115619
Approved by: https://github.com/jansel, https://github.com/eellison
Previously if two calls to cumsum were generated in the same triton kernel
we would generate identical helper functions with different names. Now this
recognizes identical functions and only defines it once. To do this I defer
choosing the name until after codegen.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115546
Approved by: https://github.com/lezcano
ghstack dependencies: #109132