This commit had inconsistent internal land and pr merged. This caused merge conflicts that required revert in both places, normalize the internal commit stack, and then re-land properly.
Original commit: #88384 (011452a2a1c745d4b12f83f89eca039f482d134b)
Inconsistent revert: #90018 (8566aa7c0b4bdca50bf85ca14705b4304de030b3)
Revert of the inconsistent revert to restore healthy state (or re-land of the original commit): cf3c3f22804be6909e54fc09e07f891ab0886774
Landing the correct, internally congruent revert of the original commit: (This PR) #90055 (TBD)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90055
Approved by: https://github.com/DanilBaibak, https://github.com/malfet
This PR is targeting to automatically enable vectorization optimization for TorchInductor. It refined the semantics of `config.cpp.simdlen`.
Originally, `None` means to disable vectorization while a specific value means the number of elements to be vectorized once time. But it depends on the data. Regarding 256bit SVE/SIMD ISA for ARM and X86, the `simdlen` should be 16 for Float while 32 for BFloat. Hence, this PR defined the `simdlen` as the bit width. The detailed semantics are as follows.
- **_simdlen = None_**: Automatically determine the SIMD bit width. Detect HW information and pick the proper vectorization ISA. Specific for X86, the priority of AVX512 is higher than AVX2.
- **_simdlen <=1_**: Explicitly disable SIMD
- **_simdlen > 1_**: Explicitly specify the SIMD bit width. It equals the disabled semantic if the bit width does not match the ISA width.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89263
Approved by: https://github.com/jgong5, https://github.com/jansel
This PR is targeting to automatically enable vectorization optimization for TorchInductor. It refined the semantics of `config.cpp.simdlen`.
Originally, `None` means to disable vectorization while a specific value means the number of elements to be vectorized once time. But it depends on the data. Regarding 256bit SVE/SIMD ISA for ARM and X86, the `simdlen` should be 16 for Float while 32 for BFloat. Hence, this PR defined the `simdlen` as the bit width. The detailed semantics are as follows.
- **_simdlen = None_**: Automatically determine the SIMD bit width. Detect HW information and pick the proper vectorization ISA. Specific for X86, the priority of AVX512 is higher than AVX2.
- **_simdlen <=1_**: Explicitly disable SIMD
- **_simdlen > 1_**: Explicitly specify the SIMD bit width. It equals the disabled semantic if the bit width does not match the ISA width.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88482
Approved by: https://github.com/jgong5, https://github.com/jansel
In this PR, we replace OMP SIMD with `aten::vec` to optimize TorchInductor vectorization performance. Take `res=torch.exp(torch.add(x, y))` as the example. The generated code is as follows if `config.cpp.simdlen` is 8.
```C++
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
const long ks0,
const long ks1)
{
#pragma omp parallel num_threads(48)
{
#pragma omp for
for(long i0=0; i0<((ks0*ks1) / 8); ++i0)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 8*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 8*i0);
auto tmp2 = tmp0 + tmp1;
auto tmp3 = tmp2.exp();
tmp3.store(out_ptr0 + 8*i0);
}
#pragma omp for simd simdlen(4)
for(long i0=8*(((ks0*ks1) / 8)); i0<ks0*ks1; ++i0)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_ptr1[i0];
auto tmp2 = tmp0 + tmp1;
auto tmp3 = std::exp(tmp2);
out_ptr0[i0] = tmp3;
}
}
}
```
The major pipeline is as follows.
- Check whether the loop body could be vectorized by `aten::vec`. The checker consists of two parts. [One ](bf66991fc4/torch/_inductor/codegen/cpp.py (L702))is to check whether all the `ops` have been supported. The [other one](355326faa3/torch/_inductor/codegen/cpp.py (L672)) is to check whether the data access could be vectorized.
- [`CppSimdVecKernelChecker`](355326faa3/torch/_inductor/codegen/cpp.py (L655))
- Create the `aten::vec` kernel and original omp simd kernel. Regarding the original omp simd kernel, it serves for the tail loop when the loop is vectorized.
- [`CppSimdVecKernel`](355326faa3/torch/_inductor/codegen/cpp.py (L601))
- [`CppSimdVecOverrides`](355326faa3/torch/_inductor/codegen/cpp.py (L159)): The ops that we have supported on the top of `aten::vec`
- Create kernel
- [`aten::vec` kernel](355326faa3/torch/_inductor/codegen/cpp.py (L924))
- [`Original CPP kernel - OMP SIMD`](355326faa3/torch/_inductor/codegen/cpp.py (L929))
- Generate code
- [`CppKernelProxy`](355326faa3/torch/_inductor/codegen/cpp.py (L753)) is used to combine the `aten::vec` kernel and original cpp kernel
- [Vectorize the most inner loop](355326faa3/torch/_inductor/codegen/cpp.py (L753))
- [Generate code](355326faa3/torch/_inductor/codegen/cpp.py (L821))
Next steps:
- [x] Support reduction
- [x] Vectorize the tail loop with `aten::vec`
- [ ] Support BF16
- [ ] Optimize the loop condition and loop index calculation by replacing `div` with `add`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87068
Approved by: https://github.com/jgong5, https://github.com/jansel
If python was launched with 'spawn' it will not use the standard
shutdown methods that concurrent.futures requires. So we register a
shutdown with the method it does uses. Without this, shutdown hangs
since the workers will not exit.
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87725
Approved by: https://github.com/wconstab
To cooperate with other multithreading methods, this
forces the process pool to use 'fork' even if others have set it
diferently. We require fork because otherwise `if __name__ == __main__`
needs to be set which we do not control as a library.
Furthermore this adds code to cleanup worker processes if
the parent exits abnormally (e.g. segfault). Previously we would leave
live but inactive workers around.
cc @jansel @lezcano @fdrocha
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87411
Approved by: https://github.com/soumith, https://github.com/anijain2305
Fixes https://github.com/pytorch/pytorch/pull/87048 by saving the needed properties before fork.
Actually attempting to get CUDA to load in the workers is probably not desired: cuda initialization takes O(seconds). Having multiple processes using the same device will slow things down.
This just moves the needed properties from the main trainer process to the workers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87101
Approved by: https://github.com/soumith
This patch significantly improves the parallel compilation performance for cThis patch significantly improves the parallel compilation performance for compiling triton kernels
by using ProcessPoolExecutor to create persistent pool of compilation
workers.
Previously os.fork overhead and GIL contention limited the achieved
parallelism. This patch replaces
the worker threads with a pool of processes to do the raw compilation,
and does serial work on the main thread
for everything else. This other work couldn't be parallelized anyway
since it is mostly in python.
In cold start situations, the time to get the worker threads started can
be significant portion of the time.
This patch starts the workers earlier so they are ready to perform
compilation (see code comments) when dynamo
gets to that point.
Just tested this on one example benchmark (tf_efficientnet_b0), but the
results are significant, almost eliminating the difference between a
warm and cold compilation.
```
39.613s - warm
41.290s - cold, this patch
2m53.197s - cold, single threaded:
1m7.092s - cold, old setup n = 8 (its best config)
```
(cold compilation is done after running `rm -rf
/tmp/torchinductor_$USER`).ompiling triton kernels
by using ProcessPoolExecutor to create persistent pool of compilation workers.
Previously os.fork overhead and GIL contention limited the achieved parallelism. This patch replaces
the worker threads with a pool of processes to do the raw compilation, and does serial work on the main thread
for everything else. This other work couldn't be parallelized anyway since it is mostly in python.
In cold start situations, the time to get the worker threads started can be significant portion of the time.
This patch starts the workers earlier so they are ready to perform compilation (see code comments) when dynamo
gets to that point.
Just tested this on one example benchmark (tf_efficientnet_b0), but the results are significant, almost eliminating the difference between a warm and cold compilation.
```
39.613s - warm
41.290s - cold, this patch
2m53.197s - cold, single threaded:
1m7.092s - cold, old setup n = 8 (its best config)
```
(cold compilation is done after running `rm -rf /tmp/torchinductor_$USER`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87032
Approved by: https://github.com/soumith, https://github.com/jansel