# Changes over the previous PR
This reverts commit 61a1f09 and adds `__launch_bounds__` to the kernel.
Previously I merged 114d404 that did not work on Blackwell because it consumed too many registers. It got reverted in 61a1f09. For more context see: https://github.com/pytorch/pytorch/issues/150266.
This PR reverts the revert (i.e. reapplies the original diff), with one additional line with `__launch_bounds__` added:
```
git diff HEAD^
diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
index 0d63a2f979c..3ce2c24c18e 100644
--- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu
+++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
@@ -657,6 +657,7 @@ bool aligned_grid
>
__global__
void
+__launch_bounds__(block_dim_x * block_dim_y)
GammaBetaBackwardCUDAKernelTemplate(
int64_t M,
int64_t N,
```
I managed to get a Blackwell machine and verified that the fix works. The fix was verified using this repro that I got from @drisspg
<details>
<summary> Repro script that fails on Blackwell </summary>
```
import torch
from torch.nn import init
# from transformer_nuggets import init_logging
# from transformer_nuggets.utils.benchmark import profiler
# from pathlib import Path
# init_logging()
class PermuteModule(torch.nn.Module):
def __init__(self, permutation):
super(PermuteModule, self).__init__()
self.permutation = permutation
def forward(self, x:torch.Tensor) -> torch.Tensor:
assert len(x.shape) == len(self.permutation), f"Dimension mismatch! Unable to permute {len(x.shape)} dim input with a {len(self.permutation)} dim permutation!"
return x.permute(*self.permutation)
def test(n_layers:int, conv_stride:int):
_sequence = []
for _ in range(n_layers):
# Conv1d inputs are (N x C x L), LayerNorm expects (* x C). Dims must be permuted between modules.
_sequence += [
PermuteModule((0,2,1)),
torch.nn.Conv1d(in_channels=512, out_channels=512, groups=1, kernel_size=9, dilation=1, stride=conv_stride, padding=0, bias=False),
PermuteModule((0,2,1)),
torch.nn.LayerNorm(512),
torch.nn.ReLU()
]
model = torch.nn.Sequential(*_sequence).to(device="cuda")
data = torch.randn((100,2048,512), device="cuda")
out = model(data)
loss = torch.nn.functional.mse_loss(out, torch.rand_like(out))
loss.backward()
torch.autograd.set_detect_anomaly(True)
print(f"Torch version: {torch.__version__}")
# with profiler(Path("conv")):
# # print(f"layers=1, stride=1")
# # test(n_layers=1, conv_stride=1)
# # print(f"layers=2, stride=1")
# # test(n_layers=2, conv_stride=1)
# # print(f"layers=1, stride=2")
# # test(n_layers=1, conv_stride=2)
# print(f"layers=2, stride=2")
# test(n_layers=2, conv_stride=2)
print(f"layers=2, stride=2")
test(n_layers=2, conv_stride=2)
# we will not reach this print statement.
print("DONE.")
```
</details>
I also re-ran my performance benchmark and found no regressions over the previous PR.
# Full description of the old PR
Original PR: https://github.com/pytorch/pytorch/pull/148605
This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way.
To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables:
1. dtype in {half, float}
2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
4. Whether we flush the L2 cache before running the backward pass
Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster).
In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*.
Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old:
M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary:
```
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.35
Elapsed Cycles cycle 27,526
Memory Throughput % 2.21
DRAM Throughput % 0.54
Duration us 20.42
L1/TEX Cache Throughput % 4.31
L2 Cache Throughput % 2.62
SM Active Cycles cycle 1,475.02
Compute (SM) Throughput % 0.29
----------------------- ----------- ------------
```
M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary:
```
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.34
Elapsed Cycles cycle 10,920
Memory Throughput % 5.64
DRAM Throughput % 1.35
Duration us 8.13
L1/TEX Cache Throughput % 1.92
L2 Cache Throughput % 6.89
SM Active Cycles cycle 3,554.41
Compute (SM) Throughput % 0.67
----------------------- ----------- ------------
```
Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following:
<img width="1508" alt="image" src="https://github.com/user-attachments/assets/06179599-b2f0-4a45-8664-247a1067950b" />
There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made:

