[relanding again after fixing internal build]
Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context
we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.
when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()
one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);
This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.
so I had to define it for pyinterpreter, and then I had to override it for nested tensors.
Approved by: https://github.com/ezyang
Test Plan:
contbuild & OSS CI, see e444cd24d4
Rollback Plan:
Differential Revision: D80435179
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160869
Approved by: https://github.com/ezyang
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context
we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.
when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()
one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);
This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.
so I had to define it for pyinterpreter, and then I had to override it for nested tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159197
Approved by: https://github.com/ezyang
PR implements a pass in post_grad to fuse activation(add + mm)
This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.
however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results
perf dash board
<img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" />
Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API
Graph module before and after this pass
Relu(addmm)
```
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=2] = placeholder[target=primals_2]
%primals_3 : [num_users=2] = placeholder[target=primals_3]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
%relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {})
%le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
%permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
return (relu, primals_2, le, permute_1)
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=2] = placeholder[target=primals_2]
%primals_3 : [num_users=2] = placeholder[target=primals_3]
%_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
%le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {})
%permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
return (_addmm_activation_default, primals_2, le, permute_1)
```
Gelu (addmm)
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {})
%mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {})
%mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {})
%mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {})
%tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {})
return (mul_5,)
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True})
return (_addmm_activation_default,)
```
Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune
H100
```
Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.0107 ms
Average Time per Iteration (torch compile): 0.0296 ms
============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.0262 ms
Average Time per Iteration (torch compile): 0.0327 ms
============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.1763 ms
Average Time per Iteration (torch compile): 0.2457 ms
============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 1.5280 ms
Average Time per Iteration (torch compile): 1.9437 ms
```
A100
```
############################################################
Testing with dtype: float16
############################################################
============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.0313 ms
Average Time per Iteration (torch compile): 0.0643 ms
============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.1149 ms
Average Time per Iteration (torch compile): 0.1255 ms
============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.6297 ms
Average Time per Iteration (torch compile): 0.7547 ms
============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas): 4.3821 ms
Average Time per Iteration (torch compile): 5.0740 ms
```
Script
```py
import torch
torch.manual_seed(0)
warmup, numrun= 10, 100
sizes = [1024, 2048, 4096, 8192]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
device = torch.device("cuda")
for dtype in dtypes:
dtype_name = str(dtype).split('.')[-1]
print(f"\n{'#'*60}")
print(f"Testing with dtype: {dtype_name}")
print(f"{'#'*60}")
for size in sizes:
M, N, K = size, size, size
print(f"\n{'='*60}")
print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}")
print(f"{'='*60}")
A = torch.randn(M, K, device=device, dtype=dtype)
B = torch.randn(K, N, device=device, dtype=dtype)
C = torch.randn(M, device=device, dtype=dtype)
def func1():
return torch._addmm_activation(C, A, B, use_gelu=True)
def func2():
return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh")
func2_compiled = torch.compile(
func2,
dynamic=False,
options={
"force_disable_caches": True,
"max_autotune": True,
"max_autotune_gemm": True,
"max_autotune_gemm_backends": "TRITON",
"autotune_fallback_to_aten": False,
}
)
for _ in range(warmup): func1()
torch.cuda.synchronize(device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time_ms = 0.0
start_event.record()
for _ in range(numrun): func1()
end_event.record()
torch.cuda.synchronize(device=device)
total_time_ms += start_event.elapsed_time(end_event)
avg_time_ms = total_time_ms / numrun
print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms")
for _ in range(warmup): func2_compiled()
torch.cuda.synchronize(device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time_ms = 0.0
start_event.record()
for _ in range(numrun): func2_compiled()
end_event.record()
torch.cuda.synchronize(device=device)
total_time_ms += start_event.elapsed_time(end_event)
avg_time_ms = total_time_ms / numrun
print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137
Approved by: https://github.com/eellison
Automatically replaces split with rsplit when relevant and only performs the split up to the first ( or last value). This allows early return of the split function and improve efficiency.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160107
Approved by: https://github.com/albanD
Adds `c_shim_aten.{h/cpp}` and use this for `fill_`
This is the generated `c_shim_aten.cpp` for reference
```cpp
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See 7e86a7c015/torchgen/gen.py (L2424-L2436) for details
// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
#include <ATen/ops/fill.h>
#endif // AT_PER_OPERATOR_HEADERS
using namespace torch::aot_inductor;
AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::fill_(
*tensor_handle_to_tensor_pointer(self), value
);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158974
Approved by: https://github.com/albanD, https://github.com/janeyx99
Most added ops are backwards ops, which have not been well-tested previously (thus why they were missed). Necessary ops were identified by manual examination of torch/_meta_registrations.py return values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158073
Approved by: https://github.com/desertfire
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944
Approved by: https://github.com/Skylion007
The op test__weight_int4pack_mm_with_scales_and_zeros is for Intel GPU. It is functionally equivalent to the CUDA/CPU op test__weight_int4pack_mm (with the constraint that oneDNN only supports integer zero points, which is why we need this API). Since test__weight_int4pack_mm is already included in AOTI's fallback list, this PR adds support for XPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155780
Approved by: https://github.com/jansel
Fixes #ISSUE_NUMBER
## Background
Task: [T222738229](https://www.internalfb.com/intern/tasks/?t=222738229)
It's the first starter task on the project **_Enabling TorchNative Standalone on Whisper_**. We are using cshim to create a layer of abstraction between _**libtorch**_ and **_AOTInductor generated artifacts_**.
So we needed to add an entry in the cshim for every API surface in libtorch. And we only care about operators that AOTInductor does not handle. And for this task, we only wanted to add it for the following ops.
## What I've done?
4 new fallback ops are added that show up in the Whisper model. (torchgen/aoti/fallback_ops.py)
- aten.permute (default)
- aten.squueze (dim)
- aten.abs (default)
- aten.hann_window (default)
Then I ran the below command to generate new header C shim header files. As it says [here](7e86a7c015/torchgen/gen.py (L2424-L2436%20for%20details))
`python torchgen/gen.py --update-aoti-c-shim`
Then, `python setup.py develop` to rebuild PyTorch
## Testing
Also 4 new tests have been added on test/inductor/test_aot_inductor.py
- test_proxy_executor_permute
- test_proxy_executor_abs
- test_proxy_executor_squeeze
- test_proxy_executor_hann
I ran these commands to test it (inside local pytorch root folder):
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_permute`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_abs`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_squeeze`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_hann`
## NOTE:
I didn't see any order between the tests inside _test/inductor/test_aot_inductor.py_. That's why, I added new tests just after the test given in the example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154251
Approved by: https://github.com/angelayi
Summary:
# Context
See the first PR https://github.com/pytorch/pytorch/pull/153670
# This PR
1. Migrate 3 clamp ops from out-of-tree to in-tree(had to migrate the 3 ops altogether, because clamp.out calls all 3 stubs, which are also called by the other 2 ops):
- clamp.out
- clamp_min.out
- clamp_max.out
2. Also enabled structured kernel codegen for MTIA, which is needed by clamp
3. Also introduced the `--mtia` flag to torchgen to prevent OSS from gencoding MTIA code.(Otherwise we got such link error `lib/libtorch_cpu.so: undefined reference to at::detail::empty_mtia`)
Differential Revision: D74674418
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154015
Approved by: https://github.com/albanD, https://github.com/nautsimon
Added AOTIModelContainerRunnerMps and a shim for mps fallback ops.
I also added a mps-specific shim which contains one operator, which will be used to set arguments being passed to the Metal kernel:
```
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
AOTIMetalKernelFunctionHandle func,
unsigned idx,
AtenTensorHandle tensor);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153964
Approved by: https://github.com/malfet, https://github.com/desertfire
Added AOTIModelContainerRunnerMps and a shim for mps fallback ops.
I also added a mps-specific shim which contains one operator, which will be used to set arguments being passed to the Metal kernel:
```
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
AOTIMetalKernelFunctionHandle func,
unsigned idx,
AtenTensorHandle tensor);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153964
Approved by: https://github.com/malfet, https://github.com/desertfire
Summary:
# Context
The MTIA New Aten Backend work is essentially to move MTIA operators from pytorch out-of-tree to in-tree, with following benefits:
1. Avoid duplicate code copied from pytorch, e.g. view ops implementation, util functions.
2. Utilize TensorIterator and structured kernel codegen, avoid manual implementation of broadcasting, dtype casting, asserting, etc.
3. Eliminate MTIA's own codegen flow, which is unnecessary complexity.
4. Overall make MTIA's aten backend more pytorch native.
Differential Revision: D74672464
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153670
Approved by: https://github.com/albanD, https://github.com/nautsimon
https://github.com/pytorch/pytorch/pull/129001#discussion_r1645126801 is the motivation for the whole stack of PRs. In `torch/__init__.py`, `torch._C.Type` shadows `from typing import Type`, and there is no type stub for `torch._C.Type` in `torch/_C/__init__.pyi`. So we need to use `from typing import Type as _Type`. After enabling [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585) in the `.pyi` type stub files, we can use `type` instead of `typing.Type` or `from typing import Type as _Type`.
------
- [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`.
- [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`.
Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449:
- #117449
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150727
Approved by: https://github.com/aorenste
ghstack dependencies: #150726