This new assertion helper bundles a printf call with the assertion. The goal is to make changes to instrument asserts with device-side information more intuitive and less error-prone. (See the printf call in ATen/native/cuda/Repeat.cu.) Parametrized error messages are a substantial improvement in debuggability because they show the mismatched device-side values. This lets us avoid a whole cycle of rebuilding + re-running failing training workflows.
We include file, line number, function, and failing condition in the printf (along with the message provided by the user). The format matches the format of the message output by `__assert_fail`. There's also an easy-to-grep-for keyword `CUDA_KERNEL_ASSERT` in the message.
I'm following the existing patterns of arch-specific macros - e.g., on ROCm, this is just a call to abort(), just like the other `CUDA_KERNEL_ASSERT*` variations. I'd appreciate any thoughts on architecture-specific testing (most likely on the OSS side).
# Alternatives
* We could just update `CUDA_KERNEL_ASSERT_MSG`. That would mean introducing `printf` calls from the kernel where there weren't any before, though. This seems like a bad idea because of the performance sensitivity.
* If we want to move more slowly here, I could instrument more `CUDA_KERNEL_ASSERT` callsites without a macro, similar to https://github.com/pytorch/pytorch/pull/157996. But the main downside here is the performance hit, so let's have an organized way of doing it first.
# Risks/Problems
* We're shoving a lot of stuff into this printf. If a filename (at compile-time) contains `%s`, we will end up dereferencing whatever value was pushed in. On a CPU this can cause a segfault. I don't know how it behaves on a GPU.
* Adding printf calls can have a performance impact because of increased register and stack usage. I did not see this play out in practice (see "benchmarks" below). However, there are changes to the generated PTX that could result in performance problems later (see "changes in generated PTX" below).
# Benchmarks
* I ran the following benchmarks a several times on a host with an A100: https://gist.github.com/mjkatmeta/e5494d949204a2afe2d43c452b99424f
* Results are here -- I couldn't find a significant difference before or after https://gist.github.com/mjkatmeta/0f99ec27bb91214fb2cc7f612938d431
# Change in generated PTX
This is the easiest way I found to run nvcc over just Repeat.cu (this is a buck2 target that includes just a copy of Repeat.cu):
```
buck2 build --show-output scripts/mjk/ai_training/cuda_benchmarks:repeat_cuda
# then use the printed .so file like this:
~/fbsource/third-party/cuda/cuda_12.8.0/x64-linux/bin/cuobjdump -ptx ../buck-out/v2/gen/fbcode/028bde1acfaba823/scripts/mjk/ai_training/cuda_benchmarks/__repeat_cuda__/libscripts_mjk_ai_training_cuda_benchmarks_repeat_cuda.so
```
## with printf
This is the version of the code that appears in this diff:
https://gist.github.com/mjkatmeta/5d18d48282d46b2240d946b335052b9a
## without printf
I recompiled, replacing `CUDA_KERNEL_ASSERT_PRINTF(...)` in Repeat.cu with:
```
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]);
```
https://gist.github.com/mjkatmeta/480df4b3a122e7b326554dd15ebb7c9d
(Both of these are annotated with `// CHAR ARRAY:` comments to make the string constants easier to read.)
Test Plan:
Running this minimal test case:
```
import torch
def main():
x = torch.ones(10, dtype=torch.int64, device="cuda:0")
torch.repeat_interleave(x, x, output_size=0)
```
Now we see the new message (from printf) alongside the assert failure:
```
$ buck2 run fbcode//scripts/darshanr/repeat_interleave_errors:repeat_interleave_errors
[...]
[CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/Repeat.cu:25: compute_cuda_kernel: block: [0,0,0], thread: [31,0,0]: Assertion failed: `result_size == cumsum_ptr[size - 1]`: Invalid input! In `repeat_interleave`, the `output_size` argument (0) must be the same as the sum of the elements in the `repeats` tensor (10).
fbcode/caffe2/aten/src/ATen/native/cuda/Repeat.cu:25: compute_cuda_kernel: block: [0,0,0], thread: [384,0,0] Assertion `result_size == cumsum_ptr[size - 1]` failed.
[...[
```
Rollback Plan:
Reviewed By: mradmila
Differential Revision: D79310684
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160129
Approved by: https://github.com/ngimel
This PR is a big copy pasta from `c10/util/Float8*` -> `torch/headeronly/util/` which is why we are breaking PR sanity :C (sorry @albanD!).
Why is it not a clean copy paste?
- For BC reasons, we have to keep the old c10 file around so that OSS devs relying on those files can still get the same APIs
- Because we reexpose APIs that are headeronly through torch::headeronly, so there is an extra chunk of code in the new torch::headeronly files to do that.
Outside of the copy paste, I:
- changed the tests to call torch::headeronly instead of c10
- updated header_only_apis.txt
- added `// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)` to pass lint (which was previously skipped for -inl.h files)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159415
Approved by: https://github.com/albanD
Essence of this copypasta:
- combine Half-inl.h and Half.h in c10/util -> torch/headeronly/util/Half.h
- Add NOLINTNEXTLINE's to the portions of Half-inl.h that were previously in the ignore list of clangtidy
- Re-expose all APIs in namespaces and through includes of the original files. Ideally, we would have the APIs in torch::headeronly and reexpose them in c10, but that runs into BC issues (see D78997465) so for now we are keeping the APIs in c10 but reexposing them in torch::headeronly.
- Change test cases in test_aoti_abi_check to test torch::headeronly::Half vs c10::Half (they're the same thing but we eventually want all the tests for headeronly APIs to only import from headeronly).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159172
Approved by: https://github.com/albanD, https://github.com/desertfire
Essence of this copypasta:
- combine Half-inl.h and Half.h in c10/util -> torch/headeronly/util/Half.h
- Add NOLINTNEXTLINE's to the portions of Half-inl.h that were previously in the ignore list of clangtidy
- Re-expose all APIs in namespaces and through includes of the original files. Ideally, we would have the APIs in torch::headeronly and reexpose them in c10, but that runs into BC issues (see D78997465) so for now we are keeping the APIs in c10 but reexposing them in torch::headeronly.
- Change test cases in test_aoti_abi_check to test torch::headeronly::Half vs c10::Half (they're the same thing but we eventually want all the tests for headeronly APIs to only import from headeronly).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159172
Approved by: https://github.com/albanD, https://github.com/desertfire
Straightup copy pasta. Keeps APIs in c10 and reexposes them to torch::headeronly.
It is arguable that we should just get rid of some of these unused dtypes but that is outside the scope of this PR, which is meant to build up to ScalarType moving to headeronly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159302
Approved by: https://github.com/malfet, https://github.com/albanD
Summary: __assert_fail is declared slightly differently in the Emscripten stdlib. This may cause errors when compiling with Emscripten.
Test Plan:
N/A
Rollback Plan:
Differential Revision: D78500790
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158580
Approved by: https://github.com/JacobSzwejbka
# Background
The `C10_WARP_SIZE`, although always be `32` on CUDA platform, varies across different AMD GPUs.
Therefore, to correctly refer this value, the host code must be a variable instead of a literal defined by macro, or a `constexpr int`.
This PR may cause more compiler errors for third party code on AMD GPU, which is intentional. Having a fixed `C10_WARP_SIZE` value on host code for AMD GPU only defers compile time error to runtime.
This PR is recommended to be included as part of Release Notes to describe an API change for whoever uses this macro.
Users are recommended to use `C10_WARP_SIZE` directly, which adapts for various scenarios, or define a macro to use `C10_WARP_SIZE`. Assignment of this macro to symbols shared by host/device code causes problems on ROCM platform. (See the fix at `aten/src/ATen/native/cuda/layer_norm_kernel.cu` for a concrete example)
# Behaviors
* If compiling with HIPCC (i.e `defined(__HIPCC__)`):
+ Define `C10_WARP_SIZE` to be non-`constexpr` `at::cuda::warp_size()` for host-compilation pass (as compared to `static constexpr int C10_WARP_SIZE = 1;` set in 04bd7e6850e8efec77994963ffee87549555b9c3)
+ Define `C10_WARP_SIZE` to be a function returning `constexpr int` `64` for `__GFX9__`, and `32` otherwise, for device-compilation pass
- `__GFX8__` is also 64 but we do not support any GFX8 GPU.
* If not compiling with HIPCC:
+ Define `C10_WARP_SIZE` to be non-constexpr `at::cuda::warp_size()`
# `constexpr` variant for host code
For host-compilation cases where a `constexpr` value is needed for warp size (eg. launch bounds), use `C10_WARP_SIZE_STATIC`, which is defined as `64`. This macro follows the pre 04bd7e6850e8efec77994963ffee87549555b9c3 behavior of `C10_WARP_SIZE`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158271
Approved by: https://github.com/jeffdaily
Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
Summary: As above, also changes a bunch of the build files to be better
Test Plan:
internal and external CI
did run buck2 build fbcode//caffe2:torch and it succeeded
Rollback Plan:
Reviewed By: swolchok
Differential Revision: D78016591
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158035
Approved by: https://github.com/swolchok