For dtype=float32, we get a similar chart:
<img width="1499" alt="image" src="https://github.com/user-attachments/assets/c4d31a76-03b0-426c-9114-e1bfad29b530" />
The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension).
The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough.
I am including the regressions here for completeness' sake:
<img width="1500" alt="image" src="https://github.com/user-attachments/assets/31c17cfb-ed9b-4106-b9c8-5c359751f530" />
To see this better:
1. Click the image
2. Right click the expanded image and open in a new tab
3. Go to that tab and left click once to zoom in
If you want to see the full data, here it is:

I also measured binary size and compile time since those are important for developers:
Binary size comparison

```
# Original
-rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so
# This PR
-rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so
```
The diff in bytes is 302kB which is about a 0.1% increase.
Compile time difference:
```
# Original
real 0m10.931s
user 0m9.676s
sys 0m1.004s
# this PR
real 0m16.720s
user 0m15.514s
sys 0m1.066s
# Command I ran
time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o
```
So the new PR is 6 seconds longer compile time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150625
Approved by: https://github.com/ngimel, https://github.com/atalman
# Changes over the previous PR
This reverts commit 61a1f09 and adds `__launch_bounds__` to the kernel.
Previously I merged 114d404 that did not work on Blackwell because it consumed too many registers. It got reverted in 61a1f09. For more context see: https://github.com/pytorch/pytorch/issues/150266.
This PR reverts the revert (i.e. reapplies the original diff), with one additional line with `__launch_bounds__` added:
```
git diff HEAD^
diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
index 0d63a2f979c..3ce2c24c18e 100644
--- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu
+++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
@@ -657,6 +657,7 @@ bool aligned_grid
>
__global__
void
+__launch_bounds__(block_dim_x * block_dim_y)
GammaBetaBackwardCUDAKernelTemplate(
int64_t M,
int64_t N,
```
I managed to get a Blackwell machine and verified that the fix works. The fix was verified using this repro that I got from @drisspg
<details>
<summary> Repro script that fails on Blackwell </summary>
```
import torch
from torch.nn import init
# from transformer_nuggets import init_logging
# from transformer_nuggets.utils.benchmark import profiler
# from pathlib import Path
# init_logging()
class PermuteModule(torch.nn.Module):
def __init__(self, permutation):
super(PermuteModule, self).__init__()
self.permutation = permutation
def forward(self, x:torch.Tensor) -> torch.Tensor:
assert len(x.shape) == len(self.permutation), f"Dimension mismatch! Unable to permute {len(x.shape)} dim input with a {len(self.permutation)} dim permutation!"
return x.permute(*self.permutation)
def test(n_layers:int, conv_stride:int):
_sequence = []
for _ in range(n_layers):
# Conv1d inputs are (N x C x L), LayerNorm expects (* x C). Dims must be permuted between modules.
_sequence += [
PermuteModule((0,2,1)),
torch.nn.Conv1d(in_channels=512, out_channels=512, groups=1, kernel_size=9, dilation=1, stride=conv_stride, padding=0, bias=False),
PermuteModule((0,2,1)),
torch.nn.LayerNorm(512),
torch.nn.ReLU()
]
model = torch.nn.Sequential(*_sequence).to(device="cuda")
data = torch.randn((100,2048,512), device="cuda")
out = model(data)
loss = torch.nn.functional.mse_loss(out, torch.rand_like(out))
loss.backward()
torch.autograd.set_detect_anomaly(True)
print(f"Torch version: {torch.__version__}")
# with profiler(Path("conv")):
# # print(f"layers=1, stride=1")
# # test(n_layers=1, conv_stride=1)
# # print(f"layers=2, stride=1")
# # test(n_layers=2, conv_stride=1)
# # print(f"layers=1, stride=2")
# # test(n_layers=1, conv_stride=2)
# print(f"layers=2, stride=2")
# test(n_layers=2, conv_stride=2)
print(f"layers=2, stride=2")
test(n_layers=2, conv_stride=2)
# we will not reach this print statement.
print("DONE.")
```
</details>
I also re-ran my performance benchmark and found no regressions over the previous PR.
# Full description of the old PR
Original PR: https://github.com/pytorch/pytorch/pull/148605
This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way.
To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables:
1. dtype in {half, float}
2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
4. Whether we flush the L2 cache before running the backward pass
Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster).
In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*.
Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old:
M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary:
```
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.35
Elapsed Cycles cycle 27,526
Memory Throughput % 2.21
DRAM Throughput % 0.54
Duration us 20.42
L1/TEX Cache Throughput % 4.31
L2 Cache Throughput % 2.62
SM Active Cycles cycle 1,475.02
Compute (SM) Throughput % 0.29
----------------------- ----------- ------------
```
M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary:
```
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.34
Elapsed Cycles cycle 10,920
Memory Throughput % 5.64
DRAM Throughput % 1.35
Duration us 8.13
L1/TEX Cache Throughput % 1.92
L2 Cache Throughput % 6.89
SM Active Cycles cycle 3,554.41
Compute (SM) Throughput % 0.67
----------------------- ----------- ------------
```
Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following:
<img width="1508" alt="image" src="https://github.com/user-attachments/assets/06179599-b2f0-4a45-8664-247a1067950b" />
There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made:

