Compare commits

...

7 Commits

Author SHA1 Message Date
5811a8d7da [cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (#167327)
[cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (#167111)

cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve https://github.com/pytorch/pytorch/issues/166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167111
Approved by: https://github.com/Skylion007

(cherry picked from commit e678450a69f6bf3b6f3ea7657d444ce9bba19940)

Co-authored-by: Eddie Yan <eddiey@nvidia.com>
2025-11-07 17:04:27 -05:00
f36c764ca4 [dynamo][ez] Initialize tracer_output to None by default. (#167366)
[dynamo][ez] Initialize tracer_output to None by default. (#163169)

Summary:
In edge cases, tracer_output can be left unset if there's double exception raised which causes the following issue:
```
UnboundLocalError: local variable 'tracer_output' referenced before assignment
```

Default initialize this variable so that it's always present.

Test Plan:
CI

Rollback Plan:

Differential Revision: D82652815

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163169
Approved by: https://github.com/tugsbayasgalan

(cherry picked from commit 6189a5f7315ac5affdaeafdbea0a85d14925506d)

Co-authored-by: Zhengxu Chen <zhxchen17@meta.com>
2025-11-07 17:03:33 -05:00
6877288115 Change forkserver test to only run below 3.13.8 (#167361)
Change forkserver test to only run below 3.13.8 (#165667)

A multiprocessing bug is fixed in 3.13.8, see [https://docs.python.org/3.13/whatsnew/changelog.html](https://l.workplace.com/l.php?u=https%3A%2F%2Fdocs.python.org%2F3.13%2Fwhatsnew%2Fchangelog.html&h=AT0qUhHJq5c2UJvQaq9_MrSo0mVhwn1VOfq1nDQl2C1UOhDI80RMbzVayhG7LSAT1uYHKtkftKnBDwiGMhbw0YRvQLe5vwE01qejpPFautHvU3LXeOE1KChPykqz3qnCRzk7czu_iNzQ05shR4F1N_qYOzR5YxejA52ZZQ), [gh-126631](https://github.com/python/cpython/issues/126631)

So this test will fail when we update to python 3.13.8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165667
Approved by: https://github.com/malfet

(cherry picked from commit d4a713cd9c8ea1dc13917d3311d73c13914306a6)

Co-authored-by: Shangdi Yu <shangdiy@meta.com>
2025-11-07 16:58:33 -05:00
9976b77abb Cherry-pick LibTorch Stable ABI documentation (#167112 #166661 #163899) (#167323)
* [BE] Refresh documentation for stable ABI / API (#163899)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163899
Approved by: https://github.com/janeyx99

* Document LibTorch ABI more, add README to headeronly (#166661)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166661
Approved by: https://github.com/mikaylagawarecki, https://github.com/albanD

* Add guidance on how to migrate kernels to the libtorch stable ABI (#167112)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167112
Approved by: https://github.com/janeyx99

---------

Co-authored-by: Jane Xu <janeyx@meta.com>
2025-11-07 11:35:34 -05:00
e6bcbbe17c [Inductor] No longer throw error in bmm out_dtype lowering due to tem… (#166922)
[Inductor] No longer throw error in bmm out_dtype lowering due to template heuristics (#166457)

Fixes https://github.com/pytorch/pytorch/issues/165892

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166457
Approved by: https://github.com/coconutruben

(cherry picked from commit c2e3cc7aedb2e7d89443225c7cccd08a0f8a3587)

Co-authored-by: PaulZhang12 <paulzhan@fb.com>
2025-11-07 11:30:59 -05:00
8f658d7599 don't produce invalid grid configs (#166973) (#167158)
Proper fix for #164048, fixes gather too, reverts #164049
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166974
Approved by: https://github.com/eqy
2025-11-06 12:36:55 -05:00
3d27d955fd [GraphPartition] cache get_free_symbol_uses (#166338) (#166994)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
ee7434be82/torch/_inductor/ir.py (L4541-L4543)

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338
Approved by: https://github.com/eellison

(cherry picked from commit dfebdcab86acbaa0eaa996b47595e5f27a66492e)

Co-authored-by: Boyuan Feng <boyuan@meta.com>
2025-11-06 11:39:46 -05:00
18 changed files with 459 additions and 34 deletions

View File

@ -155,6 +155,12 @@ class TORCH_API Context {
static long versionCuDNN() {
return detail::getCUDAHooks().versionCuDNN();
}
static long versionRuntimeCuDNN() {
return detail::getCUDAHooks().versionRuntimeCuDNN();
}
static long versionCuDNNFrontend() {
return detail::getCUDAHooks().versionCuDNNFrontend();
}
static bool hasCuSOLVER() {
return detail::getCUDAHooks().hasCuSOLVER();
}

View File

@ -21,6 +21,7 @@
#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#endif
#if AT_MAGMA_ENABLED()
@ -325,6 +326,26 @@ long CUDAHooks::versionCuDNN() const {
#endif
}
long CUDAHooks::versionRuntimeCuDNN() const {
#if AT_CUDNN_ENABLED()
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
#else
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionCuDNNFrontend() const {
#if AT_CUDNN_ENABLED()
return CUDNN_FRONTEND_VERSION;
#else
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionMIOpen() const {
#if AT_ROCM_ENABLED()
return MIOPEN_VERSION_MAJOR * 10000 +

View File

@ -48,6 +48,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;
long versionRuntimeCuDNN() const override;
long versionCuDNNFrontend() const override;
long versionMIOpen() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;

View File

@ -170,6 +170,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionRuntimeCuDNN() const {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionCuDNNFrontend() const {
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionMIOpen() const {
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -413,7 +413,7 @@ struct ConvParams {
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
@ -457,7 +457,7 @@ struct ConvParams {
}
// native kernel doesn't support 64-bit non-splittable case
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
if (cudnn_version < 0 || cudnn_version > 91000) {

View File

@ -73,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
if (is_gather_like && num_indices==1) {
const size_t element_size = iter.element_size(0);
constexpr size_t alignment = 16;
@ -83,11 +82,10 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
auto ind_dim_size = index_size[0];
auto inp_stride_bytes = index_stride[0];
auto out_stride_bytes = iter.strides(0)[1];
if (iter.numel() == 0) return;
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
return;
}
}
}
auto sizes = std::array<int64_t, MAX_DIMS>{};

View File

@ -14,10 +14,11 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
ind = (ind < 0) ? ind + ind_dim_size : ind;
}
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
if (off >= slice_size) return;
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
// off is guaranteed to be within int32 limits
for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) {
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
}
}
@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int
auto num_threads = at::round_up(
at::ceil_div(slice_size_in_bytes, Alignment),
static_cast<int64_t>(C10_WARP_SIZE));
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
grid_y = std::min(static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y);
dim3 grid = {static_cast<uint32_t>(num_ind), grid_y, 1};
auto block = std::min(max_num_threads, num_threads);
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);

View File

@ -437,7 +437,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
const auto s_k = params.key.sym_size(2);
const auto d_qk = params.query.sym_size(3);
const auto d_v = params.value.sym_size(3);
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
if (cudnn_version < 8903) {
if (debug) {
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
@ -668,7 +668,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
return false;
#endif
#if defined(CUDNN_VERSION)
static auto cudnn_version = cudnnGetVersion();
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
if (debug) {
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");

View File

@ -1,6 +1,173 @@
# LibTorch Stable ABI
This note will eventually contain more details on how to use the APIs in torch/csrc/stable. For the moment, it contains a table of internal representations:
## Overview
The LibTorch Stable ABI (Application Binary Interface) provides a limited interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. This limited set of APIs is not intended to replace existing LibTorch, but rather to provide a stable foundation for a majority of custom extension use cases. If there is any API you would like to see added to the stable ABI, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
The limited stable ABI consists of three main components:
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
3. **Stable C++ wrappers** - High-level C++ convenience wrappers (`torch/csrc/stable/*`)
We discuss each of these in detail
### `torch/headeronly`
The inlined C++ headers living in [`torch/headeronly`](https://github.com/pytorch/pytorch/tree/main/torch/headeronly) are completely decoupled from LibTorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`, as well as a libtorch-independent version of `TORCH_CHECK` that is `STD_TORCH_CHECK`. You can trust all APIs in the `torch::headeronly` namespace to not depend on `libtorch.so`. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
### `torch/csrc/stable`
This is a set of inlined C++ headers that provide wrappers around the C API that handle the rough edges
discussed below.
It consists of
- torch/csrc/stable/library.h: Provides a stable version of TORCH_LIBRARY and similar macros.
- torch/csrc/stable/tensor_struct.h: Provides torch::stable::Tensor, a stable version of at::Tensor.
- torch/csrc/stable/ops.h: Provides a stable interface for calling ATen ops from `native_functions.yaml`.
- torch/csrc/stable/accelerator.h: Provides a stable interface for device-generic objects and APIs
(e.g. `getCurrentStream`, `DeviceGuard`).
We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please file an issue if you'd like to see support for particular APIs in your custom extension.
### Stable C headers
The stable C headers started by AOTInductor form the foundation of the stable ABI. Presently, the available C headers include:
- [torch/csrc/inductor/aoti_torch/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/c/shim.h): Includes C-style shim APIs for commonly used regarding Tensors, dtypes, CUDA, and the like.
- [torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h): Includes C-style shim APIs for ATen ops from `native_functions.yaml` (e.g. `aoti_torch_aten_new_empty`).
- [torch/csrc/inductor/aoti_torch/generated/c_shim_*.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated): Includes C-style shim APIs for specific backend kernels dispatched from `native_functions.yaml` (e.g. `aoti_torch_cuda_pad`). These APIs should only be used for the specific backend they are named after (e.g. `aoti_torch_cuda_pad` should only be used within CUDA kernels), as they opt out of the dispatcher.
- [torch/csrc/stable/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/stable/c/shim.h): We are building out more ABIs to logically live in `torch/csrc/stable/c` instead of continuing the AOTI naming that no longer makes sense for our general use case.
These headers are promised to be ABI stable across releases and adhere to a stronger backwards compatibility policy than LibTorch. Specifically, we promise not to modify them for at least 2 years after they are released. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. Further, the stack-based APIs discussed below which allow the user to call into the PyTorch dispatcher do not provide strong guarantees on forward and backward compatibility of the underlying op that is called.
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
which will handle all the rough edges of the C API for the user.
## Migrating your kernel to the LibTorch stable ABI
If you'd like your kernel to be ABI stable with LibTorch, meaning you'd the ability to build for one version and run on another, your kernel must only use the limited stable ABI. This following section goes through some steps of migrating an existing kernel and APIs we imagine you would need to swap over.
Firstly, instead of registering kernels through `TORCH_LIBRARY`, LibTorch ABI stable kernels must be registered via `STABLE_TORCH_LIBRARY`. Note that, for the time being, implementations registered via `STABLE_TORCH_LIBRARY` must be boxed unlike `TORCH_LIBRARY`. See the simple example below or our docs on [Stack-based APIs](stack-based-apis) for more details. For kernels that are registered via `pybind`, before using the stable ABI, it would be useful to migrate to register them via `TORCH_LIBRARY`.
While previously your kernels might have included APIs from `<torch/*.h>` (for example, `<torch/all.h>`), they are now limited to including from the 3 categories of headers mentioned above (`torch/csrc/stable/*.h`, `torch/headeronly/*.h` and the stable C headers). This means that your extension should no longer use any utilities from the `at::` or `c10::` namespaces but instead use their replacements in `torch::stable` and `torch::headeronly`. To provide a couple examples of the necessary migrations:
- all uses of `at::Tensor` must be replaced with `torch::stable::Tensor`
- all uses of `TORCH_CHECK` must be replaced with `STD_TORCH_CHECK`
- all uses of `at::kCUDA` must be replaced with `torch::headeronly::kCUDA` etc.
- native functions such as `at::pad` must be replaced with `torch::stable::pad`
- native functions that are called as Tensor methods (e.g., `Tensor.pad`) must be replaced with the ATen variant through `torch::stable::pad`.
As mentioned above, the LibTorch stable ABI is still under development. If there is any API or feature you would like to see added to the stable ABI/`torch::headeronly`/`torch::stable`, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
Below is a simple example of migrating an existing kernel that uses `TORCH_LIBRARY` to the stable ABI (`TORCH_STABLE_LIBRARY`). For a larger end to end example you can take a look at the FA3 repository. Specifically the diff between [`flash_api.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api.cpp#L1) and the stable variant [`flash_api_stable.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api_stable.cpp#L1).
### Original Version with `TORCH_LIBRARY`
```cpp
// original_kernel.cpp - Using TORCH_LIBRARY (not stable ABI)
#include <torch/torch.h>
#include <ATen/ATen.h>
namespace myops {
// Simple kernel that adds a scalar value to each element of a tensor
at::Tensor add_scalar(const at::Tensor& input, double scalar) {
TORCH_CHECK(input.scalar_type() == at::kFloat, "Input must be float32");
return input.add(scalar);
}
// Register the operator
TORCH_LIBRARY(myops, m) {
m.def("add_scalar(Tensor input, float scalar) -> Tensor", &add_scalar);
}
// Register the implementation
TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) {
m.impl("add_scalar", &add_scalar);
}
} // namespace myops
```
### Migrated Version with `STABLE_TORCH_LIBRARY`
```cpp
// stable_kernel.cpp - Using STABLE_TORCH_LIBRARY (stable ABI)
// (1) Don't include <torch/torch.h> <ATen/ATen.h>
// only include APIs from torch/csrc/stable, torch/headeronly and C-shims
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor_struct.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
namespace myops {
// Simple kernel that adds a scalar value to each element of a tensor
torch::stable::Tensor add_scalar(const torch::stable::Tensor& input, double scalar) {
// (2) use STD_TORCH_CHECK instead of TORCH_CHECK
STD_TORCH_CHECK(
// (3) use torch::headeronly::kFloat instead of at:kFloat
input.scalar_type() == torch::headeronly::kFloat,
"Input must be float32");
// (4) Use stable ops namespace instead of input.add
return torch::stable::add(input, scalar);
}
// (5) Add Boxed wrapper required for STABLE_TORCH_LIBRARY
void boxed_add_scalar(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
// Extract arguments from stack using `to<T>`
auto input = to<torch::stable::Tensor>(stack[0]);
auto scalar = to<double>(stack[1]);
// Call the actual kernel
auto result = add_scalar(input, scalar);
// Put result back on stack using `from()`
// Stack slot 0 now holds the return value
stack[0] = from(result);
}
// (6) Register the operator using STABLE_TORCH_LIBRARY
STABLE_TORCH_LIBRARY(myops, m) {
m.def("add_scalar(Tensor input, float scalar) -> Tensor", &boxed_add_scalar);
}
// (7) Register the implementation using STABLE_TORCH_LIBRARY_IMPL
STABLE_TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) {
m.impl("add_scalar", &boxed_add_scalar);
}
} // namespace myops
```
## How are objects passed across the ABI boundary when interacting with the dispatcher?
When interacting with the dispatcher via the stable APIs (``STABLE_TORCH_LIBRARY`` etc.) we use a boxed convention. Arguments and returns are represented as a stack of ``StableIValue`` which correlates with a `torch::jit::stack` of IValues. We discuss the following below
1. StableIValue Conversions
2. StableIValue stack Conventions
3. Stable APIs that interact with the dispatcher
### StableIValue Conversions
We provide utilities for users to convert objects to and from StableIValues with the synonymous
`to` and `from` APIs in `torch/csrc/stable/stableivalue_conversions.h`. We document the stable custom extension representation, libtorch representation and StableIValue
representations below. Our confidently supported types are the ones in the table that have completed
rows. You can rely on this subset for proper ABI stability, meaning that you can call `to<T_custom_ext>(arg/ret)` or `from(T)` on these types.
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
1. type in custom extension: type used within the end user custom library.
2. StableIValue representation: a stable conversion of the type to liaison between the user model vs libtorch.so in an ABI-stable manner.
3. type in libtorch: type used within libtorch.so (or any code binary locked with libtorch).
@ -31,16 +198,10 @@ This note will eventually contain more details on how to use the APIs in torch/c
| ? | ? | c10::SymBool | SymBool |
| ? | ? | at::QScheme | QScheme |
Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset for proper ABI stability.
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
### Stack Conventions
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
## How to use stack-based APIs
`aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues, which correlates with a `torch::jit::stack` of IValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants:
There are two invariants for the stack:
1. The stack is populated left to right.
a. For example, a stack representing arguments `arg0`, `arg1`, and `arg2` will have `arg0` at index 0, `arg1` at index 1, and `arg2` at index 2.
@ -49,3 +210,33 @@ You can always work with StableIValue abstractions in your custom kernel for typ
2. The stack always has ownership of the objects it holds.
a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack.
b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references.
(stack-based-apis)=
### Stack-based APIs
The above is relevant in two places:
1. `STABLE_TORCH_LIBRARY`
Unlike `TORCH_LIBRARY`, the dispatcher expects kernels registered via `STABLE_TORCH_LIBRARY` to be boxed. This means they must have the signature `(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) -> void`.We plan to eventually abstract away the need for manual boxing, but, for the time being, please use `from` and `to`.
```cpp
Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
}
```
2. `aoti_torch_call_dispatcher`
This API allows you to call the PyTorch dispatcher from C/C++ code. It has the following signature:
```cpp
aoti_torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack);
```
`aoti_torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, and a stack of
StableIValues. This call will populate any return values of the op into the stack in their StableIValue form,
with `ret0` at index 0, `ret1` at index 1, and so on.

View File

@ -1479,6 +1479,29 @@ class TestMaxAutotune(TestCase):
# Check that contiguous transform was used
FileCheck().check("contiguous_mm").run(code[0])
@unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI")
@unittest.skipIf(TEST_WITH_ROCM, "out_dtype override only available on NVIDIA")
def test_bmm_out_dtype(self):
def f(a, b):
return torch.bmm(a, b, out_dtype=torch.float32)
a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16)
b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16)
with config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
):
compiled_f = torch.compile(f)
with self.assertRaisesRegex(
torch._inductor.exc.InductorError,
r"LoweringException: NoValidChoicesError: No choices to select",
):
out, code = run_and_get_code(compiled_f, a, b)
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)
FileCheck().check("extern_kernels.bmm_dtype").run(code[0])
def test_triton_template_generated_code_cache_key(self):
generate_and_load_args = len(
inspect.signature(

View File

@ -265,6 +265,12 @@ class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
)
class ParallelForkServerPerfTest(TestCase):
@unittest.skipIf(
sys.version_info >= (3, 13, 8),
"Python 3.13.8+ changed forkserver module caching behavior",
# https://docs.python.org/3.13/whatsnew/changelog.html
# gh-126631
)
def test_forkserver_perf(self):
start_method = 'forkserver'

View File

@ -6,7 +6,7 @@ import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import \
(parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM)
(parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM, serialTest)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA,
toleranceOverride, tol,)
@ -65,10 +65,12 @@ class TestScatterGather(TestCase):
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, atol=0, rtol=0)
@serialTest()
@dtypes(torch.int8, torch.bfloat16)
def test_gather_large(self, device, dtype):
# test larger shapes to check vectorized implementation
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)):
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100), (4, 4, 16384 * 8192)):
torch.cuda.empty_cache()
src = make_tensor((m, k), device=device, dtype=dtype)
alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype)
discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src)
@ -111,6 +113,8 @@ class TestScatterGather(TestCase):
self.assertEqual(res_ind, ref, atol=0, rtol=0)
res_gather = torch.gather(misaligned1, dim=dim, index=ind)
self.assertEqual(res_gather, ref, atol=0, rtol=0)
del src, alloc0, alloc1, alloc2
del discontig, misaligned, misaligned1
# test gather along 1st dim that can accidentally trigger fast path
# because due to index dimension in the gather dim being 1
# an unexpected squashing in tensorIterator happens

View File

@ -1429,6 +1429,7 @@ def _compile(
fail_user_frame_lineno: Optional[int] = None
torch._dynamo.utils.ReinplaceCounters.clear()
guarded_code = None
tracer_output = None
try:
guarded_code, tracer_output = compile_inner(code, one_graph, hooks)

View File

@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
free_symbols,
free_unbacked_symbols,
IterateExprs,
rebind_unbacked,
resolve_unbacked_bindings,
ShapeEnv,
@ -97,6 +98,7 @@ from .utils import (
argsort,
argsort_sym,
cache_on_self,
cache_on_self_and_args,
ceildiv,
convert_shape_to_inductor,
convert_shape_to_symint,
@ -933,6 +935,7 @@ class Loops(IRNode):
inner_fn: Callable[..., Any]
ranges: Sequence[_IntLike]
@cache_on_self_and_args("Loops")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -1222,6 +1225,7 @@ class Reduction(Loops):
__repr__ = __str__
@cache_on_self_and_args("Reduction")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
*(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
@ -2311,6 +2315,7 @@ class Scan(Loops):
# HACK we mimic reduction
@cache_on_self_and_args("Scan")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
# need to explicitly represent the closure so we can pull out unbacked
@ -2520,6 +2525,7 @@ class Sort(Loops):
# HACK we mimic reduction
@cache_on_self_and_args("Sort")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
return (
super().get_free_symbol_uses(unbacked_only)
@ -2768,6 +2774,7 @@ def is_unaligned(node: IRNode) -> bool:
class BaseView(IRNode):
data: IRNode
@cache_on_self_and_args("BaseView")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
return self.data.get_free_symbol_uses(unbacked_only)
@ -3334,6 +3341,7 @@ class ReinterpretView(BaseView):
def freeze_layout(self) -> None:
pass
@cache_on_self_and_args("ReinterpretView")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -3617,13 +3625,37 @@ class Layout(OutputSpec):
self.dtype = dtype
assert len(size) == len(stride), f"size={size}, stride={stride}"
assert all(isinstance(s, (Expr, int)) for s in size)
self.size = size
self.stride = stride
self.offset = offset
self._size = size
self._stride = stride
self._offset = offset
self.is_pinned = is_pinned
# is_pinned implies cpu
assert (not self.is_pinned) or (self.device.type == "cpu")
@property
def size(self) -> Sequence[Expr]:
return self._size
@size.setter
def size(self, value: Sequence[Expr]) -> None:
self._size = value
@property
def stride(self) -> Sequence[Expr]:
return self._stride
@stride.setter
def stride(self, value: Sequence[Expr]) -> None:
self._stride = value
@property
def offset(self) -> Expr:
return self._offset
@offset.setter
def offset(self, value: Expr) -> None:
self._offset = value
def __str__(self) -> str:
offset = ""
if self.offset != 0:
@ -3833,6 +3865,7 @@ class Layout(OutputSpec):
def storage_size(self) -> Expr:
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
@cache_on_self_and_args("Layout")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -3852,7 +3885,11 @@ class FixedLayout(Layout):
class FlexibleLayout(Layout):
"""A Tensor layout that we are allowed to change"""
"""
A Tensor layout that we are allowed to change
Assumption: layout change should NOT add or remove free symbols
"""
allow_indexing = False
@ -3937,6 +3974,33 @@ class FlexibleLayout(Layout):
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
return FlexibleLayout.fill_ordered(sizes, fill_order)
@property
def size(self) -> Sequence[Expr]:
return self._size
@size.setter
def size(self, value: Sequence[Expr]) -> None:
self.assert_free_symbol_uses_unchanged("size", value)
self._size = value
@property
def stride(self) -> Sequence[Expr]:
return self._stride
@stride.setter
def stride(self, value: Sequence[Expr]) -> None:
self.assert_free_symbol_uses_unchanged("stride", value)
self._stride = value
@property
def offset(self) -> Expr:
return self._offset
@offset.setter
def offset(self, value: Expr) -> None:
self.assert_free_symbol_uses_unchanged("offset", value)
self._offset = value
def as_stride_order(
self, order: Sequence[int], allow_padding: bool = False
) -> FixedLayout:
@ -3995,6 +4059,25 @@ class FlexibleLayout(Layout):
self.is_pinned,
)
def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
initial_free_symbols = {}
for name in ["size", "stride", "offset"]:
for unbacked_only in [True, False]:
key = (name, unbacked_only)
initial_free_symbols[key] = OrderedSet(
get_free_symbols(getattr(self, name), unbacked_only)
)
return initial_free_symbols
def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
for unbacked_only in [True, False]:
old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
assert new_free_symbols == old_free_symbols, (
f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
)
def __init__(
self,
device: torch.device,
@ -4009,6 +4092,10 @@ class FlexibleLayout(Layout):
strides = FlexibleLayout.contiguous_strides(size)
super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
# record the initial free symbols to check that we do not add new free symbols
# later when modifying sizes, strides, and offsets.
self.initial_free_symbols = self.get_initial_free_symbol_uses()
class NonOwningLayout(Layout):
"""Is a view into the storage of another tensor"""
@ -4034,6 +4121,7 @@ class NonOwningLayout(Layout):
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
@cache_on_self_and_args("NonOwningLayout")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4322,6 +4410,7 @@ class Buffer(IRNode, CodegenSymbol):
def get_read_names(self) -> OrderedSet[str]:
return OrderedSet([self.get_name()])
@cache_on_self_and_args("Buffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4394,6 +4483,7 @@ class NoneAsConstantBuffer(IRNode):
def get_reads(self) -> OrderedSet[Dep]:
return OrderedSet()
@cache_on_self_and_args("NoneAsConstantBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4413,6 +4503,7 @@ class NoneAsConstantBuffer(IRNode):
class ShapeAsConstantBuffer(IRNode):
expr: Expr
@cache_on_self_and_args("ShapeAsConstantBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4485,6 +4576,7 @@ class ComputedBuffer(OperationBuffer):
self.data.get_size(),
)
@cache_on_self_and_args("ComputedBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4912,6 +5004,7 @@ class TritonTemplateBuffer(TemplateBuffer):
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
@cache_on_self_and_args("TritonTemplateBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -5264,6 +5357,7 @@ class InputsKernel(OperationBuffer):
def num_reads(self) -> int:
return 1
@cache_on_self_and_args("InputsKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -5438,6 +5532,7 @@ class ConcatKernel(NopKernel):
and not isinstance(src.data, ExternKernelAlloc)
)
@cache_on_self_and_args("ConcatKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -6337,6 +6432,7 @@ class ExternKernel(InputsKernel):
index = sympy_subs(sympy.expand(index), replacement)
return index, tuple(new_sizes)
@cache_on_self_and_args("ExternKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -6797,6 +6893,7 @@ class UserDefinedTritonKernel(ExternKernel):
original_fxnode_name=self.fx_node.name,
)
@cache_on_self_and_args("UserDefinedTritonKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -7265,6 +7362,7 @@ class DynamicSelectStorageOffset(ExternKernel):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_offset_symbol])
@cache_on_self_and_args("DynamicSelectStorageOffset")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -7327,6 +7425,7 @@ class AssertScalar(ExternKernel):
def has_side_effects(self) -> bool:
return True
@cache_on_self_and_args("AssertScalar")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -7999,6 +8098,7 @@ class MultiOutput(ExternKernel):
self.indices = indices
self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
@cache_on_self_and_args("MultiOutput")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -8121,6 +8221,7 @@ class MutableBox(IRNode):
def realize(self) -> Optional[str]:
return self.data.realize()
@cache_on_self_and_args("MutableBox")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -8919,6 +9020,7 @@ class EffectfulKernel(FallbackKernel):
class NonTensorObj(IRNode):
@cache_on_self_and_args("NonTensorObj")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:

View File

@ -208,9 +208,10 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
)
)
if use_triton_template(layout, check_max_autotune=False):
if use_triton_template(layout, check_max_autotune=False) and (
out_dtype is None or out_dtype == mat1.get_dtype()
):
# TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton"
choices.extend(
V.choices.get_mm_configs(kernel_inputs, layout, [bmm_template], name)

View File

@ -626,6 +626,7 @@ def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
P = ParamSpec("P")
RV = TypeVar("RV", covariant=True)
FN_TYPE = Callable[Concatenate[Any, P], RV]
class CachedMethod(Protocol, Generic[P, RV]):
@ -665,6 +666,60 @@ def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
return wrapper # type: ignore[return-value]
def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
"""
Variant of cache_on_self for properties. The only difference is the type signature.
"""
# pyrefly: ignore [bad-argument-type]
return cache_on_self(fn)
def cache_on_self_and_args(
class_name: str,
) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
# include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
def wrapper(
fn: FN_TYPE[P, RV],
) -> FN_TYPE[P, RV]:
key = f"__{class_name}_{fn.__name__}_cache"
# wrapper is likely on the hot path, compile a specialized version of it
ctx = {"fn": fn}
exec(
f"""\
def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
args_kwargs = (args, tuple(sorted(kwargs.items())))
if not hasattr(self, "{key}"):
object.__setattr__(self, "{key}", {{}})
cache = self.{key}
try:
return cache[args_kwargs]
except KeyError:
pass
rv = fn(self, *args, **kwargs)
cache[args_kwargs] = rv
return rv
""".lstrip(),
ctx,
)
inner = functools.wraps(fn)(ctx["inner"])
def clear_cache(self: Any) -> None:
if hasattr(self, key):
delattr(self, key)
inner.clear_cache = clear_cache # type: ignore[attr-defined]
return inner
return wrapper
def aggregate_origins(
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
) -> OrderedSet[Node]:

View File

@ -2,6 +2,7 @@
// This file should only be compiled if this condition holds, so it should be
// safe.
#if defined(USE_CUDNN) || defined(USE_ROCM)
#include <ATen/detail/CUDAHooksInterface.h>
#include <torch/csrc/utils/pybind.h>
#include <tuple>
@ -32,11 +33,7 @@ version_tuple getRuntimeVersion() {
}
size_t getVersionInt() {
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
return at::detail::getCUDAHooks().versionRuntimeCuDNN();
}
} // namespace

View File

@ -0,0 +1,7 @@
## torch/headeronly
The inlined C++ headers in the `torch::headeronly` namespace living this subdirectory are completely decoupled from LibTorch. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
There are two types of LibTorch independent header-only headers:
1. OG header-only. Originally header-only APIs, such as `ScalarType`, `Half`, `BFloat16`, have always been implemented in headers only. For them to move into torch/headeronly only required a code migration, a copy-pasta, if you will.
2. Made to be header-only. There are also APIs that were NOT header-only that we made to be header-only. One example of such an API is `STD_TORCH_CHECK`, which was derived from `TORCH_CHECK`. `STD_TORCH_CHECK` calls into `std::runtime_error` instead of relying on `c10::Error`, which relies on libtorch.so. As a result, `STD_TORCH_CHECK` does not have the full `TORCH_CHECK` functionality that displays a fanciful traceback when the check is not met. We intentionally maintain the design that functions that do different things should be explicitly named differently.