Restart the work from PR https://github.com/pytorch/pytorch/pull/100331 in this new PR since it's hard to rebase. It would be expected that some code is copy/pasted from the previous PR and main idea is the same.
Previously we see relatively large compilation time increase due to too many loop orders being considered. This PR tries to continue the work by doing pruning and only considering loop orders that we know for sure are relevant (i.e. do it on demand).
Some manually created cases that loop ordering matters are added as unit tests. The PR can make sure inductor does not miss fusion opportunities for them.
This PR should solve the not-able to fusion problem in https://github.com/pytorch/pytorch/issues/130015
Right now there is still significant increase of compilation time. I'll disable the feature by default. Later on after the compilation time issue is resolved, I'll enable it by default.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126254
Approved by: https://github.com/jansel
Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.
The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor
allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:
```
def inner_fn(idx):
selector_idx = list(idx)
selector_idx[dim] = 0 # can do this since the index tensor has a single element on the scatter dimension
selector = selector_loader(selector_idx)
return ops.where(
selector == ops.index_expr(idx[dim], torch.int64),
ops.constant(val, dtype),
ops.constant(background_val, dtype),
)
```
## Test result on microbenchmark
For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).
The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms
## Test result on llm.c
We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check https://github.com/pytorch/pytorch/issues/129258 and https://github.com/pytorch/pytorch/pull/129399 for more details).
For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)
## Test on PyTorch 2.0 Dashboard
The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).
TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129043
Approved by: https://github.com/jansel, https://github.com/Chillee
The scheduler searches for fusion opportunities by looking for common memory access. Two memory access are considered common not only when the buffer name match, but it also requires more things
- index formula matches
- var_ranges matches
In this PR, I want to log all the fusion failures due to mismatch index formula or var_ranges. I also want to further categories the failures. Right now I found the following failure categories
- rand_seed: the index for rand seed access is an integer and different access uses different integer offset
- different numel: this happens for cat operation
- broadcast: e.g. kernel A write a buffer which is broadcasted and read by kernel B
- different loop orders: the major category we want inductor to be able to fuse
- different offset: happens when use a concatenated linear layer to project Q/K/V and then split the result. Each split will point to the same buffer with different offset.
- unknown
My hope is to make sure for the models I tested, there is no fusion failure falling in the unknown category so all the failures are well understood and categories. Right now it's true for BertForMaskedLM ( https://gist.github.com/shunting314/6dc2c903629d342fa63ba731a171adc2 ), DistillGPT2 ( https://gist.github.com/shunting314/145176f2e850103c7fad4ad72f0e200e ) and llm.c ( https://gist.github.com/shunting314/cfc64a326312a889ba55f79bd47b2082 )
For BertForMaskedLM, we found 82 instances of fusion failures and majority of them are due to different loop orders! Studying the log a bit more can help us figure out where all these loop order mismatch comes from in real models.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124986
Approved by: https://github.com/eellison, https://github.com/jansel
This PR adds the ability to pad tensor strides during lowering. The goal is to make sure (if possible) tensors with bad shape can have aligned strides so GPU can access the memory more efficiently.
By testing BlenderbotSmallForConditionalGeneration I already see 2.5ms speedup.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120758
Approved by: https://github.com/jansel
Summary: This is in preparation to enable FX graph caching by default. First fix some bugs uncovered by running all unit tests under `test/inductor/`. I'll enable in a separate diff in case we need to revert. Summary of changes:
* Turn off caching for tests that require a compilation, e.g., when checking that a relevant counter was incremented
* Bypass caching when we see mkldnn tensors as constants (they currently don't serialize, so we can't save to disk)
* Include various global settings that could affect compilation in the cache key calculation.
* Handle a few config settings that break key calculation.
* Handle code paths where no ShapeEnv is available (the cache impl requires a shape env as part of handling guards)
* Skip caching when freezing is enabled (Freezing can embed constants that wouldn't be static across runs).
* Fix the clear() method to not throw when the cache /tmp dir doesn't exist
Test Plan: Ran all tests under `test/inductor/` twice with TORCHINDUCTOR_FX_GRAPH_CACHE=1 to exercise any test that might be affected by caching.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117888
Approved by: https://github.com/eellison
Log a few more fields
- num_atomic_add: perf of kernels using atomic_add are usually data dependent. Our benchmarking code generate all indices to be 0 which will result in worse perf than reality.
- kernel_args_num_gb: estimate the amount of read/writes for kernel args. In-place args will be double counted. If we have a good estimation, this should be the lower bound of memory access that the GPU performs. Sometimes GPU will do more memory access since a single buffer may be access multiple times (e.g. for softmax when input tensor is quite large. cache only help a bit here). With this logged, and if we augment the metadata with amount of memory the GPU actually accessed, then it would be nice to dig into kernels that GPU access more memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120274
Approved by: https://github.com/jansel
ghstack dependencies: #120266
I want to log metadata for inductor generated triton kernels for a couple of purposes
1. with these metadata, it should be convenient to find unaligned reduction kernels and try the idea here https://github.com/pytorch/pytorch/issues/119929 . I think it's nice to try on kernels that are used in real models
2. I'm thinking that based on the collected kernel metadata, I can build a simple offline tool by benchmarking each kernel with ncu and augment each kernel metadata with: latency, theoretical membw (estimated memory access / latency), and actually achieved membw. Hopefully this can point us to some good optimization opportunities.
Command:
```
TORCHINDUCTOR_CACHE_DIR=`realpath ~/inductor-caches/kernel-metadata-log` TORCHINDUCTOR_ENABLED_METRIC_TABLES=kernel_metadata TORCHINDUCTOR_BENCHMARK_KERNEL=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 time python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --training
```
The best practice here is to point inductor cache to a folder outside of /tmp so that one can always run the kernel again based on the path stored in kernel metadata. (folders under /tmp may get removed by the system)
Here is first 1000 rows of collected metadata for huggingface: https://gist.github.com/shunting314/cf4ebdaaaa7e852efcaa93524c868e5f
And here is the total 10K kernels collected for huggingface. The gist can not be rendered as a csv since it's too large: https://gist.github.com/shunting314/7f841528e2debdc2ae05dece4ac591be .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120048
Approved by: https://github.com/jansel
For a persistent reduction, we generate 2 flavor of 'equivalant' kernels at the same time
- persistent reduction
- regular reduction
A MultiKernel wraps these 2 kernels and pick the one with better performance at runtime.
Here I talk more about implementation details:
- Inductor maintains states for generating kernels. E.g. the wrapper code. After we generate code for one kernel, we need restore the inductor state before we can generate the counterpart.
***There is one thing I need some comments from others***:
There is one tricky thing about kernel arguments. In general, inductor removes a buffer from the argument list if it's only used inside the kernel. But somehow a buffer removed by persistent reduction kernel may still be kept by the regular (non-persistent) reduction kernel because of some CSE invalidation rule. My current implementation avoid removing buffers if multi_kernel is enabled. This makes sure both flavors of reduction has consistent argument list. Another idea I have is, we generate the multi-kernel definition with the union of arguments from both sub-kernels. Let each sub-kernel pick the subset of arguments it wants. But this will make the code-gen or multi-kernel much complex.
I'm not sure if there is some easy and clean way to resolve this.
Testing command:
```
TORCHINDUCTOR_MULTI_KERNEL=1 TORCH_LOGS=+torch._inductor.graph TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only BertForMaskedLM --training
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103469
Approved by: https://github.com/jansel
In dynamo/inductor, sometimes it helps to gather metrics/statistics for each model in different levels like model level, graph level, kernel level or pair of fusion nodes level. This kind of thing will be very easy to do with Scuba, but we only have scuba in fbcode. This PR build metric tables to solve part of the problem.
Q: why not log to stdout/err direclty
A: sometimes we need more structured data. E.g., it would be helpful to gather all the stats in a CSV and then do post-processing (like calculating a geomean etc.). Also metric table will tag each row with the model name which is helpful.
Q: what's the difference with speedup_indcutor.csv
A: speedup_indcutor.csv is a special case that gather statistics on model level: i.e., we have one row for each model. But recording statistics on finer grain level like graph etc. is also helpful.
Example use cases:
- As a followup on the bechmark fusion PR, I want to gather all the 'slow' fusion and analyze them. With the metric table, I can easily log slow fusion for each model into a csv file. Here is the log gathered for huggingface:
https://gist.github.com/shunting314/964e73cc98368b301414ec7b7ad4c702 .
- To help understand the effect of 'loop ordering after fusion' PR, it would be helpful to gather stats like how many fusions happens for each graph. Previously we log the metric to stderr directly. But logging these metrics in a structural way is useful.
- gather number of registers, register spills, shared memory usage for each kernel in each model with runnable kernel code logged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109245
Approved by: https://github.com/jansel, https://github.com/mlazos
Working as starter task with @Chillee
This PR adds a method under BaseSchedulerNode to estimate the node's runtime in seconds.
We use a heuristic based approach, first by considering whether the operation is memory bandwidth bounded or compute bounded:
- memory bandwidth bounded: we compute the number of bytes that are read/written to
- compute bounded: we compute the FLOPS required by the operation
One use case could be to be used as a cost model for scheduling: https://github.com/pytorch/pytorch/pull/100762
```
(pytorch-3.10) [14:08:02] ~/local/pytorch (xmfan/estimate_snode_runtime) > python3 test/inductor/test_perf.py -k EstimateSnodeRuntimeTests
[(ExternKernelSchedulerNode(name='buf0'), 400)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000), (SchedulerNode(name='buf1'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26), (SchedulerNode(name='buf1'), 7.187055238190188e-09)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26)]
.[(ExternKernelSchedulerNode(name='buf0'), 34600)]
[(ExternKernelSchedulerNode(name='buf0'), 3.22687496698039e-24)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 7776176)]
[(ExternKernelSchedulerNode(name='buf0'), 4.63240241413653e-21)]
.[(FusedSchedulerNode(nodes=buf0_buf1), 210)]
[(FusedSchedulerNode(nodes=buf0_buf1), 5.030938666733132e-10)]
.[(ExternKernelSchedulerNode(name='buf0'), 300)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(SchedulerNode(name='buf0'), 20)]
[(SchedulerNode(name='buf0'), 4.7913701587934585e-11)]
.
----------------------------------------------------------------------
Ran 10 tests in 14.311s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106426
Approved by: https://github.com/Chillee
Currently there is `test_vertical_fusion1` which fuses entirely during
the lowering stage and no buffers are realized. This adds
`test_scheduler_vertical_fusion1` which is the same test but with
several intermediate calculations realized so the scheduler is left
to do the fusion.
To support the test, this PR also adds:
- `metrics.ir_nodes_pre_fusion` which when compared with
`generated_kernel_count` tells us how many nodes were fused.
- `torch._test_inductor_realize` which is an identity operator in
eager, but under inductor also forces the input to be realized.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90014
Approved by: https://github.com/jansel
In this PR, we replace OMP SIMD with `aten::vec` to optimize TorchInductor vectorization performance. Take `res=torch.exp(torch.add(x, y))` as the example. The generated code is as follows if `config.cpp.simdlen` is 8.
```C++
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
const long ks0,
const long ks1)
{
#pragma omp parallel num_threads(48)
{
#pragma omp for
for(long i0=0; i0<((ks0*ks1) / 8); ++i0)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 8*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 8*i0);
auto tmp2 = tmp0 + tmp1;
auto tmp3 = tmp2.exp();
tmp3.store(out_ptr0 + 8*i0);
}
#pragma omp for simd simdlen(4)
for(long i0=8*(((ks0*ks1) / 8)); i0<ks0*ks1; ++i0)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_ptr1[i0];
auto tmp2 = tmp0 + tmp1;
auto tmp3 = std::exp(tmp2);
out_ptr0[i0] = tmp3;
}
}
}
```
The major pipeline is as follows.
- Check whether the loop body could be vectorized by `aten::vec`. The checker consists of two parts. [One ](bf66991fc4/torch/_inductor/codegen/cpp.py (L702))is to check whether all the `ops` have been supported. The [other one](355326faa3/torch/_inductor/codegen/cpp.py (L672)) is to check whether the data access could be vectorized.
- [`CppSimdVecKernelChecker`](355326faa3/torch/_inductor/codegen/cpp.py (L655))
- Create the `aten::vec` kernel and original omp simd kernel. Regarding the original omp simd kernel, it serves for the tail loop when the loop is vectorized.
- [`CppSimdVecKernel`](355326faa3/torch/_inductor/codegen/cpp.py (L601))
- [`CppSimdVecOverrides`](355326faa3/torch/_inductor/codegen/cpp.py (L159)): The ops that we have supported on the top of `aten::vec`
- Create kernel
- [`aten::vec` kernel](355326faa3/torch/_inductor/codegen/cpp.py (L924))
- [`Original CPP kernel - OMP SIMD`](355326faa3/torch/_inductor/codegen/cpp.py (L929))
- Generate code
- [`CppKernelProxy`](355326faa3/torch/_inductor/codegen/cpp.py (L753)) is used to combine the `aten::vec` kernel and original cpp kernel
- [Vectorize the most inner loop](355326faa3/torch/_inductor/codegen/cpp.py (L753))
- [Generate code](355326faa3/torch/_inductor/codegen/cpp.py (L821))
Next steps:
- [x] Support reduction
- [x] Vectorize the tail loop with `aten::vec`
- [ ] Support BF16
- [ ] Optimize the loop condition and loop index calculation by replacing `div` with `add`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87068
Approved by: https://github.com/jgong5, https://github.com/jansel