For dtype=float32, we get a similar chart:
<img width="1499" alt="image" src="https://github.com/user-attachments/assets/c4d31a76-03b0-426c-9114-e1bfad29b530" />
The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension).
The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough.
I am including the regressions here for completeness' sake:
<img width="1500" alt="image" src="https://github.com/user-attachments/assets/31c17cfb-ed9b-4106-b9c8-5c359751f530" />
To see this better:
1. Click the image
2. Right click the expanded image and open in a new tab
3. Go to that tab and left click once to zoom in
If you want to see the full data, here it is:

I also measured binary size and compile time since those are important for developers:
Binary size comparison

```
# Original
-rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so
# This PR
-rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so
```
The diff in bytes is 302kB which is about a 0.1% increase.
Compile time difference:
```
# Original
real 0m10.931s
user 0m9.676s
sys 0m1.004s
# this PR
real 0m16.720s
user 0m15.514s
sys 0m1.066s
# Command I ran
time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o
```
So the new PR is 6 seconds longer compile time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150625
Approved by: https://github.com/ngimel
This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way.
To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables:
1. dtype in {half, float}
2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
4. Whether we flush the L2 cache before running the backward pass
Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster).
In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*.
Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old:
M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary:
```
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.35
Elapsed Cycles cycle 27,526
Memory Throughput % 2.21
DRAM Throughput % 0.54
Duration us 20.42
L1/TEX Cache Throughput % 4.31
L2 Cache Throughput % 2.62
SM Active Cycles cycle 1,475.02
Compute (SM) Throughput % 0.29
----------------------- ----------- ------------
```
M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary:
```
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.34
Elapsed Cycles cycle 10,920
Memory Throughput % 5.64
DRAM Throughput % 1.35
Duration us 8.13
L1/TEX Cache Throughput % 1.92
L2 Cache Throughput % 6.89
SM Active Cycles cycle 3,554.41
Compute (SM) Throughput % 0.67
----------------------- ----------- ------------
```
Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following:
<img width="1508" alt="image" src="https://github.com/user-attachments/assets/06179599-b2f0-4a45-8664-247a1067950b" />
There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made:

For dtype=float32, we get a similar chart:
<img width="1499" alt="image" src="https://github.com/user-attachments/assets/c4d31a76-03b0-426c-9114-e1bfad29b530" />
The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension).
The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough.
I am including the regressions here for completeness' sake:
<img width="1500" alt="image" src="https://github.com/user-attachments/assets/31c17cfb-ed9b-4106-b9c8-5c359751f530" />
To see this better:
1. Click the image
2. Right click the expanded image and open in a new tab
3. Go to that tab and left click once to zoom in
If you want to see the full data, here it is:

I also measured binary size and compile time since those are important for developers:
Binary size comparison

```
# Original
-rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so
# This PR
-rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so
```
The diff in bytes is 302kB which is about a 0.1% increase.
Compile time difference:
```
# Original
real 0m10.931s
user 0m9.676s
sys 0m1.004s
# this PR
real 0m16.720s
user 0m15.514s
sys 0m1.066s
# Command I ran
time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o
```
So the new PR is 6 seconds longer compile time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148605
Approved by: https://github.com/ngimel
Fixes#103425
## Changes
- Add doc description size value `must be > 0`
- Add validation for `in1_features` param
Currently, only `in1_features` will cause runtime error, if add checks for `in2_features` and `out_features` as well, might be kind of BC breaking.
```python
import torch
from torch import nn
class lenet(nn.Module):
def __init__(self):
super(lenet, self).__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)
# Error, `in1_features=1, in2_features=0, out_features=0` no error
self.linear = nn.Bilinear(in1_features=0, in2_features=0, out_features=0)
def forward(self, x):
# 1st block
x = self.conv(x)
x = self.linear(x)
return x
if __name__ == '__main__':
net = lenet()
```
## Test Result
```bash
pytest test/test_nn.py -k test_bilinear -vv
```


Pull Request resolved: https://github.com/pytorch/pytorch/pull/149018
Approved by: https://github.com/mikaylagawarecki
Fixes#134106. This PR moves the `upcasted_result` down-casting after all computation is done.
Since the multiplication with the weight_opt input is not done in half precision, the current code path is doing the following: fp16 -> fp32 -> fp16 -> fp32 -> fp16. What we want tho is to avoid down-casting and this PR proposes: fp16 -> fp32 -> fp16. This results in better accuracy as it avoids truncating.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147203
Approved by: https://github.com/eqy
This PR is similar to https://github.com/pytorch/pytorch/pull/122970, but works on the softmax backward pass.
Specifically, it uses shared memory to cache the gradOutput when it can fit in shared memory. Before this PR we were reading gradOutput twice.
On my H100 this seems to improve the softmax backward pass performance by about 5% for problem sizes that fit within shared memory. (Note that this is not the only kernel that runs when you call softmax backward pass -- there is an elementwise kernel that runs before this; optimizing that can be a separate PR).
**Important Note**: Currently the softmax backward pass consists of an [element-wise multiply operator](7f65a20884/aten/src/ATen/native/cuda/SoftMax.cu (L1216)), followed by [this function](7f65a20884/aten/src/ATen/native/cuda/SoftMax.cu (L1062)) which calls the `cunn_SoftMaxBackward` kernel. With my change the kernel time reduces by about 12% (see screenshot below), while the total time (including the elementwise) reduces by about 5%.
```
Baseline This PR
N size FP32 bandwidth FP16 bandwidth N size FP32 bandwidth FP16 bandwidth fp32 diff fp16 diff
0 256 134.340966 70.042039 0 256 133.70146 70.342753 -0.48% 0.43%
1 512 233.501185 129.945803 1 512 234.057145 132.933066 0.24% 2.30%
2 1024 340.667966 229.280464 2 1024 338.833265 226.441699 -0.54% -1.24%
3 2048 379.643726 337.452058 3 2048 399.559017 338.432284 5.25% 0.29%
4 4096 416.597537 383.625364 4 4096 428.252403 396.137506 2.80% 3.26%
5 6000 431.198241 384.384384 5 6000 457.744577 406.06275 6.16% 5.64%
6 8192 462.811252 427.292573 6 8192 474.791032 428.281563 2.59% 0.23%
7 10000 464.258731 429.050294 7 10000 483.7643 446.849381 4.20% 4.15%
8 10013 465.199701 429.824179 8 10013 464.904407 428.72184 -0.06% -0.26%
9 10240 477.07359 428.853737 9 10240 485.317024 444.902586 1.73% 3.74%
10 11000 473.038785 430.778663 10 11000 488.161438 453.462162 3.20% 5.27%
11 12000 474.342475 432.594814 11 12000 490.532418 458.427653 3.41% 5.97%
12 16384 487.468854 473.611576 12 16384 488.154406 476.264631 0.14% 0.56%
13 20000 482.029793 465.666186 13 20000 482.147092 483.886193 0.02% 3.91%
14 24000 478.368093 474.159464 14 24000 478.364948 491.447921 0.00% 3.65%
15 32000 476.523796 473.18868 15 32000 476.523796 474.398962 0.00% 0.26%
16 32768 476.104723 477.493634 16 32768 476.704463 477.330606 0.13% -0.03%
17 36864 477.900663 475.472787 17 36864 477.973279 475.728454 0.02% 0.05%
18 40960 477.707561 475.559064 18 40960 478.445017 476.088067 0.15% 0.11%
19 45056 479.169812 475.865134 19 45056 479.143266 475.878202 -0.01% 0.00%
20 49152 477.804907 475.382982 20 49152 477.868404 475.976377 0.01% 0.12%
21 65536 481.274125 478.171806 21 65536 481.537733 478.703926 0.05% 0.11%
22 66000 481.64652 480.095457 22 66000 481.856013 480.466388 0.04% 0.08%
23 68608 481.745774 479.034704 23 68608 481.917596 478.856209 0.04% -0.04%
24 80000 483.409361 480.356529 24 80000 483.330481 480.375277 -0.02% 0.00%
25 98304 480.736301 481.396882 25 98304 480.789858 481.320143 0.01% -0.02%
```
NCU profiler shows lower DRAM fetches with the new kernel:

NCU reports about 12% elapsed time reduction in this kernel alone compared to baseline (and because of other kernels that are run, the overall backward pass time as seen by the user gets reduced by 5%).
I compared the binary size increase by running `python setup.py develop` before and after and diffing the .so files:

libtorch_cuda.so goes from 274,752,224 bytes to 274,787,072 bytes. The increase in size is 34kB which is about 0.01%.
I measured the compilation time for incremental development:
```
touch ./aten/src/ATen/native/cuda/SoftMax.cu
time python setup.py develop
real 0m10.083s
user 0m8.197s
sys 0m3.149s
```
Note that this uses `ccache` and does a bunch of copies and is not just measuring the `nvcc` time. I measured the `nvcc` time separately by capturing the `nvcc` command shown in [1] below and running it on the baseline and modified kernels:
```
# baseline nvcc time for SoftMax.cu
real 0m35.341s
user 0m33.801s
sys 0m1.289s
# this PR's nvcc time for SoftMax.cu
real 0m36.513s
user 0m34.722s
sys 0m1.408s
```
So the `nvcc` time increases by about 1 second, or ~3% of the baseline.
[1] `nvcc` command is here:
```
# This is the nvcc command
/usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/torch/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/torch/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/SoftMax.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/SoftMax.cu.o
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145866
Approved by: https://github.com/ngimel
Adds feature for #98925
Tests pass for both existing reflectionpad2d and the new one I inserted.
**Summary of the work:**
Simple conditional check for deterministic mode that will dispatch to a different kernel. This kernel does not use any atomic operations, and will lead to deterministic results as instead of going from the output to input(1:1) relationship, I am doing the opposite. I am going from input -> all outputs, which is 1 to many. These operations are done in the same order every execution as I simply traverse the data set with a grid stride loop and use simple linearized indexing into the input tensor.
So each thread will compute the 4 conditionals, which are then used to see if the input has an output in the 8 regions. These 8 regions are top left, top, top right, left, right, bottom left, bottom, bottom right`.
I did not focus on performance for this PR as that would expand the scope heavily. If there are any performance questions though i can answer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136241
Approved by: https://github.com/eqy, https://github.com/albanD
Triton 2.2 and greater have a bug where allowing TF32 generation for a GPU that does not support TF32 will cause code generation errors. Patch around this problem by:
1. Adding a function to `torch.cuda` that determines whether CUDA hardware is capable of using the TF32 format.
2. Using that function to explicitly disable TF32 generation when calling Triton, where needed.
To demonstrate that this fix works, try running `test/inductor/test_max_autotune.py` on a GPU with CUDA compute capability < 8 (e.g. any NVIDIA consumer GPU) without this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145684
Approved by: https://github.com/eqy
Fixes#136291
This PR is to fix the `invalid configuration argument` problem happened on ROCm when input is a large tensor when calling `torch.layer_norm`.
```
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/functional.py", line 2573, in layer_norm
return torch.layer_norm
RuntimeError: HIP error: invalid configuration argument
```
After investigation, I found that the reason why this error happened is: The amd compute language runtime checks whether `gridDim.x * blockDim.x` is greater than `std::numeric_limits<uint32_t>::max()` or not. If yes, it will error out with the "invalid configuration argument" message.
The fix is to split the whole task to several chunks so that each chunk will not trigger the failure condition. This will ensure the correctness and completeness given the current kernel implementation logic of `vectorized_layer_norm_kernel`.
Also added a largeTensor layer_norm unit test `test_layer_norm_large_tensor` with the same shape `[16, 3000, 3000, 16]` as the one used by the pytorch issue #136291 so that the unit test can check the expected output value to ensure correctness.
The future work may include performance optimization of layer_norm and CK layer_norm integration.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144007
Approved by: https://github.com/eqy
Fixes#111824
Currently it is the case that if the user specifies their group normalization to be of NHWC format, pytorch will default to NCHW tensors and convert. This conversion is not immediately obvious to the user unless they check the format themselves which is not intuitive. This PR adds suppor for NHWC for cuda by adding necessary kernels.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126635
Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki
#### Summary
This pull request introduces new weighted loss functions to the PyTorch library: `weighted_huber_loss`, `wmse_loss`, and `wmae_loss`. These functions allow for precise control over the influence of each sample during training, important for imbalanced data or when certain samples are more significant than others.
#### Changes
- **`weighted_huber_loss`**: Huber loss modified to incorporate weights, providing a balance between L1 and L2 loss based on the `delta` parameter.
- **`wmse_loss`** (Weighted Mean Squared Error): Applies weights to the standard MSE loss, useful for emphasizing certain samples in regression tasks.
- **`wmae_loss`** (Weighted Mean Absolute Error): Adjusts MAE loss calculation by including weights, ideal for datasets with outliers.
#### Code Details
- **Input Validation**: Ensures `input`, `target`, and `weights` tensors match in size to prevent broadcasting errors.
- **Reduction Options**: Supports `none`, `mean`, and `sum` reductions to suit various computational needs.
- **Backward Compatibility**: Maintains support for deprecated arguments `size_average` and `reduce`, while encouraging use of the `reduction` argument.
#### Usage Example
```python
import torch
input = torch.tensor([0.5, 2.5, 2.0], dtype=torch.float32)
target = torch.tensor([0.0, 2.0, 1.5], dtype=torch.float32)
weights = torch.tensor([1.0, 0.5, 1.5], dtype=torch.float32)
loss = weighted_huber_loss(input, target, weights, delta=1.0)
print(loss)
```
---
Feedback on these implementations is welcome; please let me know if further modifications are required.
Resolves#132465
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132049
Approved by: https://github.com/mikaylagawarecki
Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
More or less literal copy-n-paste of c33b0580e6/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu (L24)
and
c33b0580e6/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu (L99)
Missing `uint8` implementation mimics CUDA behavior
Initial version coded live in https://www.youtube.com/watch?v=shi6Kb5xxvk
Later refinements:
- Switch from 2D dispatch to 1D one (to match CUDA behavior)
- Added batch + channel loops
- Fixed scale computation to match align corners behavior
- Added backward implementation
Backward implementation again, mimics CUDA, so it has issues precision issue for `torch.half` as well as a somewhat slow simulation of atomic adds using atomic compare and exchange of the pair of adjacent values, i.e.
```metal
emplate <typename T>
static inline void atomic_add_helper(
device atomic<int>* data,
long offset,
float value) {
auto ptr = data + (offset >> 1);
auto old = atomic_load_explicit(ptr, memory_order_relaxed);
union {
int i;
T t[2];
} val;
do {
val.i = old;
val.t[offset & 1] += static_cast<T>(value);
} while (!atomic_compare_exchange_weak_explicit(
ptr, &old, val.i, memory_order_relaxed, memory_order_relaxed));
}
```
Bump basic Metal language version to 3.0, as it's supported on MacOS13 and that's the first version that has `atomic_float`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136123
Approved by: https://github.com/albanD
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
+ Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`
Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.
Fixes#133540
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet