Compare commits

..

199 Commits

Author SHA1 Message Date
bc67bce2e5 Working setup with runnable PyTorch on Codex.
Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 132668d46021090fe3ef197fb25ba762ce42667c
Pull-Request: https://github.com/pytorch/pytorch/pull/159968
2025-08-06 14:56:40 -07:00
79eca4677b [precompile] Skip serializing unnecesssary objects for guards. (#158926)
Summary:
The following type of objects don't need to be serialized for precompile:
1. PyCapsule because we don't guard on C binding objects in meaningful ways.
2. Code object because we only id matching on these but id matches will always be dropped for precompile.
3. Nested function objects since we also ban CLOSURE_MATCH.

Test Plan:
buck run mode/opt test/dynamo:test_dynamo -- -k test_skipped_objects

Rollback Plan:

Differential Revision: D78816888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158926
Approved by: https://github.com/jamesjwu
2025-08-06 15:00:28 +00:00
2855688a1d Revert "Replace C array with std::array in formatSockAddr (#159812)"
This reverts commit e7feedf6a9bb346ad205796aa4084c8dcfb18072.

Reverted https://github.com/pytorch/pytorch/pull/159812 on behalf of https://github.com/malfet due to Looks like it broke distribtued tests, see 2231c3ca3a/1 ([comment](https://github.com/pytorch/pytorch/pull/159812#issuecomment-3160513656))
2025-08-06 14:55:48 +00:00
2231c3ca3a [CI][CD] Fix install_nvshem function (#159907)
When one builds CD docker, all CUDA dependencies must be installed into `/usr/local/cuda/` folder

Test plan: Looks at the binary build logs, for example [here](https://github.com/pytorch/pytorch/actions/runs/16768141521/job/47477380147?pr=159907):
```
2025-08-06T05:58:00.7347471Z -- NVSHMEM_HOME set to:  ''
2025-08-06T05:58:00.7348378Z -- NVSHMEM wheel installed at:  ''
2025-08-06T05:58:00.7392528Z -- NVSHMEM_HOST_LIB:  '/usr/local/cuda/lib64/libnvshmem_host.so'
2025-08-06T05:58:00.7393251Z -- NVSHMEM_DEVICE_LIB:  '/usr/local/cuda/lib64/libnvshmem_device.a'
2025-08-06T05:58:00.7393792Z -- NVSHMEM_INCLUDE_DIR:  '/usr/local/cuda/include'
2025-08-06T05:58:00.7394252Z -- NVSHMEM found, building with NVSHMEM support
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159907
Approved by: https://github.com/Skylion007, https://github.com/ngimel
2025-08-06 14:44:37 +00:00
c03a734ba1 [OpenReg] Disable automatic inclusion of data files (#159845)
# Background

After I built torch_openreg, I noticed that the wheel package contained the stub.c file under the csrc directory, which was not used in the runtime.

# Motivation

This PR aims to remove the stub.c file and any unused file when running torch_openreg.

**Changes:**

- Setting **include_package_data** keyword to false in the setup function

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159845
Approved by: https://github.com/albanD
2025-08-06 10:35:13 +00:00
98316e5896 [WOQ] Add CUDA kernel for _weight_int8pack_mm (#159325)
**Summary**
This issue proposes implementing a CUDA kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU. On CUDA, the fallback path uses an unfused .mul().sum() pattern in quantization.py, which is less efficient for inference. https://github.com/pytorch/pytorch/issues/158849

**Motivation**
A fused GPU kernel for aten._weight_int8pack_mm would:
- Eliminate reliance on the .mul().sum() fallback in quantization.py
- Improve performance for quantized inference on CUDA
- Extend Inductor’s GPU quantization support across more workloads

**Implementation**
- Implement a Triton kernel for:
```
out[b, n] = sum_k(x[b, k] * w[n, k]) * scale[n]

where:
x: [B, K] float32
w: [N, K] int8
scale: [N] float32
out: [B, N] float32
```
- Integrate the kernel with register_woq_mm_ops() in torch/_inductor/quantized_lowerings.py
- Route it conditionally in quantization.py where GPU currently falls back to .mul().sum()
- Add unit tests comparing results to the reference fallback path

Test Plan:
```
buck2 run 'fbcode//mode/opt' :linalg test_linalg.TestLinalgCUDA.test__int8_mm_m_64_k_64_n_64_compile_True_slice_True_cuda
```
Log: P1882799769

```
buck2 test 'fbcode//mode/opt' caffe2/test:linalg
```
https://www.internalfb.com/intern/testinfra/testconsole/testrun/6755399722424741/

Benchmark Results:
```
**[Shape B=256, K=1024, N=512]**
CPU and CUDA outputs match
Max abs diff: 2.59e-04, max rel diff: 0.75
CPU: 144.14 ms, CUDA: 303.67 µs
Speedup: ×474.6

**[Shape B=512, K=2048, N=1024]**
CPU and CUDA outputs match
Max abs diff: 5.49e-04, max rel diff: 0.15
CPU: 1173.27 ms, CUDA: 2.40 ms
Speedup: ×488.5
```
Rollback Plan:

Differential Revision: D79042656

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159325
Approved by: https://github.com/danielvegamyhre, https://github.com/jerryzh168
2025-08-06 10:28:08 +00:00
23cf241039 [aoti][mps] Initialize mps kernels first (#159753)
In some cases we have mps kernels which are reused across higher-order-op subgraphs and the toplevel code. However, currently we initialize the variable for the mps kernel the first time we use it, which runs into an issue if we run into the mps kernel within a subgraph since the kernel will only be initialized within the subgraph scope. For instance:
```
if ...
    auto mps_lib_0_func = ...
    mps_lib_0_func->run()

// since we already used mps_lib_0 once, we don't re-initialize it
mps_lib_0_func->run()  // error, mps_lib_0_func not initialized
```

So the solution we took here is to initialize all the kernels at the beginning:
```
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
...
if ...
    get_mps_lib_0()->run()

get_mps_lib_0()->run()  // success
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159753
Approved by: https://github.com/malfet
ghstack dependencies: #159456, #159695
2025-08-06 07:54:29 +00:00
e7feedf6a9 Replace C array with std::array in formatSockAddr (#159812)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159812
Approved by: https://github.com/Skylion007
2025-08-06 07:44:29 +00:00
dad2a05bec [DTensor] Set up DTensorContinuousTestBase (#159885)
Also migrate `test_common_rules.py` since it was a short file

`python test/distributed/tensor/test_common_rules.py`

Before:
Ran 10 tests in 91.516s
After:
Ran 10 tests in 5.604s

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159885
Approved by: https://github.com/ezyang
2025-08-06 07:40:31 +00:00
0495cab545 Wire in pt2_triton_builds (#159897)
Summary:
This allows us to start seeing the failure rate on these models (and
potentially alert on it).

Test Plan:
```
FORCE_LOG_TRITON_BUILDS_TO_PROD=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run @//mode/opt :compile 2>&1 | tee out
```
P1889607054

Waiting for scuba table to generate, but manual logging show it should show up at https://fburl.com/scuba/pt2_triton_builds_inc_archive/7852kt8h soon.

Rollback Plan:

Reviewed By: masnesral

Differential Revision: D79308333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159897
Approved by: https://github.com/masnesral
2025-08-06 07:39:51 +00:00
abfe403981 [AIDIR] Internal util function to insert MLHub debugging insight for dynamic shape (#159391)
Summary:
This feature is Meta internal only
Add a util function to put dynamic shape-related suggestion to MLHubDebugInsightService, which will then be surfaced to users in the MLHub .

The rollout will be controlled by JK.

Test Plan:

MAST job aps-omnifmv3_dev_baseline_test-a34fdccf21

 {F1980593060}

* If you're not able to see the insight, please add yourself to this gk 'mlhub_debugging_insights_dev_visibility'
* The URL link should route to a new Job Inspector page that will provide details and straight forward instructions of how to config the ds. The page is currently still in development so here we use the general PT2 compile JI page.
* Test fails because of the export checks. I'll export after addressing all the comments from reviewers.

Rollback Plan:

Reviewed By: pianpwk

Differential Revision: D78526522

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159391
Approved by: https://github.com/jingsh
2025-08-06 07:39:39 +00:00
1690c0c3a0 [Reland] Migrate ScalarType to headeronly (#159911)
The non ghstack version of #159416, to make sure we don't get reverted again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159911
Approved by: https://github.com/mikaylagawarecki
2025-08-06 07:36:37 +00:00
e9d27aa8fd [CUDA 13] CMake/Dependencies: no need to call find_package(CUB) (#159854)
CUB library is the part of CCCL of the CUDA Toolkit 13. If CUDA Found, CUB is found as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159854
Approved by: https://github.com/eqy
2025-08-06 06:03:58 +00:00
2457e62c90 Revert "Set PYTHONHOME for inductor subprocesses using torch (#159382)"
This reverts commit fe8984a9f43bde10d1956abe7cb40710ed7ceed2.

Reverted https://github.com/pytorch/pytorch/pull/159382 on behalf of https://github.com/malfet due to Broke MacOS testing see d0fccbc99c/1 ([comment](https://github.com/pytorch/pytorch/pull/159382#issuecomment-3157455367))
2025-08-06 05:30:20 +00:00
d0fccbc99c [CI] Delete sm86 tests from pull (#159903)
And delete sm89+cuda12.4 builds from periodic (as sm86+legacy driver should be enough)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159903
Approved by: https://github.com/huydhn
2025-08-06 05:16:55 +00:00
3461988a4b [audio hash update] update the pinned audio hash (#159823)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159823
Approved by: https://github.com/pytorchbot
2025-08-06 05:02:35 +00:00
9764981116 Pass fw/bw compilers to aot_export_joint_with_descriptors (#159814)
Allow overriding nop compilers with real ones when using this flow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159814
Approved by: https://github.com/fmassa
2025-08-06 04:50:56 +00:00
704594eb23 [Dynamo] make HOPs hashable (#159910)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159910
Approved by: https://github.com/yf225
2025-08-06 04:02:17 +00:00
eqy
bfc27cf468 [Distributed] Fix @parametrize on unordered iterable in distributed test (#159793)
seems to fix https://github.com/pytorch/pytorch/issues/145807

sets aren't ordered so `@parametrize` can cause two processes to spawn with different settings

originally debugged thanks to @k-artem, see https://github.com/pytorch/pytorch/issues/145807#issuecomment-2971009451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159793
Approved by: https://github.com/Skylion007, https://github.com/wconstab
2025-08-06 03:51:42 +00:00
311f74089a remove print (#159917)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159917
Approved by: https://github.com/laithsakka
2025-08-06 03:48:23 +00:00
14c7358c64 Enable fr_trace to read local traces from multiple hosts. (#159490)
Summary: For training jobs particularly from GenAI, NCCL trace dumps are generated in the format of `<hostname>.pci3_rank_<rank>`. For multi-node training jobs, the hostname varies across traces. The current prefix matching logic can't handle this case.

Test Plan:
Create a local folder `dumps` and several empty files: `host0.pci3_rank_0`, `host0.pci3_rank_1`, `host1.pci3_rank_0`, `host1.pci3_rank_1` inside it. Then run
```
buck2 run fbcode//caffe2/fb/flight_recorder:fr_trace -- trace_dir dumps
```

Before this diff, fr_trace cannot locate any trace files, giving the following assertion error:
```
AssertionError: no files loaded from /home/tianhaoh/dumps with prefix pci3_rank_
```

After this diff, fr_trace is able to locate the trace files, resulting in the exceptions like
```
    dump = pickle.load(infile)
           ^^^^^^^^^^^^^^^^^^^
EOFError: Ran out of input
```
(since the trace files are fake and empty).

Rollback Plan:

Differential Revision: D79224727

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159490
Approved by: https://github.com/fduwjj
2025-08-06 03:15:34 +00:00
8ce81bcee1 [Torch Package] Make get names of OrderedImporters support fallback to importers (#155743)
Summary:
OrderedImporters is supposed to be an importer which tries out every single importer in self._importers. However the get_name API does not follow this behavior and only uses the get_name from the basic Importer class.
This change is to update the OrderedImporters get_name API so that it tries the get_name API of every single importers.

Differential Revision: D76463252

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155743
Approved by: https://github.com/jcwchen, https://github.com/jingsh
2025-08-06 02:26:10 +00:00
4604f0482c Add UT for torch.accelerator memory-related API (#155200)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155200
Approved by: https://github.com/albanD
ghstack dependencies: #138222, #152932
2025-08-06 02:22:18 +00:00
15f1173e5d Add unified memory APIs for torch.accelerator (#152932)
# Motivation
The following API will be put under torch.accelerator
- empty_cache
- max_memory_allocated
- max_memory_reserved
- memory_allocated
- memory_reserved
- memory_stats
- reset_accumulated_memory_stats
- reset_peak_memory_stats

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152932
Approved by: https://github.com/albanD
ghstack dependencies: #138222
2025-08-06 02:22:18 +00:00
e16c48ae97 [BE] Fix type hint in AOTIRunnerUtil (#159577)
Not sure why it was labelled as list in the first place. In test_aot_inductor.py, I scanned a few use cases and they are tuple as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159577
Approved by: https://github.com/Skylion007
2025-08-06 01:20:45 +00:00
f7a66da5f9 Add DeviceAllocator as the base device allocator (#138222)
# Motivation
In line with [RFC] [A device-agnostic Python device memory related API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/134978), some memory-related APIs are widely used in popular repositories, such as HuggingFace [so many if-else conditional code](https://github.com/search?q=repo%3Ahuggingface%2Faccelerate%20torch.cuda.empty_cache&type=code). We would like to introduce a generic API set under torch.accelerator namespace to generalize these user cases.

<div align="center">
<table>
<tr>
<td> Device-specific memory APIs torch.xxx.foo</td> <td> Device-agnostic memory APIs torch.accelerator.foo</td>
</tr>
<tr>
<td>

```python
torch.xxx.empty_cache
```

</td>
<td>

```python
torch.accelerator.empty_cache
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.reset_peak_memory_stats
```

</td>
<td>

```python
torch.accelerator.reset_peak_memory_stats
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.reset_accumulated_memory_stats
```

</td>
<td>

```python
torch.accelerator.reset_accumulated_memory_stats
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.memory_stats
```

</td>
<td>

```python
torch.accelerator.memory_stats
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.memory_allocated
```

</td>
<td>

```python
torch.accelerator.memory_allocated
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.max_memory_allocated
```

</td>
<td>

```python
torch.accelerator.max_memory_allocated
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.memory_reserved
```

</td>
<td>

```python
torch.accelerator.memory_reserved
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.max_memory_reserved
```

</td>
<td>

```python
torch.accelerator.max_memory_reserved
```

</td>
</tr>

</table>
</div>

# Solution
This design follows a similar pattern to `HostAllocator`. We're introducing a base class `DeviceAllocator`, from which `CUDAAllocator` and `XPUAllocator` will inherit. This allows us to provide a unified call path like: `torch.accelerator.empty_cache()` -> `GetDeviceAllocator(allocator)->empty_cache()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138222
Approved by: https://github.com/albanD, https://github.com/Camyll
2025-08-06 00:40:29 +00:00
3eb3da9b4b [dynamo][guards] Skip ID_MATCH guard on self.__class__.__closure__ (#159888)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159888
Approved by: https://github.com/williamwen42
2025-08-06 00:36:43 +00:00
3ddfd46bd2 Cut a version of TORCH_ERROR_CODE_CHECK in headeronly from AOTI (#159604)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159604
Approved by: https://github.com/albanD, https://github.com/desertfire
2025-08-06 00:29:56 +00:00
6a82da392e [export] Fix generated schema for C++20/23 (#159871)
Summary: Fixing the issue from https://github.com/pytorch/pytorch/issues/159838

Test Plan:
buck run caffe2/:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/

Rollback Plan:

Differential Revision: D79647167

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159871
Approved by: https://github.com/malfet
2025-08-06 00:23:05 +00:00
22bedc429f Extract some HOP utils to be importable (#159705)
Useful helper function for stage 1 export -> manual partitioner -> stage 2 compile users

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159705
Approved by: https://github.com/zou3519
ghstack dependencies: #159134
2025-08-05 23:59:47 +00:00
49abc0e3f8 [Take 2] Setup TorchBench in Docker (#159300)
Fix and reland https://github.com/pytorch/pytorch/pull/158613, I keep `checkout_install_torchbench` in `.ci/pytorch/macos-test.sh` script because it's still used there, and there is no Docker.

### Testing

MacOS perf nightly run https://github.com/pytorch/pytorch/actions/runs/16580798470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159300
Approved by: https://github.com/ZainRizvi
2025-08-05 23:47:42 +00:00
1052604acd fix logging setup issue for Windows.. (#159887)
When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html
Such as:
    TORCH_LOGS="+schedule,+inductor,+output_code"
On Linux, it shows as:
```cmd
declare -x SSH_TTY="/dev/pts/0"
declare -x TERM="xterm"
declare -x TORCH_LOGS="+schedule,+inductor,+output_code"
declare -x USER="xu"
```
On Windows, it shows as:
```cmd
TORCHINDUCTOR_WINDOWS_TESTS=1
TORCH_LOGS="+schedule,+inductor,+output_code"
UCRTVersion=10.0.22000.0
```
For Linux, it shows quotes by default, And Windows is not shows quotes.
Besides that, Windows would auto assemble quotes when env var processing.

On Linux, we will get variable: "+schedule,+inductor,+output_code"
On Windows, we will get variable: '"+schedule,+inductor,+output_code"'

So, we need remove the outer quotes for Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159887
Approved by: https://github.com/angelayi
2025-08-05 23:44:38 +00:00
fe8984a9f4 Set PYTHONHOME for inductor subprocesses using torch (#159382)
Summary:
This is needed for subprocesses that are trying to call back into torch
functionality, i.e. anything that's also setting `PYTHONPATH`.  There are more
`sys.executable` subprocesses in torch/ but it seems like they're fine.

Test Plan: Local inference runs.

Reviewed By: aorenste

Differential Revision: D79124705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159382
Approved by: https://github.com/aorenste
2025-08-05 23:32:48 +00:00
74a754aae9 Add meta kernel for sdpa_math_for_mps (#159695)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159695
Approved by: https://github.com/malfet
ghstack dependencies: #159456
2025-08-05 22:27:06 +00:00
b1ec088113 [mps] Turn on inductor dynamic shapes tests (#159456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159456
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-08-05 22:27:06 +00:00
fb35a9ea4a [export] Improve error messages (#159881)
Originally, if the PT2 errored when loading, we would try to load using the old loader to fit BC issues. However this hides the error messages for if an up-to-date PT2 is erroring when loading due to some other reason.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159881
Approved by: https://github.com/yushangdi
2025-08-05 22:26:48 +00:00
8034b2a732 [inductor] Add TLParse artifact for logging runtime of collective and compute ops (#159730)
Summary:

- debug.py: Added log_runtime_estimates() function to dump runtime estimation data as structured tlparse artifacts in JSON format
- test_structured_trace.py: Added comprehensive test coverage with testing compute and collective ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159730
Approved by: https://github.com/yushangdi
ghstack dependencies: #159190
2025-08-05 22:06:32 +00:00
64cc6f06b1 [Inductor] Revert minimal changes to avoid internal test failures (#159809)
The diff/PR https://github.com/pytorch/pytorch/pull/159211 caused a bunch of test failures for graph compiler(T232684410). But I couldn't figure out a forward fix so far. So with this diff/PR, I'm proposing to revert the minimal changes to resolve the test failures.

I'll continue the debugging, and re-land the reverted changes once we find out a forward fix.

Differential Revision: [D79221721](https://our.internmc.facebook.com/intern/diff/D79221721/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159809
Approved by: https://github.com/blaine-rister, https://github.com/eellison
2025-08-05 22:05:26 +00:00
410812763b Revert "[Inductor][Triton] Support TMA before strict 3.4 cutoff (#159777)"
This reverts commit bbc0df1094b5a4dcd2cce83f8402127b07913231.

Reverted https://github.com/pytorch/pytorch/pull/159777 on behalf of https://github.com/izaitsevfb due to breaking inductor test on ROCm ([comment](https://github.com/pytorch/pytorch/pull/159777#issuecomment-3156770098))
2025-08-05 22:00:24 +00:00
bdb07a2bc5 [Cutlass] Allow offsets to be passed as arguments to kernel (#159761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159761
Approved by: https://github.com/henrylhtsang
ghstack dependencies: #159760
2025-08-05 21:59:07 +00:00
8085edc8f9 [autograd] torch._C._set_view_replay_enabled state leaking into other tests (#159840)
This was causing view_fns to pop up in tests that ran after `TestAutograd.test_view_replay_enabled` where it isn't used as a context manager. It is unclear to me why we would want `_force_original_view_tracking` to mutate global state on __init__ rather than on __enter__, that could be an alternative fix.

FIXES https://github.com/pytorch/pytorch/issues/156306 https://github.com/pytorch/pytorch/issues/156289 https://github.com/pytorch/pytorch/issues/156265 https://github.com/pytorch/pytorch/issues/156209
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159840
Approved by: https://github.com/albanD
2025-08-05 21:57:49 +00:00
882d50c5bf [C10] Add Scalar::isUnsigned() method (#159877)
That returns true if Scalar hold unsigned integral value

With the implications of `Tag::HAS_u` semantic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159877
Approved by: https://github.com/Skylion007, https://github.com/ezyang
2025-08-05 21:43:21 +00:00
b52a4d0821 [ez][CI] Remove some unused docker images (#159171)
Removes unused docker images from the docker build workflow
Then removes unused definitions in build.sh

The only one I left is the vllm one because I'm pretty sure it's going to be used in the future

I assume everything not mentioned is old and we forgot to remove them
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159171
Approved by: https://github.com/yangw-dev
2025-08-05 21:31:53 +00:00
a45a840926 [CI] Disable check-labels and check_mergeability (#159900)
See https://github.com/pytorch/pytorch/issues/159825
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159900
Approved by: https://github.com/clee2000
2025-08-05 21:16:12 +00:00
9b953bb3fb [BE] Update TensorPipe pin (#159834)
No functional changes, just:
- Update C++ standard to C++17
- Update `cmake` min version to 3.18
- Update `libuv` dependency to 1.51 (to move its cmake min version to 3.10)
- Replace boost optional implementation with `std::optional` wrapper
- Make it compilable with gcc-14.x plus by including `cstddef` in few headers
-  Avoid using deprecated enums for MacOS builds

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159834
Approved by: https://github.com/Skylion007
2025-08-05 20:45:09 +00:00
eb25a95a6e Fix inductor memory estimation when a single buf has multiple mutations. Add runtime verification of mem tracking (#159569)
With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline:

```
    When an operation mutates a buffer in-place, the scheduler creates a new buffer name
    to track the "before" and "after" states, even though they share the same memory.
    The mutated buffer represents a rename with zero allocation and deallocation cost.
    During dependency tracking, we transfer dependencies from the mutated name back to
    the original buffer, ensuring the original memory is only freed when all aliases
    are done.
    This handles cases where a buffer has multiple non-overlapping aliases - rather than
    trying to assign free costs to individual aliases, we forward all alias dependencies
    to the original buffer.
    Consider:
        buf0 = op0()
        buf1 = mutation_op_(buf0)
        del buf0
        ...
        op(buf1)
        del buf1
    The only memory events are the creation prior to op0, and the deletion following buf1.
```

As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect.

This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569
Approved by: https://github.com/IvanKobzarev
2025-08-05 19:58:11 +00:00
eqy
9884d0351e [CUDA] Decrease launch bounds of CTCLoss backward for blackwell (#159522)
Otherwise we see `CUDA error: too many resources requested for launch`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159522
Approved by: https://github.com/janeyx99
2025-08-05 19:26:25 +00:00
d7c83972d5 tools: Add mode to find python automatically (#159820)
Add support for automatically finding Python interpreters in manylinux
environments to our wheel building script. Scaffolding for sequential builds

Signed-off-by: Eli Uriegas <eliuriegas@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159820
Approved by: https://github.com/malfet
2025-08-05 19:19:22 +00:00
e06b110f73 [Testing] Add MPS to NATIVE_DEVICES (#153835)
This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely:
- Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64`
- Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())`
- Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed)
- Add lots of `expectedFailureIfMPS`
- Delete all `@onlyNativeDeviceTypesAnd("mps")`

&lt;sarcasm&gt; I love how well documented this variable are &lt;/sarcasm&gt;

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835
Approved by: https://github.com/Skylion007
2025-08-05 18:57:35 +00:00
0ba09a6d34 fix link for tutorial of inductor on windows (#159853)
fix link issue from https://docs.pytorch.org/tutorials/prototype/inductor_windows.html to https://docs.pytorch.org/tutorials/unstable/inductor_windows.html due to structure change with pr https://github.com/pytorch/tutorials/pull/3489
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159853
Approved by: https://github.com/sekyondaMeta

Co-authored-by: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com>
Co-authored-by: Zesheng Zong <zesheng.zong@outlook.com>
2025-08-05 18:37:47 +00:00
aeb5321b63 Allow controlling PG backend and options via init_device_mesh (#159371)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159371
Approved by: https://github.com/wconstab, https://github.com/fduwjj, https://github.com/wanchaol
2025-08-05 12:44:14 +00:00
625108ede2 [inductor] consolidate common GEMM triton param retrieval (#159383)
\# Why

- Make loop iteration simpler
- Have a common spot where to make modifications that affect
  all the GEMM Triton templates, avoiding missed spots

\# What

- pull out commong logic of taking the BaseConfig objects
  and turning them into kwargs to feed into maybe_append_choice
  for Triton GEMM templates

Differential Revision: [D79186962](https://our.internmc.facebook.com/intern/diff/D79186962)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159383
Approved by: https://github.com/jansel
2025-08-05 11:42:25 +00:00
09e5a93fcb Improve graph output alias with subclass error message (#159619)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159619
Approved by: https://github.com/albanD
2025-08-05 06:47:31 +00:00
908c5cc4c0 Generalize torch._C._set_allocator_settings to be generic (#156175)
# Motivation
This PR moves the implementation of `torch.cuda.memory._set_allocator_settings` to `torch._C._accelerator_setAllocatorSettings`.
Since the original API was intended as a temporary/internal utility, I am not exposing the new function as a public API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156175
Approved by: https://github.com/albanD
ghstack dependencies: #159629, #150312, #156165
2025-08-05 04:08:42 +00:00
c1145852a5 Deprecate overleap functions in CUDAAllocatorConfig, use AcceleratorAllocatorConfig instead (#156165)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156165
Approved by: https://github.com/albanD
ghstack dependencies: #159629, #150312
2025-08-05 04:08:42 +00:00
ae1a706444 Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312)
# Motivation
Refactor `CUDAAllocatorConfig` to reuse `AcceleratorAllocatorConfig` and `ConfigTokenizer`. We would deprecate those option that overleap with `AcceleratorAllocatorConfig` in the following PR and keep them only for BC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150312
Approved by: https://github.com/albanD
ghstack dependencies: #159629
2025-08-05 04:08:04 +00:00
56d19a5ced Fix AllocatorConfig potential SIO issue (#159629)
# Motivation
As @ScottTodd identified in this [comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3141524874), using STL containers like `std::string` and `std::unordered_set` at static init time can cause static initialization order issues. This PR is based on and modified from his original PR: https://github.com/pytorch/pytorch/pull/159607. I’m stacking this PR here to help facilitate the landing and validation process.

Co-authored-by: @ScottTodd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159629
Approved by: https://github.com/ScottTodd, https://github.com/albanD
2025-08-05 04:07:51 +00:00
b6c53383fe [Dynamo][Better Engineering] Type annotation for torch/_dynamo/output_graph.py (#159602)
As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to `torch/_dynamo/output_graph.py`

Running
```
mypy torch/_dynamo/output_graph.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  2163 | 4792 | 45.14% | 121 | 268 | 45.15% |
| This PR | 4818 | 4818 | 100.00% | 268 | 268 | 100.00% |
| Delta    | +2655 | +26 | +54.84% | +147 | 0 | +54.85% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159602
Approved by: https://github.com/Skylion007
2025-08-05 03:50:54 +00:00
4fd5fabee9 skip XPU for dataloader CPU only unit test (#159811)
Fixes [#159802](https://github.com/pytorch/pytorch/issues/159802)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159811
Approved by: https://github.com/izaitsevfb
2025-08-05 03:44:01 +00:00
bbc0df1094 [Inductor][Triton] Support TMA before strict 3.4 cutoff (#159777)
Summary: Inductor's 3.4 Triton release is the most common used variant of Triton, but if someone is working with an alternative version of Triton this may not match. This moves the version check from 3.4 Triton to any variant that has support for the TMA APIs.

Test Plan:
Relying on CI. Should be a NFC.

Rollback Plan:

Reviewed By: davidberard98

Differential Revision: D79378792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159777
Approved by: https://github.com/davidberard98
2025-08-05 03:29:13 +00:00
33ec6e3e9a Remove pin on libuv from instructions (#159504)
This package doesn't exist at conda-forge and causes some confusion for users.
see https://anaconda.org/conda-forge/libuv/files?version=1.39.0

libuv is quite stable, so the newer versions should be fine. we build with them anyway at conda-forge.

see: https://github.com/conda-forge/libuv-feedstock/issues/80

Hopefully this can help future users.

Fixes https://github.com/conda-forge/libuv-feedstock/issues/80

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159504
Approved by: https://github.com/seemethere
2025-08-05 03:18:42 +00:00
efc4b460b3 Add cascade sum support for Inductor CPP backend (#156296)
Fixes #154703

Add cascade summation support for Inductor CPP backend to improve precision for large size summation.

Currently, Inductor CPP directly do reduction for sum. As shown in #154703, when the size of the sum is large and the number of parallel is small, direct reduction will cause an intolerable precision loss:
```
extern "C"  void kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    {
        {
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3000000000L); x0+=static_cast<int64_t>(16L))
            {
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(3000000000L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        tmp_acc0_vec = tmp_acc0_vec + tmp0;
                    }
                }
            }
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
            out_ptr0[static_cast<int64_t>(0L)] = static_cast<float>(tmp_acc0);
        }
    }
    {
        {
            {
                auto tmp0 = out_ptr0[static_cast<int64_t>(0L)];
                auto tmp1 = static_cast<float>(3000000000.0);
                auto tmp2 = tmp0 / tmp1;
                in_out_ptr0[static_cast<int64_t>(0L)] = tmp2;
            }
        }
    }
}
```

After adding cascade sum support:

```
extern "C"  void kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    {
        {
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            at::vec::Vectorized<float> masked_tmp_acc0_vec = at::vec::Vectorized<float>(0);
            CascadeSumHelper<float, 65536> scalar_cascade_helper0(static_cast<int64_t>(3000000000L));
            CascadeSumHelper<at::vec::Vectorized<float>, 65536> cascade_helper0(static_cast<int64_t>(187500000L));
            CascadeSumHelper<at::vec::Vectorized<float>, 65536> masked_cascade_helper0(static_cast<int64_t>(0L));
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3000000000L); x0+=static_cast<int64_t>(16L))
            {
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(3000000000L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        tmp_acc0_vec = cascade_sum_combine(tmp0, &cascade_helper0);
                    }
                }
            }
            tmp_acc0 = cascade_sum_final(&scalar_cascade_helper0);
            tmp_acc0_vec = cascade_sum_final(&cascade_helper0);
            masked_tmp_acc0_vec = cascade_sum_final(&masked_cascade_helper0);
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec + masked_tmp_acc0_vec);
            out_ptr0[static_cast<int64_t>(0L)] = static_cast<float>(tmp_acc0);
        }
    }
    {
        {
            {
                auto tmp0 = out_ptr0[static_cast<int64_t>(0L)];
                auto tmp1 = static_cast<float>(3000000000.0);
                auto tmp2 = tmp0 / tmp1;
                in_out_ptr0[static_cast<int64_t>(0L)] = tmp2;
            }
        }
    }
}
```
This will inevitably reduce performance when cascade sum is turned on.
For the case shown in #154703: performance reduced by ~3%.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156296
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2025-08-05 02:54:32 +00:00
1ca8388442 [BE][MPS] Remove unused size12 variable (#159832)
Fixes following compilation warning
```
/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Pooling.metal:433:8: warning: unused variable 'size12' [-Wunused-variable]
  auto size12 = input_sizes[1] * input_sizes[2];
       ^
1 warning generated.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159832
Approved by: https://github.com/dcci
2025-08-05 02:32:06 +00:00
b69497351d [nativert] force resize to zero. (#159683)
Summary:
this was quite a miserable bug. there are a few kernels that don't explicitly resize outputs to zero, which led to some weird UB.

Rollback Plan:

Differential Revision: D79476454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159683
Approved by: https://github.com/SherlockNoMad, https://github.com/henryoier
2025-08-05 02:25:31 +00:00
482f069c41 [C10D] fix slow init due to repeated dns resolution failure (#159596)
It can be be very slow to repeatedly hit DNS resolution failure, but
its very helpful to have DNS names in logs by default. So we try to use DNS
but if we hit a transient failure we just disable it for the remainder of the
job, logging IP addresses instead.

Fixes #159007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159596
Approved by: https://github.com/d4l3k
2025-08-05 02:15:26 +00:00
85d931f29e Use uppercase OR when checking for system XNNPACK (#159527)
This PR fixes `cmake/Dependencies.cmake` to work when compiling with `USE_SYSTEM_XNNPACK=ON` by changing a lowercase `or` to an uppercase `OR`.

---

For a personal project, I was building pytorch with a customized build of XNNPACK. When trying to do so I encountered the following error:

```
CMake Error at cmake/Dependencies.cmake:566 (if):
  if given arguments:

    "NOT" "XNNPACK_LIBRARY" "or" "NOT" "microkernels-prod_LIBRARY"

  Unknown arguments specified
Call Stack (most recent call first):
  CMakeLists.txt:868 (include)
```

Upon making the change in this PR (changing `or` to `OR`), the process continued as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159527
Approved by: https://github.com/janeyx99
2025-08-05 02:10:53 +00:00
8a2f53c523 Recursively sync fbgemm submodules before build (#159477)
ROCm inductor benchmark builds failing fbgemm build stage https://ossci-raw-job-status.s3.amazonaws.com/log/46800456622
```
2025-07-27T08:00:32.3443858Z /var/lib/jenkins/pytorch/fbgemm/src/RowWiseSparseAdagradFused.cc:389:18: error: no matching function for call to ‘asmjit::v1_17::x86::Vec::Vec(uint32_t)’
2025-07-27T08:00:32.3444080Z   389 |         x86::Xmm partial_sum_xmm(partial_sum_vreg.id());
```

It looks like asmjit fails to build, this seems to be due to submodules of fbgemm not being updated after checking out to new commit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159477
Approved by: https://github.com/pruthvistony, https://github.com/eqy
2025-08-05 02:00:54 +00:00
b59b61a099 Add avg_pool3d backward pass for MPS (#159089)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159089
Approved by: https://github.com/malfet
2025-08-05 01:55:38 +00:00
57ab39f7e4 Update torch-xpu-ops commit pin (#159621)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@1f7a57](1f7a57f507) includes:

- Add Template Parameter to the function `gpu_kernel` for Controlling Broadcasting Vectorization
- Add optional NaN checks to XCCL
- Fix NllLossForwardReduce2DKernelFunctor accuracy
- Extend the existing communication logging to include the reduction operation for collective calls
- [Reland] Install xpu codegen header to torch/include
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159621
Approved by: https://github.com/EikanWang
2025-08-05 01:46:15 +00:00
182975e01a [Dynamo] Enable torch function dispatch on HOPs (#159708)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159708
Approved by: https://github.com/zou3519, https://github.com/XilunWu
ghstack dependencies: #159707
2025-08-05 01:43:22 +00:00
9f8cfe7476 [Dynamo] Fix arg ordering in tf modes (#159707)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159707
Approved by: https://github.com/zou3519
2025-08-05 01:43:21 +00:00
e273ff028a Fix failing test (#159800)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159800
Approved by: https://github.com/aorenste
2025-08-05 00:28:51 +00:00
5e0fc2c9a9 [AOTI] don't allow int32 indices if {non-inf, > int32_max} upper bound is provided (#159433)
**Motivation / Context**: (what I _think_ is happening here)

In "eager"/just-in-time PT2 usage, dynamo/inductor will guard on whether indices fit in int32 or not. So it's generally safe in Inductor code to rely on the example values for symbolic ints in order to determine whether indices fit in int32, because the indices will be guarded on anyway; and if the inputs ever increase to `>int32_max`, dynamo will cause a recompilation.

But with AOTI, those int32 guards aren't respected; so if the example input is `< int32_max` but can be `> int32_max` during future execution, then the future execution might fail / IMA.

**Solution space**

Export allows users to specify which dimension are dynamic, and to provide **ranges of valid sizes**.

One solution idea is to always respect the upper bound of the dynamic shape range when doing AOTI; if the index's range includes values `>int32_max`, then don't use the hint and assume that this index doesn't fit in int32.

However, the problem with this is that many users may specify dynamism without specifying a range of values - the upper bound of the range will be set to the default of `inf`. Such use cases could potentially experience a perf regression if we implemented the idea above.

To prevent any such regressions, this implementation will rely solely on the specified range only if the upper bound of the range isn't inf. In other words, we'll ignore the hints/example values for AOTI (and rely only on the specified range) only if the upper bound of the range isn't inf - if users explicitly specify a range that extends past int32, we can be fairly sure that they actually do need values `>int32_max`.

If we continue to see correctness issues even with this implementation, we could consider more aggressively relying on the ranges.

Differential Revision: [D79220301](https://our.internmc.facebook.com/intern/diff/D79220301)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159433
Approved by: https://github.com/jingsh, https://github.com/ColinPeppler
2025-08-05 00:17:09 +00:00
bc4b04e058 DeviceCopy should have the same layout as input (#159615)
Summary: Fix https://github.com/pytorch/pytorch/issues/159612

- Fix the meta implementation of `nan_to_num`, it should preserve the stride of the input
- The DeviceCopy IR node should always preserve the input's layout, so we don't end up with a contiguous call during device copy

Test Plan:
```
buck2 run @mode/dev-nosan fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_d2h_copy
```

Rollback Plan:

Differential Revision: D79411407

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159615
Approved by: https://github.com/eellison
2025-08-04 23:56:58 +00:00
6b414f56a4 Revert "[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160) (#158462)" (#159798)
This reverts commit 305a03727672de42870f956ddf4ad9fa424443e1.

Reason: causes device-side assertion failures when running with this repro (a minimized version of a failure seen in a real model)

```
import torch
def ri(inp, repeats, output_size):
    return torch.repeat_interleave(inp, repeats, output_size=output_size)
inp = torch.arange(0, 4, device="cuda").reshape(-1, 1)
x = torch.tensor([1, 2, 3, 4], device="cuda")
ri_c = torch.compile(ri)
print(ri(inp, x, 10))
print(ri_c(inp, x, 10))
```

which leads to errors like

```
/tmp/torchinductor_dberard/3h/c3hlb22fpptebupstsuhl6kexa6z3upgbnyxln7c24gfcr5747iu.py:30: unknown: block: [0,0,0], thread: [10,0,0] Assertion `index out of bounds: 0 <= tmp5 < 4` failed.
```

Differential Revision: [D79591561](https://our.internmc.facebook.com/intern/diff/D79591561)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159798
Approved by: https://github.com/danzimm
2025-08-04 23:39:20 +00:00
fb8f32ef52 Revert "[mps] Turn on inductor dynamic shapes tests (#159456)"
This reverts commit 19f1f9960db7f29f2110a7f49f06a1a23c651ecf.

Reverted https://github.com/pytorch/pytorch/pull/159456 on behalf of https://github.com/davidberard98 due to Sorry - this causes a merge conflict with https://github.com/pytorch/pytorch/pull/159798, which I'm trying to land with co-dev to resolve a sev ([comment](https://github.com/pytorch/pytorch/pull/159456#issuecomment-3152751821))
2025-08-04 23:11:05 +00:00
7ba996bbaa [Cutlass] Fix wrapper code generation breakage (#159760)
Fixes issues introduced by https://github.com/pytorch/pytorch/pull/159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159760
Approved by: https://github.com/angelayi
2025-08-04 23:03:03 +00:00
ddbdcdc710 [cutlass backend][test] Expand FP8 tests to FP16 (#159538)
Differential Revision: [D79317343](https://our.internmc.facebook.com/intern/diff/D79317343/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159538
Approved by: https://github.com/mlazos
2025-08-04 23:01:55 +00:00
19f1f9960d [mps] Turn on inductor dynamic shapes tests (#159456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159456
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-08-04 22:44:31 +00:00
fd6655a0f5 Feature: Implement support for cudnn_batch_norm_out kernel to replace the autogen approach. (#123020)
Fixes #115611

Autogen kernel may cause redundant copy, so we develop the kernel to improve efficiency.

Test Case:

```c++
#include <torch/torch.h>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

int main() {
    auto input = torch::rand({2, 3, 4, 4}, torch::device(torch::kCUDA));
    auto weight = torch::randn({3}, torch::device(torch::kCUDA));
    auto bias = torch::randn({3}, torch::device(torch::kCUDA));
    auto running_mean = torch::zeros({3}, torch::device(torch::kCUDA));
    auto running_var = torch::ones({3}, torch::device(torch::kCUDA));

    bool training = true;
    double exponential_average_factor = 0.1;
    double epsilon = 1e-5;

    auto output = torch::empty_like(input);
    auto save_mean = torch::empty({3}, torch::device(torch::kCUDA));
    auto save_var = torch::empty({3}, torch::device(torch::kCUDA));
    auto reserve = torch::empty({0}, torch::device(torch::kCUDA)); // empty place-holder

    at::native::cudnn_batch_norm_out(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, output, save_mean, save_var, reserve);
    auto outputs = at::native::cudnn_batch_norm(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon);

    bool is_close_output = torch::allclose(output, std::get<0>(outputs));
    bool is_close_save_mean = torch::allclose(save_mean, std::get<1>(outputs));
    bool is_close_save_var = torch::allclose(save_var, std::get<2>(outputs));
    bool is_close_reserve = torch::allclose(reserve, std::get<3>(outputs));

    std::cout << "Is output close: " << is_close_output << std::endl;
    std::cout << "Is save_mean close: " << is_close_save_mean << std::endl;
    std::cout << "Is save_var close: " << is_close_save_var << std::endl;
    std::cout << "Is reserve close: " << is_close_reserve << std::endl;

    return 0;
}
```

Please CC @albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123020
Approved by: https://github.com/andrewor14, https://github.com/eqy, https://github.com/albanD
2025-08-04 22:40:33 +00:00
a7f3bdf550 [Dynamo][Better Engineering] Type coverage for torch/_dynamo/utils.py (#159580)
As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to `torch/_dynamo/utils.py`

Running
```
mypy torch/_dynamo/utils.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  2163 | 4792 | 45.14% | 121 | 268 | 45.15% |
| This PR | 4818 | 4818 | 100.00% | 268 | 268 | 100.00% |
| Delta    | +2655 | +26 | +54.84% | +147 | 0 | +54.85% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159580
Approved by: https://github.com/williamwen42
2025-08-04 21:51:53 +00:00
510e8b4ae0 [inductor] use writable temp file on windows (#159738)
Use `WritableTempFile` on Windows, reference to: https://github.com/pytorch/pytorch/pull/159342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159738
Approved by: https://github.com/angelayi, https://github.com/Skylion007
2025-08-04 21:51:02 +00:00
83ba3f1101 Revert "[inductor] allocate non-blocking copy destinations in pinned memory (#155121) (#158758)"
This reverts commit 6085bf7565fec0d2ed26e8590001f09c05adbbe4.

Reverted https://github.com/pytorch/pytorch/pull/158758 on behalf of https://github.com/davidberard98 due to I need to revert #158462 (it causes device-side asserts), and this PR causes a merge conflict in the test file. Sorry about that! ([comment](https://github.com/pytorch/pytorch/pull/158758#issuecomment-3152490371))
2025-08-04 21:47:11 +00:00
1fad16aacb Revert "[inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983)"
This reverts commit 444e2381d07a14cb501c00d11f9e63a3f1d2c86e.

Reverted https://github.com/pytorch/pytorch/pull/158983 on behalf of https://github.com/davidberard98 due to I need to revert #158462 (it causes device-side asserts), and this PR causes a merge conflict in the test file. Sorry about that! ([comment](https://github.com/pytorch/pytorch/pull/158758#issuecomment-3152490371))
2025-08-04 21:47:11 +00:00
444e2381d0 [inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158983
Approved by: https://github.com/eellison
ghstack dependencies: #158758
2025-08-04 21:42:05 +00:00
6085bf7565 [inductor] allocate non-blocking copy destinations in pinned memory (#155121) (#158758)
Fixes #155121

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158758
Approved by: https://github.com/EikanWang, https://github.com/eellison
2025-08-04 21:22:11 +00:00
8201dbf4bc check driver to be >=12.4 to use fabric handles (#159697)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159697
Approved by: https://github.com/malfet
2025-08-04 21:05:39 +00:00
26d045bb60 Linux py 3.14 wheel builds (#157559)
Related to https://github.com/pytorch/pytorch/issues/156856

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157559
Approved by: https://github.com/malfet, https://github.com/albanD
2025-08-04 20:55:19 +00:00
356ac3103a Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3.

Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
2025-08-04 20:37:39 +00:00
d4109a0f99 [MPS] Add max_unpool1d/2d/3d (#159789)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159789
Approved by: https://github.com/malfet
2025-08-04 20:00:59 +00:00
7ea789ccfb Revert #156868: Bring back symint check for sharding propagation cache (#159671)
Fixes #159601

Unfortunately #156868 introduced a couple regressions (see #159590 and #159601). This reverts the commit while I am working on a permanent fix. This means the `in_compiled_autograd_initial_trace` global flag will be removed and the `_are_we_tracing()` will instead be replaced with the symint preprocessing step during sharding prop post init.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159671
Approved by: https://github.com/xmfan
2025-08-04 19:58:48 +00:00
7e8197e34d Revert "Migrate ScalarType to headeronly (#159416)"
This reverts commit 1371a98b0e727f8a8916dd473b6dd0cff78c0449.

Reverted https://github.com/pytorch/pytorch/pull/159416 on behalf of https://github.com/izaitsevfb due to breaking internal builds, see D79452481 ([comment](https://github.com/pytorch/pytorch/pull/159416#issuecomment-3152138508))
2025-08-04 19:55:09 +00:00
50eac811a6 [typing] Constrain OrderedSet generic to be Hashable (#159684)
Ran across this typing bug while creating an OrderedSet from a type I didn't realize wasn't hashable, which failed at runtime. With this constraint, typing would've failed pre-runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159684
Approved by: https://github.com/Skylion007
2025-08-04 18:08:01 +00:00
4e0f179d0b Update the signature and test of torch.hamming_window() (#152682)
Fixes #146590

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152682
Approved by: https://github.com/albanD
2025-08-04 17:50:42 +00:00
36e59d9b12 [c10d][nvshmem] fix missing override compilation error for nvshmem symmetric code (#159557)
Summary:
Fix error when compiling nvshmem code section `NVSHMEMSymmetricMemory.cu` with BUCK

```
fbcode/caffe2/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu:154:20: error: 'get_buffer' overrides a member function but is not marked 'override' [-Werror,-Winconsistent-missing-override]
  154 | virtual at::Tensor get_buffer(int
      |                    ^
fbcode/caffe2/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp:56:20: note: overridden virtual function is here
   56 | virtual at::Tensor get_buffer(int rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storage_offset) = 0;
```

Test Plan:
Build test + CI

Rollback Plan:

Differential Revision: D78813586

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159557
Approved by: https://github.com/kwen2501
2025-08-04 17:46:30 +00:00
fc340d0ca3 [export] Allow comparing device w/o index with device w/ index (#159665)
In the case where we have expected device "cuda" and given device "cuda:0" I think we should succeed?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159665
Approved by: https://github.com/yushangdi
2025-08-04 17:00:07 +00:00
53e47af0f7 [dynamo][guards] Read the attr name from GetAttrGuardAccessor (#159754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159754
Approved by: https://github.com/jansel
ghstack dependencies: #159752
2025-08-04 16:51:27 +00:00
66ad881fc7 [dynamo][guards][refactor] Simplify type extraction from GuardManager (#159752)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159752
Approved by: https://github.com/jansel
2025-08-04 16:51:27 +00:00
1d3eef27ac [ROCm CI] Migrate to MI325 Capacity (#159649)
Migrate mi300s to gfx942.

Related to https://github.com/pytorch/pytorch/pull/159059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159649
Approved by: https://github.com/huydhn
2025-08-04 16:48:12 +00:00
dd95900cec [AOTI] normalize_path_separator file path for Windows. (#159726)
`normalize_path_separator` file path for Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159726
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-08-04 15:57:19 +00:00
1cdd665526 fix test_verbose_logs_dynamic_shapes with MSVC (#159573)
Operator `typeid` have different outputs in different compiler. There is a good example in [cppreference](https://www.en.cppreference.com/w/cpp/language/typeid.html).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159573
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-08-04 15:56:53 +00:00
7cb2dcd2dd [c10d][nvshmem] modify is_nvshmem_available runtime check to work with static-linked library (#159558) (#159561)
Summary:

Currently this function rely on the logic that we load `libnvshmem_device.a` statically and load `libnvshmem_host.so` at runtime. For loading `libnvshmem.a` (the combine 2 thing together) statically this will fail. Add a section to check if the symbol from host API exist at runtime to check if nvshmem is loaded statically

Test Plan:
CI + sample run

Rollback Plan:

Differential Revision: D79177525

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159561
Approved by: https://github.com/kwen2501
2025-08-04 15:40:29 +00:00
e5a81aa7ba Fix conversion of values in libtorch agnostic tests (#155115)
Due to different byteorder,
when copying data, it has to be put into last bytes to ensure that int32_t converted to int64_t keeps same value. Same has to be done when it's converted back.

This change fixes test
TestLibtorchAgnosticCPU::test_my_ones_like_cpu
from
cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py on s390x.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155115
Approved by: https://github.com/huydhn
2025-08-04 13:40:22 +00:00
3e2aa4b0e3 Update pin to include Python 3.14 support (#159725)
Update Triton Pin to top of rel/3.4 branch : https://github.com/triton-lang/triton/tree/rel/3.4 . This is the same as release/3.4.x branch but also includes Python 3.14 support

This should unblock enablement of Python 3.14 support in this PR: https://github.com/pytorch/pytorch/pull/157559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159725
Approved by: https://github.com/davidberard98
2025-08-04 13:30:12 +00:00
6646461764 S390X: fix detection of magic number placeholder in inductor (#157784)
This change fixes multiple tests in
test/inductor/test_aot_inductor_arrayref.py
such as
test_cond_with_parameters_cpu_with_stack_allocation,
test_issue_140766_cpu_with_stack_allocation,
test_model_modified_weights_cpu_with_stack_allocation,
test_nested_tensor_from_jagged_cpu_with_stack_allocation.

Enable tests in test/inductor/test_aot_inductor_arrayref.py

This change is split off from https://github.com/pytorch/pytorch/pull/150116

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157784
Approved by: https://github.com/huydhn
2025-08-04 12:42:31 +00:00
f74da2a136 [xla hash update] update the pinned xla hash (#159758)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159758
Approved by: https://github.com/pytorchbot
2025-08-04 11:21:45 +00:00
eqy
d35b27dde5 [CUDA] Add some more missing @serialTest decorators (#159672)
Seems to fix #159663

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159672
Approved by: https://github.com/Skylion007
2025-08-04 07:44:35 +00:00
a9dc1566d4 [MTIA Aten Backend] Migrate arange.start_out (#159540)
Differential Revision: [D79317519](https://our.internmc.facebook.com/intern/diff/D79317519/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159540
Approved by: https://github.com/malfet, https://github.com/nautsimon
2025-08-04 07:38:05 +00:00
33a1996714 Fix perf downgrad by reverting template use in use_mkldnn_matmul (#159024)
This PR is to fix the performance downgrad by reverting template use in `use_mkldnn_matmul` in #157520 . Fix https://github.com/pytorch/pytorch/issues/159031 and https://github.com/pytorch/pytorch/issues/159551.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159024
Approved by: https://github.com/mingfeima
2025-08-04 05:49:46 +00:00
ee62177c19 [dynamo] Be consistent with storing func source for UserMethodVariable (#159696)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159696
Approved by: https://github.com/jansel
ghstack dependencies: #159534
2025-08-04 05:12:44 +00:00
64cbaa876c [dynamo][guards] Make class members go through obj.__class__.__dict__ (#159534)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159534
Approved by: https://github.com/jansel
2025-08-04 05:12:44 +00:00
4516c59f5f [dynamo][source] Add special source for __code__ and __closure__ (#159722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159722
Approved by: https://github.com/jansel
2025-08-04 05:02:05 +00:00
8bc843a9ec [vllm hash update] update the pinned vllm hash (#159610)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vllm hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159610
Approved by: https://github.com/pytorchbot
2025-08-04 04:06:09 +00:00
e39a62c70d Fix warnings in triton_helpers.py (#159719)
```
  /home/jansel/pytorch/torch/_inductor/runtime/triton_helpers.py:152: UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
    equal |= a_isnan and b_isnan
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159719
Approved by: https://github.com/Skylion007
2025-08-04 03:21:09 +00:00
978e3a9142 refresh expected results (#159727)
Just regular update due to recent <10% changes CI is stable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159727
Approved by: https://github.com/anijain2305
2025-08-03 22:47:50 +00:00
e2a5c42e7e [BE][MPS] Build metal kernels of MacOS-14+ (#159733)
Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary.
Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic.

Part of https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159733
Approved by: https://github.com/dcci
ghstack dependencies: #159731, #159732
2025-08-03 20:53:58 +00:00
5116c49b52 [BE] Remove macos-13 guard from bench_mps_ops (#159732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159732
Approved by: https://github.com/dcci
ghstack dependencies: #159731
2025-08-03 20:53:58 +00:00
fecdebe385 [CI][MPS] Fix compile benchmark correctness (#159731)
By passing `fullgraph=True` attribute and increasing cache size limit to 2**16

Otherwise, compiler might decide not to fall back to eager to avoid recompilations
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159731
Approved by: https://github.com/dcci
2025-08-03 20:53:50 +00:00
e136a9175b [BE] Fix dev warning in Dependencies.cmake (#159702)
Namely
```
CMake Warning (dev) in cmake/Dependencies.cmake:
  A logical block opening on the line

    /Users/nshulga/git/pytorch/pytorch/cmake/Dependencies.cmake:261 (if)

  closes on the line

    /Users/nshulga/git/pytorch/pytorch/cmake/Dependencies.cmake:263 (endif)

  with mis-matching arguments.
```

Introduced by https://github.com/pytorch/pytorch/pull/143846

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159702
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-08-03 18:45:07 +00:00
9a680e14b7 [bucketing] Reduce CPU overhead for reduce_scatter_merge_fn_to_trace (#159723)
The previous implementation was creating `n_gpu * n_tensors` intermediate tensors, which was adding a lot of CPU overhead, specially given that inductor was generating a number of individual tensor copy kernels for `torch.cat` .

This PR changes the implementation so that only `n_tensors` are created, making the CPU overhead proportional to the number of tensors being bucketed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159723
Approved by: https://github.com/IvanKobzarev
2025-08-03 09:16:55 +00:00
805a102beb Revert "[dynamo][guards] Make class members go through obj.__class__.__dict__ (#159534)"
This reverts commit 1616777cd2a3170ff76afa3e7860b0969420c445.

Reverted https://github.com/pytorch/pytorch/pull/159534 on behalf of https://github.com/malfet due to Broke some inductor test and lint among other things, see 9c18901bfd/1 ([comment](https://github.com/pytorch/pytorch/pull/159534#issuecomment-3146983186))
2025-08-03 04:58:32 +00:00
6e8d705a22 Revert "[dynamo] Be consistent with storing func source for UserMethodVariable (#159696)"
This reverts commit be71000ff5292293d1976f313218e2df4d5046d3.

Reverted https://github.com/pytorch/pytorch/pull/159696 on behalf of https://github.com/malfet due to Broke some inductor test and lint among other things, see 9c18901bfd/1 ([comment](https://github.com/pytorch/pytorch/pull/159534#issuecomment-3146983186))
2025-08-03 04:58:32 +00:00
9c18901bfd [MTIA Aten Backend] Migrate all.out (#159539)
Differential Revision: [D79317033](https://our.internmc.facebook.com/intern/diff/D79317033/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159539
Approved by: https://github.com/malfet
ghstack dependencies: #159098
2025-08-03 02:08:35 +00:00
a29ed5e1ac Add torch compile force disable caches alias (#158072)
Bunch of people keep thinking current alias only disables inductor cache because it has the name inductor in it. lets globalize the name

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158072
Approved by: https://github.com/ezyang
2025-08-02 23:23:17 +00:00
d2792f51b2 [bucketing] Use max of input/output size for bucketing (#159717)
The output of a reduce_scatter is n_gpu times smaller than its input, while the output of an all_gather is n_gpu times larger than its input. This means that in the current heuristic for bucketing reduce_scatter, we would need to use a bucket size which is n_gpu times larger than the bucket for all_gather, making it gpu-dependent and less intuitive. This PRs propose to use instead the max between the input and output sizes, so that one can use the same bucket_size value for both passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159717
Approved by: https://github.com/wconstab
2025-08-02 22:42:22 +00:00
be71000ff5 [dynamo] Be consistent with storing func source for UserMethodVariable (#159696)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159696
Approved by: https://github.com/jansel
ghstack dependencies: #159186, #159534
2025-08-02 21:40:38 +00:00
3f86076775 gc before warming up benchmarking (#159670)
#158649 turned off automatic GCs during cudagraph recording. This is causing a small uptick in some internal benchmark numbers because of memory the benchmark is leaving around before the benchmark starts - so GC before warming up the model.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159670
Approved by: https://github.com/oulgen
2025-08-02 19:37:24 +00:00
1616777cd2 [dynamo][guards] Make class members go through obj.__class__.__dict__ (#159534)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159534
Approved by: https://github.com/jansel
ghstack dependencies: #159186
2025-08-02 18:04:35 +00:00
38895c0ac2 Update RuntimeError message in is_nonzero(input) method from bool to Boolean (#159712)
RuntimeError message updated in is_nonzero(input) method from bool to Boolean.

**Case 1:**
t = torch.tensor([])
torch.is_nonzero(t)

**Case 2:**
t = torch.tensor([1,2])
torch.is_nonzero(t)

**Existing Error message in documentation:**

for case 1: RuntimeError: bool value of Tensor with no values is ambiguous
for case 2: RuntimeError: bool value of Tensor with more than one value is ambiguous

**Proposed Error message in documentation:**

for case 1: RuntimeError: Boolean value of Tensor with no values is ambiguous
for case 2: RuntimeError: Boolean value of Tensor with more than one value is ambiguous

Fixes #159710
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159712
Approved by: https://github.com/malfet
2025-08-02 17:23:45 +00:00
310f901a71 Stop parsing command line arguments every time common_utils is imported. (#156703)
Last PR in the series to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs:

https://github.com/pytorch/pytorch/pull/154612
https://github.com/pytorch/pytorch/pull/154628
https://github.com/pytorch/pytorch/pull/154715
https://github.com/pytorch/pytorch/pull/154716
https://github.com/pytorch/pytorch/pull/154725
https://github.com/pytorch/pytorch/pull/154728

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156703
Approved by: https://github.com/clee2000
2025-08-02 16:38:54 +00:00
e11b1cd97e [ROCm] fix nightly wheel due to rocBLAS environment variable (#159570)
Fixes #159070

The TunableOp failure is due to missing rocBLAS files in our manywheels packaging. This bug has been present since June 7-8 time frame. It was caused by a typo in the rocBLAS environment variable that stores the list of files. It was introduced in this PR: https://github.com/pytorch/pytorch/pull/155388

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159570
Approved by: https://github.com/malfet
2025-08-02 06:54:43 +00:00
b599d91738 Log autotune choices and benchmark result to scuba/chrome trace (#159496)
Summary:
Report the kernel choices and benchmark data to better understand how kernels are selected and the performance gap between the best kernel (likely a CUDA kernel) and Triton kernels.

**Example**

Event: mm_template_autotuning
Column: autotune_choices

```json
{
  "num_choices": 52,
  "num_triton_choices": 19,
  "best_kernel": "cutlass_f6c25cf2",
  "best_kernel_desc": "cutlass3x_sm90_tensorop_gemm_f16_f16_f32_void_f16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=8",
  "best_time": 0.6283040046691895,
  "best_triton_pos": 26,
  "best_triton_time": 0.6832960247993469,
  "best_triton_kernel": "triton_mm_17",
  "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0"
}
```

Test Plan:
```
TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS =1 buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt
```

Rollback Plan:

Differential Revision: D79235037

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159496
Approved by: https://github.com/masnesral
2025-08-02 05:34:17 +00:00
fd6a6658c3 Enable _int_mm on Intel GPU (#157769)
# Moativation

This PR is used to enable _int_mm on Intel GPU. And _int_mm is used by int8 quantization on torchao.

# Model Test Result:
We run meta-llama/Llama-3.1-8B-Instruct on Intel GPU and A100 using torchao int8-dynamic-quantization. The model configs as below:
Precision : torch.bfloat16
quantization configuration : Int8DynamicActivationInt8WeightConfig
dataset : wikitext

Result:
The perplexity values for Intel GPU and A100 are 9.582953453063965 and 9.57755184173584, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157769
Approved by: https://github.com/EikanWang, https://github.com/desertfire
2025-08-02 05:16:01 +00:00
04973496a8 [audio hash update] update the pinned audio hash (#159611)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159611
Approved by: https://github.com/pytorchbot
2025-08-02 05:15:47 +00:00
1548b011ea Fix rand_like decomposition to preserve strides (#159294)
Summary: Like https://github.com/pytorch/pytorch/pull/158898, the rand_like variants are not preserving strides. Followed the pattern established in https://github.com/pytorch/pytorch/pull/158898.

Test Plan: New unit test (fails before this PR; but fixed after)

Differential Revision: [D79472604](https://our.internmc.facebook.com/intern/diff/D79472604)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159294
Approved by: https://github.com/eellison
2025-08-02 03:54:41 +00:00
e57a92734d [export] Fix nn_module_stack of assert_tensor_metadata nodes (#159625)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159625
Approved by: https://github.com/yushangdi
2025-08-02 02:52:42 +00:00
79ff3b320b Back out "[ez] get rid of unused var" (#159677)
Summary: turns out i added this to reduce the frequency we'd call try_update_max_size_at_index when a new maximum is found before the replan is called. oops.

Test Plan:
backout

Rollback Plan:

Differential Revision: D79474114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159677
Approved by: https://github.com/georgiaphillips
2025-08-02 01:50:16 +00:00
426f249f20 Fix launch grid calculation (#159497)
Summary:

The launch grid calculation code is using a python trick to achieve CeilDiv() through negative integer division with FloorDiv(). This is language dependent behaviour that doesn't apply to all languages.

In the FXIR backend we negate this behaviour and replace the experssion with CeilDiv() operation so the computation is correct regardless of language used. Not directly directly changing the orginal computation as it leads to a performance degredation.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79275534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159497
Approved by: https://github.com/blaine-rister
2025-08-02 01:12:58 +00:00
d33a484763 Use boxed_nop_preserve_node_meta for aot_export_joint_with_descriptors (#159545)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159545
Approved by: https://github.com/xmfan, https://github.com/wconstab
ghstack dependencies: #159336, #159337
2025-08-02 00:33:41 +00:00
a81ffbc5f5 improve shape checks for grouped_mm (#159666)
Check that contraction dimension matches between tensors if it's known, and do device-side checks for correct offsets
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159666
Approved by: https://github.com/danielvegamyhre, https://github.com/eqy
2025-08-02 00:12:25 +00:00
465fe4d9f7 Enable sample nightly PT2 benchmark on B200 (#158011)
Per the discussion with @nWEIdia, this resumes the work on https://github.com/pytorch/pytorch/pull/157870 to enable PT2 benchmark on B200

### Testing

https://github.com/pytorch/pytorch/actions/runs/16615101382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158011
Approved by: https://github.com/nWEIdia, https://github.com/atalman
2025-08-01 23:47:44 +00:00
9477af1063 fix compilation on cuda < 12.3 (#159657)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159657
Approved by: https://github.com/kwen2501
2025-08-01 23:40:55 +00:00
dcc36e38bb [Graph Breaks] Remove unsupported Additional Info field (#159658)
Race condition when landing PR#158800 caused us to add this field when it is deprecated, so remove it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159658
Approved by: https://github.com/williamwen42
2025-08-01 23:25:50 +00:00
efd78584a8 [EZ] Add linux-aarch64.yml workflow to the viable/strict blocking set (#159668)
Since it's required to be run on every PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159668
Approved by: https://github.com/malfet
2025-08-01 23:19:08 +00:00
135762ea20 Unpin helion (#159579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159579
Approved by: https://github.com/jansel
2025-08-01 23:08:06 +00:00
e2ee9cfaa2 [NativeRT] Turn on enableStaticCPUKernels by default (#159422)
Summary: As title.

Test Plan:
Need to manual test on production models.

Rollback Plan:

Differential Revision: D78747742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159422
Approved by: https://github.com/dolpm
2025-08-01 22:27:07 +00:00
06d28de17a Update CK Kernel generation and update ck submodule (#157964)
changes required to reduce the number of ck kernels generated. This change depends on https://github.com/ROCm/composable_kernel/pull/2480 to be merged first.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157964
Approved by: https://github.com/842974287
2025-08-01 22:24:27 +00:00
df9720b8b5 [MTIA Aten Backend] Migrate all foreach ops (#159098)
# Context

See the first PR https://github.com/pytorch/pytorch/pull/153670

# This diff

 Migrate all foreach operators to in-tree, including:
  - _foreach_abs
  - _foreach_abs_
  - _foreach_add.List
  - _foreach_add_.List
  - _foreach_add_.Scalar
  - _foreach_add_.Tensor
  - _foreach_addcmul.Scalar
  - _foreach_addcmul_.Scalar
  - _foreach_copy
  - _foreach_copy_
  - _foreach_mul.List
  - _foreach_mul_.List
  - _foreach_mul_.Scalar
  - _foreach_mul.Tensor
  - _foreach_mul_.Tensor
  - _foreach_norm.Scalar
  - _foreach_sqrt_

Differential Revision: [D78913847](https://our.internmc.facebook.com/intern/diff/D78913847/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159098
Approved by: https://github.com/malfet
2025-08-01 22:10:12 +00:00
85e74d5ace [inductor] Add logging for distributed collective ops for multi‑rank diagnostics (#159190)
This change introduces structured logging of the collective communication schedule, enabling downstream tools (e.g. TLParse) to ingest and analyze per‑rank collective‐order information for multi‑rank jobs.

- Iterates over scheduler.nodes, filters for _CollectiveKernel nodes
- Extracts each op’s python_kernel_name
- Emits a structured JSON payload under the inductor_collective_schedule artifact name
- Dumps the full schedule list to collective_schedule.json via the PyTorch trace‑structured artifact
- Added comprehensive unit tests for collective schedule tracing: Created test_collective_schedule_empty() and test_collective_schedule_real() tests to verify structured trace logging works correctly for both empty collective schedules and real collective operations (like all_reduce and wait_tensor from _c10d_functional ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159190
Approved by: https://github.com/yushangdi, https://github.com/xmfan
2025-08-01 21:51:42 +00:00
0450f05658 Output tensor meta data for FX graph node (#159311)
FX graph segment in CompiledFxGraph does not include tensor meta data, for example, tensor shape, tensor stride, tensor data type, tensor device. AI system co-design team requested to include these information in FX graph segment so they can use FX graph segment to project the performance on different hardware.
This DIFF is to modify the Graph::Node::format_node to include tensor meta data.
Before this DIFF, the triton kernel FX graph segment looks like the following:
```
# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]
# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})
# %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})
# %cos : cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})
# return %cos
After this DIFF:
# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]
# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})
# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})
# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})
# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})
# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})
# return %cos
```
If format_node can not be changed, I can copy the code to caffe2/torch/_inductor/utils.py.

Differential Revision: D77973076

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159311
Approved by: https://github.com/angelayi
2025-08-01 21:40:29 +00:00
595a65f5c2 [dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/script_object.py (#159343)
Fixes part of #147913

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159343
Approved by: https://github.com/williamwen42

Co-authored-by: William Wen <william.wen42@gmail.com>
2025-08-01 21:30:41 +00:00
8c6c2e40eb Edit a test case to detect potential bugs in all-gathering noncontiguous inputs in the Gloo backend (#159542)
As suggested in the pull request #158903 by @H-huang, this pull request edits a test case to detect potential bugs in all-gathering noncontiguous inputs in the Gloo backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159542
Approved by: https://github.com/d4l3k, https://github.com/H-Huang
2025-08-01 21:20:25 +00:00
32840d19f9 [cutlass backend] skip stream k if shape is dynamic (#159442)
Differential Revision: [D79229210](https://our.internmc.facebook.com/intern/diff/D79229210/)

Motivation is workspace size is hard to determine, and varies for different shape. What I observed is sometimes the shape got smaller, but the workspace can increase. So it is hard to upper bound it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159442
Approved by: https://github.com/ColinPeppler
2025-08-01 20:42:24 +00:00
2040f00112 [BE][Easy] respect os.environ in subprocess calls in tools/nightly.py (#159572)
Respect parent shell's envvars, such as `UV_INDEX_STRATEGY`, `http{,s}_proxy`, etc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159572
Approved by: https://github.com/Skylion007
2025-08-01 20:40:31 +00:00
c137f9da0b [Dynamo][Better Engineering] Add type coverage to dynamo/compiled_autograd.py (#159518)
As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to `torch/_dynamo/compiled_autograd.py`

Running
```
mypy torch/_dynamo/compiled_autograd.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  425 | 1553 | 27.37% | 17 | 62 | 27.42% |
| This PR | 1623 | 1623 | 100.00% | 62 | 62 | 100.00% |
| Delta    | +1198| +0 | +72.63% | +45 | 0 | +72.58% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159518
Approved by: https://github.com/xmfan
2025-08-01 20:24:58 +00:00
5e8b95605f [PP] Support OVERLAP_F_B computation type (#158978)
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see https://github.com/pytorch/pytorch/pull/159591)

The IR looks like:

```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```

In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.

814629043a/torch/distributed/pipelining/schedules.py (L1205-L1216)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158978
Approved by: https://github.com/wconstab
2025-08-01 20:22:30 +00:00
8ea86a6e31 Actually test STD_TORCH_CHECK, add testfile to CMake (#159603)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159603
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-08-01 19:53:41 +00:00
acad808545 Revert "[inductor] consolidate common GEMM triton param retrieval (#159383)"
This reverts commit e7cc42df58a86bee05944f6e80c535aa1d099443.

Reverted https://github.com/pytorch/pytorch/pull/159383 on behalf of https://github.com/jataylo due to sorry but rocm CI is broken due to this PR ([comment](https://github.com/pytorch/pytorch/pull/159383#issuecomment-3145604831))
2025-08-01 19:49:21 +00:00
c687446374 Revert "Fix rand_like decomposition to preserve strides (#159294)"
This reverts commit 2c46922ce4b33c39b1c48c302604805510a3f889.

Reverted https://github.com/pytorch/pytorch/pull/159294 on behalf of https://github.com/yangw-dev due to breaking internal test ([comment](https://github.com/pytorch/pytorch/pull/159294#issuecomment-3145541845))
2025-08-01 19:19:51 +00:00
dd22ba09b4 [C10D] Document barrier interaction with device_id (#159389)
Addresses #159262

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159389
Approved by: https://github.com/malfet, https://github.com/H-Huang, https://github.com/kwen2501, https://github.com/fduwjj
2025-08-01 18:12:21 +00:00
c0e0126399 Remove unused input parameter in ExpandableSegment (#159356)
# Motivation
While refactoring the caching allocator, I noticed that the `ExpandableSegment` constructor on CUDA had an unused parameter. This change removes that unused argument to avoid potential confusion.

# Additional Context
I noticed that `ExpandableSegment` is defined in cpp file, so it should be safe to make this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159356
Approved by: https://github.com/ngimel, https://github.com/albanD
ghstack dependencies: #159159
2025-08-01 17:47:51 +00:00
e4b123b5e4 Revert direct updates (#159654)
reverts:
```

commit 5711a8f06948eeee56ed5f53f171fa519f78491c (tag: trunk/5711a8f06948eeee56ed5f53f171fa519f78491c, origin/main, main)
Author: Jovian Anthony Jaison <38627145+jovianjaison@users.noreply.github.com>
Date:   Fri Aug 1 09:32:52 2025 -0700

    Update test_utils.py

commit b4b71d011ed07a41c2086ff0dec2988a63662877 (tag: trunk/b4b71d011ed07a41c2086ff0dec2988a63662877)
Author: Jovian Anthony Jaison <38627145+jovianjaison@users.noreply.github.com>
Date:   Fri Aug 1 09:27:54 2025 -0700

    Update utils.py

commit 52376b9b6fbf9fe24f5d82038dc520f0c64b6f8d (tag: trunk/52376b9b6fbf9fe24f5d82038dc520f0c64b6f8d)
Author: Jovian Anthony Jaison <38627145+jovianjaison@users.noreply.github.com>
Date:   Fri Aug 1 09:26:05 2025 -0700
```

(commits pushed directly to main by mistake)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159654
Approved by: https://github.com/atalman
2025-08-01 16:54:51 +00:00
5711a8f069 Update test_utils.py 2025-08-01 09:32:52 -07:00
b4b71d011e Update utils.py 2025-08-01 09:27:54 -07:00
52376b9b6f Update convert_frame.py 2025-08-01 09:26:05 -07:00
1371a98b0e Migrate ScalarType to headeronly (#159416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159416
Approved by: https://github.com/albanD
ghstack dependencies: #159415, #159411
2025-08-01 16:07:01 +00:00
2a286cbdf4 Allow register_buffer with Tensor-like object (#159455)
As torch allows extending the tensor with `__torch_function__`, it would be desirable to allow registering it as a buffer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159455
Approved by: https://github.com/mikaylagawarecki
2025-08-01 15:31:38 +00:00
7c37b8e1e0 [ROCm][Windows] Switch __builtin_clz ifdef from WIN32 to MSC_VER. (#159273)
PyTorch with ROCm on Windows is built with clang-cl and not MSVC. This code path is specific to the MSVC compiler so it should be checking for MSC_VER, not just WIN32. The change here is similar to https://github.com/pytorch/pytorch/pull/146606.

This fixes downstream build errors using clang-cl like https://github.com/ROCm/TheRock/actions/runs/16569646709/job/46858176812 (patched and tested downstream at https://github.com/ROCm/TheRock/pull/1140):
```
[7099/7147] Building CXX object functorch\CMakeFiles\functorch.dir\csrc\dim\dim.cpp.obj
FAILED: functorch/CMakeFiles/functorch.dir/csrc/dim/dim.cpp.obj
C:\home\runner\_work\_tool\Python\3.11.9\x64\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\clang-cl.exe  /nologo -TP -DEXPORT_AOTI_FUNCTIONS -DFUNCTORCH_BUILD_MAIN_LIB -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNOMINMAX -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DROCM_ON_WINDOWS -DROCM_USE_FLOAT16 -DROCM_VERSION=70000 -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -DTORCH_HIP_VERSION=700 -DUSE_EXTERNAL_MZCRC -DUSE_MIMALLOC -DUSE_PROF_API=1 -DWIN32_LEAN_AND_MEAN -D_CRT_SECURE_NO_DEPRECATE=1 -D_UCRT_LEGACY_INFINITY -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -Dfunctorch_EXPORTS -IB:\src\torch\build\aten\src -IB:\src\torch\aten\src -IB:\src\torch\build -IB:\src\torch -IB:\src\torch\nlohmann -IB:\src\torch\moodycamel -IB:\src\torch\third_party\mimalloc\include -IB:\src\torch\functorch -IB:\src\torch\torch\csrc\api -IB:\src\torch\torch\csrc\api\include -IB:\src\torch\c10\.. -IB:\src\torch\c10\hip\..\.. -IB:\src\torch\torch\.. -IB:\src\torch\torch\..\aten\src -IB:\src\torch\torch\..\aten\src\TH -IB:\src\torch\build\caffe2\aten\src -IB:\src\torch\build\third_party -IB:\src\torch\build\third_party\onnx -IB:\src\torch\torch\..\third_party\valgrind-headers -IB:\src\torch\torch\..\third_party\gloo -IB:\src\torch\torch\..\third_party\onnx -IB:\src\torch\torch\..\third_party\flatbuffers\include -IB:\src\torch\torch\..\third_party\kineto\libkineto\include -IB:\src\torch\torch\..\third_party\cpp-httplib -IB:\src\torch\torch\..\third_party\nlohmann\include -IB:\src\torch\torch\csrc -IB:\src\torch\torch\lib -IB:\src\torch\torch\standalone -IB:\src\torch\torch\lib\libshm_windows -imsvcC:\home\runner\_work\_tool\Python\3.11.9\x64\Lib\site-packages\_rocm_sdk_devel\include -imsvcB:\src\torch\third_party\protobuf\src -imsvcB:\src\torch\third_party\XNNPACK\include -imsvcB:\src\torch\third_party\ittapi\include -imsvcB:\src\torch\cmake\..\third_party\eigen -imsvcB:\src\torch\third_party\ideep\mkl-dnn\include\oneapi\dnnl -imsvcB:\src\torch\third_party\ideep\include -imsvcB:\src\torch\INTERFACE -imsvcB:\src\torch\third_party\nlohmann\include -imsvcB:\src\torch\third_party\concurrentqueue -imsvcC:\home\runner\_work\_tool\Python\3.11.9\x64\Lib\site-packages\_rocm_sdk_devel\include\hiprand -imsvcC:\home\runner\_work\_tool\Python\3.11.9\x64\Lib\site-packages\_rocm_sdk_devel\include\rocrand -imsvcB:\src\torch\cmake\..\third_party\pybind11\include -imsvcC:\home\runner\_work\_tool\Python\3.11.9\x64\include /DWIN32 /D_WINDOWS /EHsc /Zc:__cplusplus /bigobj /FS /utf-8 -DUSE_PTHREADPOOL -DNDEBUG -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE /wd4624 /wd4068 /wd4067 /wd4267 /wd4661 /wd4717 /wd4244 /wd4804 /wd4273 /O2 /Ob2 /DNDEBUG /bigobj -DNDEBUG -std:c++17 -MD -Z7 -Wmissing-prototypes -Werror=missing-prototypes /permissive- /d2implyavx512upperregs- /EHsc /bigobj -fms-runtime-lib=dll -D__HIP_PLATFORM_AMD__=1 -DCUDA_HAS_FP16=1 -DUSE_ROCM -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DTORCH_HIP_VERSION=700 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIPBLAS_V2 -DHIP_ENABLE_WARP_SYNC_BUILTINS -fms-extensions -Wno-ignored-attributes /showIncludes /Fofunctorch\CMakeFiles\functorch.dir\csrc\dim\dim.cpp.obj /Fdfunctorch\CMakeFiles\functorch.dir\ -c -- B:\src\torch\functorch\csrc\dim\dim.cpp
clang-cl: warning: unknown argument ignored in clang-cl: '-std=c++17' [-Wunknown-argument]
clang-cl: warning: argument unused during compilation: '/d2implyavx512upperregs-' [-Wunused-command-line-argument]
In file included from B:\src\torch\functorch\csrc\dim\dim.cpp:36:
B:\src\torch\functorch\csrc\dim\arena.h(14,21): error: functions that differ only in their return type cannot be overloaded
   14 | inline unsigned int __builtin_clz(unsigned int x) {
      |        ~~~~~~~~~~~~ ^
C:\home\runner\_work\_tool\Python\3.11.9\x64\Lib\site-packages\_rocm_sdk_devel\lib\llvm\lib\clang\20\include\ia32intrin.h(60,15): note: '__builtin_clz' is a builtin with type 'int (unsigned int) noexcept'
   60 |   return 31 - __builtin_clz((unsigned int)__A);
      |               ^
1 error generated.
[7100/7147] Building CXX object caffe2\torch\CMakeFiles\torch_python.dir\csrc\utils\tensor_list.cpp.obj
```

> [!NOTE]
> I haven't been able to reproduce those errors locally, but we have CI jobs that consistently fail when building for Python 3.11 but not 3.12 or 3.13. I'm not sure what is different between those builds, but the code fix seems correct.

There are a few other variations on fixes to this floating around, such as:
* a97a957af0/lz4.c (L34-L43) (checking with `__has_builtin`)
* c98c55ec7e/lj92.c (L31-L46) (the same code as here, but with `_MSC_VER`)
* 2760e5a2bb/def.h (L23-L25) (using `__lzcnt` instead of a custom implementation)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159273
Approved by: https://github.com/Skylion007, https://github.com/m-gallus
2025-08-01 15:21:26 +00:00
ee2649219c Fix max_width computation in _tensor_str._Formatter (#126859)
Previous version of `torch._tensor_str._Formatter` was not using `PRINT_OPTS.sci_mode` for the `max_width` computation but was using it for the formatting of values leading to a weird discrepancy.

Now, the code first checks if it should be in sci_mode, then compute `max_width`

Here is an example to test the behavior:
```python
A = torch.tensor([10, 1e-1, 1e-2])
B = torch.tensor([10, 1e-1, 1e-1])

print("================= Default =================")
print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")

print("================= sci_mode=False =================")
with torch._tensor_str.printoptions(sci_mode=False):
    print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
    print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")

print("================= sci_mode=True =================")
with torch._tensor_str.printoptions(sci_mode=True):
    print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
    print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")
```

In the current version this prints:
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([   10.0000,     0.1000,     0.0100]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 7
```

On can see that in `sci_mode=False`, the values of A are prefixed with unneeded 0 and does not have the same `max_width` as B (It keeps the `max_width` from `sci_mode = None`)

Also in `sci_mode = True`, for B, the `max_width` is 7 but each value takes 10 chars... (But it is fine as the code that uses `max_width` do not rely much on it, but still, this is missleading)

After this commit, this will print
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([10.0000,  0.1000,  0.0100]) Formatter max_width: 7
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 10
```

This also allows to align A with B for `sci_mode=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126859
Approved by: https://github.com/malfet
2025-08-01 15:05:41 +00:00
b0b3e6e48b [PP] Refactor test_schedule_multiproc (#158780)
This refactors the pipelining schedule tests since a lot of them have the same repeated code of:
1. Create pipelined model and reference model
2. Run reference model and pipelined model
3. compare gradients

So this refactors those parts above into helper methods and reduces ~300 LOC. Also adds a better gradient check to resolve flakiness (fixes https://github.com/pytorch/pytorch/issues/154408).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158780
Approved by: https://github.com/wconstab
2025-08-01 15:02:18 +00:00
3967dbedf4 [ContextParallel][FlexAttention] Prototype of supporting FlexAttention in Context Parallel (#158692)
**Summary**
This PR adds an all-gather based FlexAttention and uses TorchFunctionMode to dispatch
`FlexAttentionHOP.__call__` to it.

This PR makes the following changes:

- add a user-facing API `create_cp_block_mask` for creating CP-specific `BlockMask`
which masks over the attention result of Q shard and KV global.
- add `_ContextParallelGlobalVars` to store all necessary global vars that CP FlexAttention
requires. `torch_function_mode` is critical to maintain singleton mode to avoid dynamo
recompilations.
- add a dispatch path for `FlexAttentionForwardHOP.__call__` (TorchFunctionMode dispatch
won't work correctly without this line)

What's not in this PR:
- QKV load balancing
- Test on other masking besides `causal_mask`.
- Support on small attention (i.e. qkv size is smaller than 128) because the block mask
rewrite function requires `Q_BLOCK_SIZE == KV_BLOCK_SIZE == 128`.

**Test**
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

**Followup**
1. create an issue to reproduce the error in `create_fw_bw_graph()` when trying to call `create_block_mask`
to re-write `block_mask` in `FlexAttentionHOP` dispatch in `TorchFunctionMode`.
2. Merge `_ContextParallelGlobalVars` and `_cp_options`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158692
Approved by: https://github.com/drisspg
2025-08-01 06:49:01 +00:00
4396b15aa7 remove co_lnotab in favor of co_linetable (#159227)
Fixes #158833
DeprecationWarning: remove co_lnotab in favor of co_linetable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159227
Approved by: https://github.com/ezyang
2025-08-01 06:34:38 +00:00
bb6766053b fix strategy hashing arg mismatch (#159506)
Reland https://github.com/pytorch/pytorch/pull/159289.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159506
Approved by: https://github.com/XilunWu
2025-08-01 05:42:40 +00:00
a4fc051c9a Fix a bug of distributed 'gather' with noncontiguous tensors on the NCCL backend. (#159549)
Fixes #159548

* Throw an error message when the input tensors for the distributed `gather` are noncontiguous. This behaviour is consistent with the distributed `all_gather`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159549
Approved by: https://github.com/d4l3k
2025-08-01 03:26:06 +00:00
5cc6a0abc1 Revert "Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312)"
This reverts commit dfacf11f66d6512396382bdf5088f0ba9de00406.

Reverted https://github.com/pytorch/pytorch/pull/150312 on behalf of https://github.com/guangyey due to Static initialization order issue impact the downstream repo ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3142035444))
2025-08-01 03:24:54 +00:00
90f13f3b2a Revert "Deprecate overleap functions in CUDAAllocatorConfig, use AcceleratorAllocatorConfig instead (#156165)"
This reverts commit 1fc010a9d8ea95bb74e54b31d17eba56ef16c27c.

Reverted https://github.com/pytorch/pytorch/pull/156165 on behalf of https://github.com/guangyey due to Static initialization order issue impact the downstream repo ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3142035444))
2025-08-01 03:24:54 +00:00
cb9b74872b Revert "Generalize torch._C._set_allocator_settings to be generic (#156175)"
This reverts commit d3ce45012ed42cd1e13d5048b046b781f0feabe0.

Reverted https://github.com/pytorch/pytorch/pull/156175 on behalf of https://github.com/guangyey due to Static initialization order issue impact the downstream repo ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3142035444))
2025-08-01 03:24:54 +00:00
c964204829 [CI] Disable executorch jobs (#159595)
The current executorch pin needs to be updated

The next time the docker image gets rebuilt, the executorch docker build is going to fail like https://github.com/pytorch/pytorch/actions/runs/16626853655/job/47137807966

The failure is that the pin uses a version of the nightly that has been removed from the nightly index
```
#62 72.30 ERROR: Could not find a version that satisfies the requirement torch==2.8.0.dev20250601 (from versions: 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0, 2.7.1, 2.8.0.dev20250602+cpu, 2.8.0.dev20250603+cpu, 2.8.0.dev20250604+cpu, 2.8.0.dev20250605+cpu, 2.8.0.dev20250606+cpu, 2.8.0.dev20250607+cpu, 2.8.0.dev20250608+cpu, 2.8.0.dev20250609+cpu, 2.8.0.dev20250610+cpu, 2.8.0.dev20250611+cpu, 2.8.0.dev20250612+cpu, 2.8.0.dev20250613+cpu, 2.8.0.dev20250614+cpu, 2.8.0.dev20250615+cpu, 2.8.0.dev20250616+cpu, 2.8.0.dev20250617+cpu, 2.8.0.dev20250618+cpu, 2.8.0.dev20250619+cpu, 2.8.0.dev20250620+cpu, 2.8.0.dev20250621+cpu, 2.8.0.dev20250622+cpu, 2.8.0.dev20250623+cpu, 2.8.0.dev20250624+cpu, 2.8.0.dev20250625+cpu, 2.8.0.dev20250626+cpu, 2.8.0.dev20250627+cpu, 2.9.0.dev20250628+cpu, 2.9.0.dev20250629+cpu, 2.9.0.dev20250630+cpu, 2.9.0.dev20250701+cpu, 2.9.0.dev20250702+cpu, 2.9.0.dev20250703+cpu, 2.9.0.dev20250704+cpu, 2.9.0.dev20250705+cpu, 2.9.0.dev20250706+cpu, 2.9.0.dev20250707+cpu, 2.9.0.dev20250708+cpu, 2.9.0.dev20250709+cpu, 2.9.0.dev20250710+cpu, 2.9.0.dev20250711+cpu, 2.9.0.dev20250712+cpu, 2.9.0.dev20250713+cpu, 2.9.0.dev20250714+cpu, 2.9.0.dev20250715+cpu, 2.9.0.dev20250716+cpu, 2.9.0.dev20250717+cpu, 2.9.0.dev20250718+cpu, 2.9.0.dev20250719+cpu, 2.9.0.dev20250720+cpu, 2.9.0.dev20250722+cpu, 2.9.0.dev20250723+cpu, 2.9.0.dev20250724+cpu, 2.9.0.dev20250725+cpu, 2.9.0.dev20250726+cpu, 2.9.0.dev20250727+cpu, 2.9.0.dev20250728+cpu, 2.9.0.dev20250729+cpu, 2.9.0.dev20250730+cpu, 2.9.0.dev20250731+cpu)
#62 72.30 ERROR: No matching distribution found for torch==2.8.0.dev20250601
```

The executorch hash update currently fails due to https://github.com/pytorch/pytorch/actions/runs/16636773244/job/47079169392
```
2025-07-31T01:56:57.0249165Z + echo 'expecting triton to not be installed, but it is'
2025-07-31T01:56:57.0249614Z expecting triton to not be installed, but it is
2025-07-31T01:56:57.0249969Z + exit 1
2025-07-31T01:58:27.6764352Z ##[error]Final attempt failed. Child_process exited with error code 1
```
I believe the cause is https://github.com/pytorch/executorch/pull/11653 where the nightly pytorch is installed from our index, but then requirements-examples installs timm from pypi, which reinstalls pytorch, except its the release build for cuda from pypi?  Which then causes triton to be installed.

I don't know what the intended behavior is so I'm disabling the executorch docker build, executorch build, and the nightly hash update, and apparently the test was already disabled because it was failing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159595
Approved by: https://github.com/malfet
2025-08-01 02:18:03 +00:00
2ac45c2752 Fix autocast context manager when there is exception (#159565)
Summary: When exception occurs inside context manager, we need to either return False OR properly propagage exceptions via __exit__(exc_type, exc_val). But previously while tracing, we don't actually run the exit node so we end up swallowing the exception in a very weird way as outlined in https://github.com/pytorch/pytorch/issues/153202. This PR fixes it

Test Plan:
new test case

Rollback Plan:

Differential Revision: D79348382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159565
Approved by: https://github.com/zou3519, https://github.com/yushangdi
2025-08-01 02:12:24 +00:00
83e2ea8135 [CPU] fix _weight_int8pack_mm with large output shape (#158341)
**Summary**
`_weight_int8pack_mm` on CPU may cause segmentation fault if output shape is large (i.e., M * N is large). It's because the kernel compute output buffer address by
```c++
auto* C_ptr = C_data + mb_start * N + nb_start;
```
where both `mb_start` and `N` are `int` and when they are large their product may overflow.
The solution is simple: declare these variables as `int64_t` so that the product won't overflow.

**Test plan**
```
pytest -sv test/test_linalg.py -k test__int8_mm_large_shape
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158341
Approved by: https://github.com/mingfeima, https://github.com/drisspg
2025-08-01 01:55:48 +00:00
d994027a41 [Doc fix] fix spelling of enough (#159587)
fixes typo in word `enought` to correct `enough` at 3 places in these files
```
aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu
aten/src/ATen/native/cuda/CuFFTPlanCache.h
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159587
Approved by: https://github.com/ezyang
2025-08-01 01:50:57 +00:00
cb4f41e125 Revert "[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (#157566)"
This reverts commit 8e07c9870d07c5a318ab21bb16b3fa27576851e6.

Reverted https://github.com/pytorch/pytorch/pull/157566 on behalf of https://github.com/yangw-dev due to failed an odd internal test, please reach out to metamate to fix it, D79112610 ([comment](https://github.com/pytorch/pytorch/pull/157566#issuecomment-3141840110))
2025-08-01 01:27:45 +00:00
690fc9cf88 [merge_rules] add some expected failure and skips (#159581)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159581
Approved by: https://github.com/anijain2305
2025-08-01 01:18:40 +00:00
eb853e222b [cutlass upgrade] Ignore unused-but-set-variable for AsyncMM.cu (#159578)
Fixes inductor-perf-nightly-h100. This was caused by cutlass upgrade https://github.com/pytorch/pytorch/pull/158854. I missed it in https://github.com/pytorch/pytorch/pull/159276

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159578
Approved by: https://github.com/Skylion007
2025-08-01 00:10:59 +00:00
06395276e4 Remove dynamo_timed from the CachingAutotuner.coordinate_descent_tuning() hot path. (#159588)
Summary: When coordinate_descent_tuning==True, CachingAutotuner.coordinate_descent_tuning() is called for every call of CachingAutotuner.run() (at least for Triton templates), but immediately returns the launcher. Move the dynamo_timed call after the check for triton template so we don't incur the context manager overhead on every call.

Fixes https://github.com/pytorch/pytorch/issues/159525

Test Plan: Used the repro in https://github.com/pytorch/pytorch/issues/159525 to make sure the overhead goes away.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159588
Approved by: https://github.com/eellison
2025-07-31 23:33:10 +00:00
8becf646ef [dynamo] Make filter handle None as filter function (#159500)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159500
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
ghstack dependencies: #158774, #159102
2025-07-31 23:28:57 +00:00
fa68216ca1 [itertools] Implement itertools.cycle with a polyfill (#159102)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159102
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
ghstack dependencies: #158774
2025-07-31 23:28:57 +00:00
25ef3d315d [aoti][mps] Dynamic reductions (#159355)
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    constant long& r0_numel,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
        int r0_1 = 1024 * r0_1_cnt + r0_index;
        if (r0_1 >= r0_numel) break;
        auto tmp0 = in_ptr0[x0 + 5*r0_1];
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
    if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}

void AOTInductorModel::run_impl(...) {
    ...
    auto arg0_1_size = arg0_1.sizes();
    int64_t s77 = arg0_1_size[0];
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Static kernel:
```cpp
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
    int x0 = xindex;
    auto tmp0 = in_ptr0[x0];
    auto tmp1 = in_ptr0[5 + x0];
    auto tmp3 = in_ptr0[10 + x0];
    auto tmp5 = in_ptr0[15 + x0];
    auto tmp2 = tmp0 + tmp1;
    auto tmp4 = tmp2 + tmp3;
    auto tmp6 = tmp4 + tmp5;
    out_ptr0[x0] = static_cast<float>(tmp6);
}

void AOTInductorModel::run_impl(...) {
    ...
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159355
Approved by: https://github.com/malfet
2025-07-31 23:15:02 +00:00
7e00f2ec9d [AOTI] add zero size consts asm handler (#159225)
Add `get_zero_consts_asm_code` to handle zero size consts to object.
This function is used to handle zero consts situation. Because cpp standard does not allow zero size array:
https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c
1. On Windows, MSVC will report error C2466:
https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466?view=msvc-170
So, we can use assmbely compiler to handle this situation.
2. On Windows, why not use Win32 asm to handle all path? Because ml64 only supports up to align `16`, it is
not aligned to pytorch's `64`. Reference: https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170
```
Packs structures on the specified byte boundary. The alignment can be 1, 2, 4, 8, or 16.
```
3. It function can handle zero size case on both Windows and Linux, as that:
    A. On Linux, we added `-pedantic` to disable zero size array on C++ compiler. 8e07c9870d/torch/_inductor/cpp_builder.py (L580)
    B. On Windows, msvc is not support zero size array by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159225
Approved by: https://github.com/desertfire
2025-07-31 22:46:33 +00:00
490cb3f1a4 Revert "[inductor] Add logging for distributed collective ops for multi‑rank diagnostics (#159190)"
This reverts commit bb62e1f769ef51e2ec149d7256c135d09425aaa0.

Reverted https://github.com/pytorch/pytorch/pull/159190 on behalf of https://github.com/clee2000 due to broke [GH job link](https://github.com/pytorch/pytorch/actions/runs/16658705097/job/47150840171) [HUD commit link](bb62e1f769) on mac ([comment](https://github.com/pytorch/pytorch/pull/159190#issuecomment-3141513921))
2025-07-31 22:22:13 +00:00
b95cf5c91d Move complex to headeronly (#159411)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159411
Approved by: https://github.com/albanD
ghstack dependencies: #159415
2025-07-31 22:05:43 +00:00
5e2ef2a465 Move Float8 variations to headeronly (#159415)
This PR is a big copy pasta from `c10/util/Float8*` -> `torch/headeronly/util/` which is why we are breaking PR sanity :C (sorry @albanD!).

Why is it not a clean copy paste?
- For BC reasons, we have to keep the old c10 file around so that OSS devs relying on those files can still get the same APIs
- Because we reexpose APIs that are headeronly through torch::headeronly, so there is an extra chunk of code in the new torch::headeronly files to do that.

Outside of the copy paste, I:
- changed the tests to call torch::headeronly instead of c10
- updated header_only_apis.txt
- added `// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)` to pass lint (which was previously skipped for -inl.h files)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159415
Approved by: https://github.com/albanD
2025-07-31 22:05:43 +00:00
9f753f8c0d [DTensor] Improve sort strategy (#159189)
- Sort strategy now supports sharding on non sorted dim.
~~- Fix histc xfail.~~
  - ~~Previously `python test/distributed/tensor/test_dtensor_ops.py TestDTensorOpsCPU.test_dtensor_op_db_histc_cpu_float32` will fail with `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=18`. However, if we run `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=18 python test/distributed/tensor/test_dtensor_ops.py TestDTensorOpsCPU.test_dtensor_op_db_histc_cpu_float32`, the test will pass. This kind of error is due to DTensor reuses the strategy schema hashing. It turns out that not only the strategy,  the result correctness also depends on `static_argnum` or the op will reuse the previous args from hashed schema and output wrong results. I updated the document also.~~ (fixed in https://github.com/pytorch/pytorch/pull/159289)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159189
Approved by: https://github.com/XilunWu
2025-07-31 21:52:42 +00:00
db437690d1 Add myself as a reviewer for when someone touches headeronly or stable (#159583)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159583
Approved by: https://github.com/mikaylagawarecki
2025-07-31 21:30:05 +00:00
669009bcd1 [inductor] respect layout tags for ops with registered lowerings (#159134)
scaled_grouped_mm's kernel only supports column-major on the second operand. I -think- this is just for efficiency reasons. But inductor treats that buffer as flexible and may tweak the strides to be row-major instead, as seen in the issue.

~Tagging the op as "needs_fixed_stride_order"/"needs_exact_strides" does not work. Inductor only considers those tags for ops that don't have registered lowering (not sure if this is intended). scaled_grouped_mm does have a lowering, so we never check its tags.~ From discussion below, the op tags are expected to work.

FIXES https://github.com/pytorch/pytorch/issues/159097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159134
Approved by: https://github.com/eellison
2025-07-31 21:29:40 +00:00
e4e2701429 Add the RunLLM widget to the website (#152055)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152055
Approved by: https://github.com/albanD
2025-07-31 20:53:53 +00:00
64cc649275 [itertools] Fix accumulate (#158774)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158774
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
2025-07-31 20:32:02 +00:00
b1fb552974 Revert "Fix ep deepcopy when there is python builitin name (#159478)"
This reverts commit de7376537f2a11783169fee2b3bc276d266898bf.

Reverted https://github.com/pytorch/pytorch/pull/159478 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/159478#issuecomment-3141228423))
2025-07-31 20:20:53 +00:00
bb62e1f769 [inductor] Add logging for distributed collective ops for multi‑rank diagnostics (#159190)
This change introduces structured logging of the collective communication schedule, enabling downstream tools (e.g. TLParse) to ingest and analyze per‑rank collective‐order information for multi‑rank jobs.

- Iterates over scheduler.nodes, filters for _CollectiveKernel nodes
- Extracts each op’s python_kernel_name
- Emits a structured JSON payload under the inductor_collective_schedule artifact name
- Dumps the full schedule list to collective_schedule.json via the PyTorch trace‑structured artifact
- Added comprehensive unit tests for collective schedule tracing: Created test_collective_schedule_empty() and test_collective_schedule_real() tests to verify structured trace logging works correctly for both empty collective schedules and real collective operations (like all_reduce and wait_tensor from _c10d_functional ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159190
Approved by: https://github.com/yushangdi, https://github.com/xmfan
2025-07-31 19:58:07 +00:00
359 changed files with 12707 additions and 8114 deletions

View File

@ -144,16 +144,6 @@ case "$tag" in
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9)
CUDA_VERSION=12.6.3
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.12
@ -164,39 +154,6 @@ case "$tag" in
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.13
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.10
@ -219,18 +176,6 @@ case "$tag" in
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-py3.11-clang12)
ANACONDA_PYTHON_VERSION=3.11
CLANG_VERSION=12
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-py3.9-gcc9)
ANACONDA_PYTHON_VERSION=3.9
GCC_VERSION=9
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3)
if [[ $tag =~ "jammy" ]]; then
ANACONDA_PYTHON_VERSION=3.10

View File

@ -1 +1 @@
11ec6354315768a85da41032535e3b7b99c5f706
f7888497a1eb9e98d4c07537f0d0bcfe180d1363

View File

@ -68,8 +68,8 @@ function install_nvshmem {
# download, unpack, install
wget -q "${url}"
tar xf "${filename}.tar.gz"
cp -a "libnvshmem/include/"* /usr/local/include/
cp -a "libnvshmem/lib/"* /usr/local/lib/
cp -a "libnvshmem/include/"* /usr/local/cuda/include/
cp -a "libnvshmem/lib/"* /usr/local/cuda/lib64/
# cleanup
cd ..

View File

@ -15,11 +15,37 @@ function install_timm() {
commit=$(get_pinned_commit timm)
pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}"
# Clean up
conda_run pip uninstall -y torch torchvision triton
}
function install_torchbench() {
local commit
commit=$(get_pinned_commit torchbench)
git clone https://github.com/pytorch/benchmark torchbench
pushd torchbench
git checkout "$commit"
python install.py --continue_on_fail
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
# is regressing speedup metric. This needs to be investigated further
pip install transformers==4.38.1
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
chown -R jenkins torchbench
}
# Pango is needed for weasyprint which is needed for doctr
conda_install pango
# Stable packages are ok here, just to satisfy TorchBench check
pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
install_torchbench
install_huggingface
install_timm
# Clean up
conda_run pip uninstall -y torch torchvision torchaudio triton

View File

@ -103,5 +103,5 @@ fi
# It depends on torch and triton. We don't want to install
# triton and torch from production on Docker CI images
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
pip_install helion==0.0.10 --no-deps
pip_install helion --no-deps
fi

View File

@ -361,7 +361,6 @@ pwlf==2.2.1
#Pinned versions: 2.2.1
#test that import: test_sac_estimator.py
# To build PyTorch itself
pyyaml
pyzstd

View File

@ -1,7 +1,7 @@
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
@ -50,7 +50,7 @@ IPython==8.12.0
#Pinned versions: 8.12.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs

View File

@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface.txt huggingface.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
# (optional) Install non-default Ninja version
ARG NINJA_VERSION

View File

@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface.txt huggingface.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
ARG TRITON
ARG TRITON_CPU

View File

@ -194,7 +194,7 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
ROCBLAS_LIB_DST=lib/rocblas/library
ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH)
ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx)
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES)
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES)
# hipblaslt library files
HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library

View File

@ -229,7 +229,6 @@ function install_torchrec_and_fbgemm() {
pip_install tabulate # needed for newer fbgemm
pip_install patchelf # needed for rocm fbgemm
pushd /tmp
local wheel_dir=dist/fbgemm_gpu
local found_whl=0
@ -245,7 +244,7 @@ function install_torchrec_and_fbgemm() {
if [ "${found_whl}" == "0" ]; then
git clone --recursive https://github.com/pytorch/fbgemm
pushd fbgemm/fbgemm_gpu
git checkout "${fbgemm_commit}"
git checkout "${fbgemm_commit}" --recurse-submodules
python setup.py bdist_wheel \
--build-variant=rocm \
-DHIP_ROOT_DIR="${ROCM_PATH}" \
@ -264,7 +263,6 @@ function install_torchrec_and_fbgemm() {
done
rm -rf fbgemm
popd
else
pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec
pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu
@ -283,30 +281,6 @@ function clone_pytorch_xla() {
fi
}
function checkout_install_torchbench() {
local commit
commit=$(get_pinned_commit torchbench)
git clone https://github.com/pytorch/benchmark torchbench
pushd torchbench
git checkout "$commit"
if [ "$1" ]; then
python install.py --continue_on_fail models "$@"
else
# Occasionally the installation may fail on one model but it is ok to continue
# to install and test other models
python install.py --continue_on_fail
fi
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
# is regressing speedup metric. This needs to be investigated further
pip install transformers==4.38.1
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
}
function install_torchao() {
local commit
commit=$(get_pinned_commit torchao)

View File

@ -157,6 +157,29 @@ test_jit_hooks() {
assert_git_not_dirty
}
# Shellcheck doesn't like it when you pass no arguments to a function
# that can take args. See https://www.shellcheck.net/wiki/SC2120
# shellcheck disable=SC2120
checkout_install_torchbench() {
local commit
commit=$(cat .ci/docker/ci_commit_pins/torchbench.txt)
git clone https://github.com/pytorch/benchmark torchbench
pushd torchbench
git checkout "$commit"
if [ "$1" ]; then
python install.py --continue_on_fail models "$@"
else
# Occasionally the installation may fail on one model but it is ok to continue
# to install and test other models
python install.py --continue_on_fail
fi
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
}
torchbench_setup_macos() {
git clone --recursive https://github.com/pytorch/vision torchvision
git clone --recursive https://github.com/pytorch/audio torchaudio
@ -179,8 +202,6 @@ torchbench_setup_macos() {
USE_OPENMP=0 python setup.py develop
popd
# Shellcheck doesn't like it when you pass no arguments to a function that can take args. See https://www.shellcheck.net/wiki/SC2120
# shellcheck disable=SC2119,SC2120
checkout_install_torchbench
}

View File

@ -627,6 +627,8 @@ test_perf_for_dashboard() {
device=cuda_a10g
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
device=cuda_h100
elif [[ "${TEST_CONFIG}" == *b200* ]]; then
device=cuda_b200
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
device=rocm
fi
@ -801,6 +803,16 @@ test_dynamo_benchmark() {
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@"
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
# TODO (huydhn): Just smoke test some sample models
if [[ "${TEST_CONFIG}" == *b200* ]]; then
if [[ "${suite}" == "huggingface" ]]; then
export TORCHBENCH_ONLY_MODELS="DistillGPT2"
elif [[ "${suite}" == "timm_models" ]]; then
export TORCHBENCH_ONLY_MODELS="inception_v3"
elif [[ "${suite}" == "torchbench" ]]; then
export TORCHBENCH_ONLY_MODELS="hf_Bert"
fi
fi
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
else
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
@ -1672,13 +1684,11 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then
elif [[ "${TEST_CONFIG}" == cachebench ]]; then
install_torchaudio
install_torchvision
checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco
PYTHONPATH=$(pwd)/torchbench test_cachebench
PYTHONPATH=/torchbench test_cachebench
elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then
install_torchaudio
install_torchvision
checkout_install_torchbench nanogpt
PYTHONPATH=$(pwd)/torchbench test_verify_cachebench
PYTHONPATH=/torchbench test_verify_cachebench
elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
install_torchaudio
install_torchvision
@ -1687,28 +1697,22 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
# https://github.com/opencv/opencv-python/issues/885
pip_install opencv-python==4.8.0.74
if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then
checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer
PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf
PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf
elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then
checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \
llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \
functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0
PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf
PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf
elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then
checkout_install_torchbench
TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest
TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest
else
checkout_install_torchbench
# Do this after checkout_install_torchbench to ensure we clobber any
# nightlies that torchbench may pull in
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
install_torchrec_and_fbgemm
fi
PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id"
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
fi
elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
install_torchvision
PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
test_inductor_aoti
fi

View File

@ -1 +1 @@
bf305f538005f2e900f8850ed57146024a8bc559
6fbc710b617f79b992ef2ebc7f95e818aa390293

View File

@ -1 +1 @@
ca9e2be3ed6320b51f52f536595cd24e254f8bb2
6a39ba85fe0f2fff9494b5eccea717c93510c230

View File

@ -1 +1 @@
29ae4c76c026185f417a25e841d2cd5e65f087a3
b6a5b82b9948b610fa4c304d0d869c82b8f17db1

View File

@ -488,6 +488,10 @@
- torch/_dynamo/**
- torch/csrc/dynamo/**
- test/dynamo/**
- test/dynamo_expected_failures/**
- test/dynamo_skips/**
- test/inductor_expected_failures/**
- test/inductor_skips/**
approved_by:
- guilhermeleobas
mandatory_checks_name:

View File

@ -193,7 +193,7 @@ LIBTORCH_CONTAINER_IMAGES: dict[str, str] = {
"cpu": "libtorch-cxx11-builder:cpu",
}
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"]
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
@ -315,6 +315,11 @@ def generate_wheels_matrix(
# TODO: Enable python 3.13t on cpu-s390x
if gpu_arch_type == "cpu-s390x" and python_version == "3.13t":
continue
# TODO: Enable python 3.14 on non linux OSes
if os != "linux" and (
python_version == "3.14" or python_version == "3.14t"
):
continue
if use_split_build and (
arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"

View File

@ -96,7 +96,7 @@ jobs:
steps:
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
instructions: |
@ -109,7 +109,7 @@ jobs:
no-sudo: true
- name: Setup Python
if: matrix.runner == 'B200'
if: contains(matrix.runner, 'b200')
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.12'
@ -117,7 +117,7 @@ jobs:
- name: Setup Linux
uses: ./.github/actions/setup-linux
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200'
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200')
- name: configure aws credentials
if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
@ -128,7 +128,7 @@ jobs:
aws-region: us-east-1
- name: Login to Amazon ECR
if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }}
if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }}
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
@ -166,17 +166,17 @@ jobs:
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
with:
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
- name: Setup GPU_FLAG for docker run
id: setup-gpu-flag
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }}
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
id: setup-sscache-port-flag
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }}
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }}
- name: Lock NVIDIA A100 40GB Frequency
run: |
@ -277,8 +277,8 @@ jobs:
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }}
# Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs
SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }}
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
@ -403,7 +403,7 @@ jobs:
job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }}
- name: Authenticate with AWS
if: ${{ matrix.runner == 'B200' }}
if: ${{ contains(matrix.runner, 'b200') }}
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results

View File

@ -34,7 +34,8 @@ jobs:
contents: read
pull-requests: write
name: Check labels
if: github.repository_owner == 'pytorch'
# Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
if: github.repository_owner == 'pytorch' && false
runs-on: linux.24_04.4x
steps:
- name: Checkout PyTorch

View File

@ -7,7 +7,8 @@ on:
jobs:
ghstack-mergeability-check:
if: github.repository_owner == 'pytorch'
# Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
if: github.repository_owner == 'pytorch' && false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -51,17 +51,12 @@ jobs:
docker-image-name: [
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm,
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.9-clang12,
pytorch-linux-jammy-py3.11-clang12,
pytorch-linux-jammy-py3.12-clang12,
pytorch-linux-jammy-py3.13-clang12,
pytorch-linux-jammy-rocm-n-py3,
pytorch-linux-noble-rocm-n-py3,
@ -76,7 +71,8 @@ jobs:
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter,
pytorch-linux-jammy-py3-clang12-executorch,
# Executorch pin needs update
# pytorch-linux-jammy-py3-clang12-executorch,
pytorch-linux-jammy-py3.12-triton-cpu
]
include:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,154 @@
name: inductor-perf-b200
on:
schedule:
- cron: 0 7 * * 1-6
- cron: 0 7 * * 0
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
workflow_dispatch:
inputs:
training:
description: Run training (on by default)?
required: false
type: boolean
default: true
inference:
description: Run inference (on by default)?
required: false
type: boolean
default: true
default:
description: Run inductor_default?
required: false
type: boolean
default: false
dynamic:
description: Run inductor_dynamic_shapes?
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
type: boolean
default: true
freezing_cudagraphs:
description: Run inductor_cudagraphs with freezing for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
type: boolean
default: false
maxautotune:
description: Run inductor_max_autotune?
required: false
type: boolean
default: false
benchmark_configs:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf_cuda_b200,inductor_timm_perf_cuda_b200,inductor_torchbench_perf_cuda_b200
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
build:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
# Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100
# or newer GPUs, so it doesn't benefit much from existing compiler cache
# from trunk. Also use a memory-intensive runner here because memory is
# usually the bottleneck
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
{ config: "inductor_timm_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
{ config: "inductor_torchbench_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
]}
selected-test-configs: ${{ inputs.benchmark_configs }}
build-additional-packages: "vision audio fbgemm torchao"
secrets: inherit
test-periodically:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 1-6'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
test-weekly:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 0'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
timeout-minutes: 1440
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
test:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -81,21 +81,21 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
]}
secrets: inherit

View File

@ -75,10 +75,11 @@ jobs:
repo-owner: pytorch
branch: main
pin-folder: .github/ci_commit_pins
- repo-name: executorch
repo-owner: pytorch
branch: main
pin-folder: .ci/docker/ci_commit_pins
# executorch jobs are disabled since it needs some manual work for the hash update
# - repo-name: executorch
# repo-owner: pytorch
# branch: main
# pin-folder: .ci/docker/ci_commit_pins
- repo-name: triton
repo-owner: triton-lang
branch: main

View File

@ -51,37 +51,6 @@ jobs:
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cuda12_4-py3_10-gcc11-sm89-build:
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_4-py3_10-gcc11-sm89-test:
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_4-py3_10-gcc11-sm89-build
- target-determination
with:
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_4-py3_10-gcc11-build:
name: linux-jammy-cuda12.4-py3.10-gcc11
uses: ./.github/workflows/_linux-build.yml

View File

@ -292,13 +292,14 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
@ -402,38 +403,8 @@ jobs:
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-sm89-build:
name: linux-jammy-cuda12.8-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-sm89-test:
name: linux-jammy-cuda12.8-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-sm89-build
- target-determination
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3-clang12-executorch-build:
if: false # Docker build needs pin update
name: linux-jammy-py3-clang12-executorch
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type

View File

@ -10,6 +10,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-default-label-prefix:
if: github.repository_owner == 'pytorch'

View File

@ -205,7 +205,7 @@ jobs:
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.9-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },

View File

@ -23,7 +23,7 @@ jobs:
with:
repository: pytorch/pytorch
stable-branch: viable/strict
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]'
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\", \"linux-aarch64\"]'
secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }}
clickhouse-url: ${{ secrets.CLICKHOUSE_URL }}
clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }}

View File

@ -1 +1,18 @@
- This is the only AGENTS.md, there are no recursive AGENTS.md
- When you are working on a bug, first create a standalone file that
reproduces the bug and verify it fails in the expected way. Use this to
test if your changes work. Once the change is passing, find an appropriate
test file to add the test to and make sure to follow local conventions on
the test file.
- If you are running the real test suite, DO NOT run the entire test suite.
Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir'
- Do NOT run setup.py, you do not have a working build environment
- Do NOT run pre-commit, it is not setup
- To run lint, run 'lintrunner -a' (which will autoapply changes). lintrunner
ONLY accepts this flag, do not try to run on individual files.
- Do NOT attempt to install dependencies, you do not have Internet access
- When you are ready to make a PR, do exactly these steps:
- git stash -u
- git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch
- git stash pop
- Resolve conflicts if necessary

View File

@ -14,7 +14,6 @@
/torch/csrc/autograd/ @albanD @soulitzer
/torch/autograd/ @albanD @soulitzer
/tools/autograd/ @albanD @soulitzer
/torch/header_only_apis.txt @janeyx99
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
/torch/optim/ @albanD @janeyx99
/test/test_public_bindings.py @albanD
@ -196,3 +195,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed
/torch/utils/_cxx_pytree.py @XuehaiPan
/torch/utils/pytree/ @XuehaiPan
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
# Relating to libtorch ABI
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
/torch/headeronly/ @janeyx99
/torch/header_only_apis.txt @janeyx99

View File

@ -276,7 +276,7 @@ conda install pkg-config libuv
pip install mkl-static mkl-include
# Add these packages if torch.distributed is needed.
# Distributed package support on Windows is a prototype feature and is subject to changes.
conda install -c conda-forge libuv=1.39
conda install -c conda-forge libuv
```
#### Install PyTorch

View File

@ -439,6 +439,7 @@ if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
_pytorch_rocm_generate_ck_conf()
@ -703,21 +704,17 @@ if(USE_MPS)
if(CAN_COMPILE_METAL)
foreach(SHADER ${native_mps_metal})
cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
list(APPEND AIR_BASIC ${TGT_BASIC})
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
endforeach()
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
add_custom_command(
COMMAND echo "// $$(date)" > metallib_dummy.cpp
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
DEPENDS kernels_basic.metallib
OUTPUT metallib_dummy.cpp
COMMENT "Updating metallibs timestamp")
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
else()
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
foreach(SHADER ${native_mps_metal})

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
// original device index that was active before the change.
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
TORCH_API inline void emptyCache() {
const auto device_type = getAccelerator(true).value();
at::getDeviceAllocator(device_type)->emptyCache();
}
TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getDeviceStats(device_index);
}
TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index);
}
TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -2,7 +2,6 @@
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <cstddef>

View File

@ -2,6 +2,7 @@
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/flat_hash_map.h>

View File

@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
}
}
template<>
void _assert_match<c10::Device, std::optional<c10::Device>>(
const c10::Device& original,
const std::optional<c10::Device>& compared,
const std::string& name) {
if (compared) {
const c10::Device& expected = compared.value();
if (original.type() != expected.type()) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
throw std::runtime_error(msg.str());
}
// If the expected device doesn't have an index (e.g., just "cuda"),
// or if both devices have the same index, consider them equal
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
throw std::runtime_error(msg.str());
}
}
}
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
_assert_match(tensor.sym_sizes(), sizes, "sizes");
_assert_match(tensor.sym_strides(), strides, "strides");

View File

@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
auto* C_data = C.data_ptr<T>();
const auto* S_data = scales.const_data_ptr<T>();
int M = A.size(0);
int N = B.size(0);
int K = A.size(1);
int lda = A.stride(0);
constexpr int BLOCK_M = 4;
constexpr int BLOCK_N = 4;
int64_t M = A.size(0);
int64_t N = B.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 4;
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
int mb{0}, nb{0};
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
for (const auto i : c10::irange(begin, end)) {
(void)i;
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(BLOCK_N, N - nb_start);
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
const auto* A_ptr = A_data + mb_start * lda;
const auto* B_ptr = B_data + nb_start * K;

View File

@ -526,7 +526,7 @@ namespace {
// we are dealing with packed tensor here. max index is the same as numel.
// TODO: to really support input tensor large enought to go beyond int32,
// TODO: to really support input tensor large enough to go beyond int32,
// we will need to restrict out shared memory usage and adjust the launch
// config;
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
@ -681,7 +681,7 @@ namespace {
const dim3 grid(grid_x, grid_y, grid_z);
// we are dealing with packed tensor here. max index is the same as numel.
// TODO: to really support input tensor large enought to go beyond int32,
// TODO: to really support input tensor large enough to go beyond int32,
// we will need to restrict out shared memory usage and adjust the launch
// config;
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());

View File

@ -1634,6 +1634,9 @@ bool use_fast_accum) {
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
TORCH_CHECK(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
@ -1716,6 +1719,9 @@ std::optional<c10::ScalarType> out_dtype) {
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
// check that the strides are valid, the fn will throw an error if not
check_valid_strides_and_return_transposed(mat_a);

View File

@ -223,7 +223,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo
class CuFFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// Only move semantics is enough for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
CuFFTConfig(const CuFFTConfig&) = delete;

View File

@ -241,6 +241,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
Strides tensor_StrideA = make_strides(mat_a.strides());
Strides tensor_StrideB = make_strides(mat_b.strides());
Strides tensor_StrideOutput = make_strides(out.strides());
Strides tensor_ShapeA = make_strides(mat_a.sizes());
Strides tensor_ShapeB = make_strides(mat_b.sizes());
at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>(
reinterpret_cast<DtypeA*>(mat_a.data_ptr()),
@ -264,6 +266,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
tensor_StrideA,
tensor_StrideB,
tensor_StrideOutput,
tensor_ShapeA,
tensor_ShapeB,
0,
0,
a_row_major,

View File

@ -38,18 +38,20 @@ __global__ void prepare_grouped_gemm_data(
Strides tensor_StrideA,
Strides tensor_StrideB,
Strides tensor_StrideOutput,
Strides tensor_ShapeA,
Strides tensor_ShapeB,
int64_t a_scale_stride,
int64_t b_scale_stride,
bool a_row_major = true,
bool b_row_major = false) {
int32_t tid = threadIdx.x;
int32_t delta = 0;
int32_t offset = 0;
if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start;
if (K < 0) {
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
}
offset = offs[tid];
delta = offset - start;
CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n");
// TMA transfers require global memory tensor addresses to be
// aligned to 16 bytes.
@ -84,6 +86,7 @@ __global__ void prepare_grouped_gemm_data(
int64_t lda, ldb, ldoutput;
if (M < 0) {
// A and output is 2d
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n");
M = delta;
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
@ -96,6 +99,7 @@ __global__ void prepare_grouped_gemm_data(
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
B_ptrs[tid] = B + tid * tensor_StrideB[0];
} else if (N < 0) {
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n");
N = delta;
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed
@ -108,6 +112,7 @@ __global__ void prepare_grouped_gemm_data(
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
}
} else if (K < 0) {
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n");
// A, B is 2d, output is 3d
K = delta;
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];

View File

@ -644,7 +644,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta))
// As above, there may be better configurations to use.
constexpr int max_threads = std::is_same_v<scalar_t, float> ? 1024 : 896; // we need 72 or so 32 bit registers for double
constexpr int max_threads_ = std::is_same_v<scalar_t, float> ? 1024 : 896; // we need 72 or so 32 bit registers for double
int max_threads = max_threads_;
// Blackwell launch bounds
if (at::cuda::getCurrentDeviceProperties()->major >= 10) {
max_threads = 512;
}
int threads_target = max_threads;
while (threads_target / 2 >= 2*max_target_length+1) {
threads_target /= 2;

View File

@ -298,6 +298,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
Strides tensor_StrideA = make_strides(mat_a.strides());
Strides tensor_StrideB = make_strides(mat_b.strides());
Strides tensor_StrideOutput = make_strides(out.strides());
Strides tensor_ShapeA = make_strides(mat_a.sizes());
Strides tensor_ShapeB = make_strides(mat_b.sizes());
// scale stride will be used inside the kernel only if needed,
// so for 1d scales the "1" assigned here won't be used
int64_t a_scale_stride = scale_a.stride(0);
@ -325,6 +328,8 @@ void f8f8bf16_grouped_gemm_impl_sm90(
tensor_StrideA,
tensor_StrideB,
tensor_StrideOutput,
tensor_ShapeA,
tensor_ShapeB,
a_scale_stride,
b_scale_stride);

View File

@ -0,0 +1,74 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace at::native {
__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) {
// one thread per output element: [B, N]
int b = blockIdx.y * blockDim.y + threadIdx.y;
int n = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B || n >= N) return;
float acc = 0.0f;
for (int k = 0; k < K; ++k) {
acc += x[b * K + k] * static_cast<float>(w[n * K + k]);
}
out[b * N + n] = acc * scale[n];
}
void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) {
const int B = x.size(0);
const int K = x.size(1);
const int N = w_int8.size(0);
const dim3 block(16, 16);
const dim3 grid((N + block.x - 1) / block.x, (B + block.y - 1) / block.y);
auto stream = at::cuda::getCurrentCUDAStream();
weight_int8pack_mm_kernel<<<grid, block, 0, stream>>>(
x.data_ptr<float>(),
w_int8.data_ptr<int8_t>(),
scale.data_ptr<float>(),
out.data_ptr<float>(),
B, K, N);
}
// Main GPU entry point
at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) {
// --- Check inputs ---
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor");
TORCH_CHECK(scale.is_cuda(), "scale must be a CUDA tensor");
TORCH_CHECK(x.dim() == 2, "x must be 2D");
TORCH_CHECK(w_int8.dim() == 2, "w must be 2D");
TORCH_CHECK(scale.dim() == 1, "scale must be 1D");
TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)");
TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)");
// --- Determine shapes ---
auto B = x.size(0); // batch size
auto N = w_int8.size(0); // output dim
// Ensure inputs are in the correct types for the kernel
auto x_f32 = x.to(at::kFloat);
auto w_int8_contiguous = w_int8.contiguous();
auto scale_f32 = scale.to(at::kFloat);
// --- Allocate output ---
auto out = at::empty({B, N}, x.options().dtype(at::kFloat));
// --- Launch kernel ---
launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out);
return out;
}
} // namespace at::native

View File

@ -28,6 +28,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support");
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
const Tensor& input,
const Tensor& weight,
const std::optional<Tensor>& bias,
const std::optional<Tensor>& running_mean,
const std::optional<Tensor>& running_var,
bool training,
double exponential_average_factor,
double epsilon,
Tensor& out,
Tensor& save_mean,
Tensor& save_var,
Tensor& reserve) {
AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
const Tensor& input,
const Tensor& grad_output,
@ -120,7 +136,12 @@ size_t _get_cudnn_batch_norm_reserve_space_size(
return reserve_size;
}
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
// Param `reserve` is a placeholder, just passing an empty tensor.
// usage:
// auto reserve = torch::empty({0}, torch::device(torch::kCUDA));
// at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var,
// reserve);
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_t_opt,
@ -128,7 +149,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const std::optional<Tensor>& running_var_t_opt,
bool training,
double exponential_average_factor,
double epsilon) {
double epsilon,
Tensor& output_t,
Tensor& save_mean,
Tensor& save_var,
Tensor& reserve) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_t_maybe_owned =
at::borrow_from_optional_tensor(bias_t_opt);
@ -168,9 +193,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
training, input->suggest_memory_format(), input->dim());
auto output_t =
at::empty_like(*input, input->options(), input->suggest_memory_format());
TensorArg output{output_t, "output", 0};
auto handle = getCudnnHandle();
@ -182,15 +204,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
Constant one(dataType, 1);
Constant zero(dataType, 0);
Tensor save_mean, save_var;
Tensor reserve;
if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({num_features}, weight_t.options());
save_var = at::empty({num_features}, weight_t.options());
auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
@ -238,9 +253,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
reserve_size));
} else {
reserve = at::empty({0}, input->options().dtype(kByte));
// This keeps a consistent output with native_batch_norm
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
handle,
mode,
@ -261,10 +273,48 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
// save_mean and save_var can be undefined
// If this causes problems, we can initialize them to empty tensors
// of the correct type
return std::tuple<Tensor, Tensor, Tensor, Tensor>{
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>{
output_t, save_mean, save_var, reserve};
}
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_t_opt,
const std::optional<Tensor>& running_mean_t_opt,
const std::optional<Tensor>& running_var_t_opt,
bool training,
double exponential_average_factor,
double epsilon) {
auto output_t = at::empty_like(
input_t, input_t.options(), input_t.suggest_memory_format());
Tensor save_mean, save_var, reserve;
if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({num_features}, weight_t.options());
save_var = at::empty({num_features}, weight_t.options());
} else {
// This keeps a consistent output with native_batch_norm
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
}
return cudnn_batch_norm_out(
input_t,
weight_t,
bias_t_opt,
running_mean_t_opt,
running_var_t_opt,
training,
exponential_average_factor,
epsilon,
output_t,
save_mean,
save_var,
reserve);
}
// NB: CuDNN only implements the backward algorithm for batchnorm
// in training mode (evaluation mode batchnorm has a different algorithm),
// which is why this doesn't accept a 'training' parameter.

View File

@ -1,7 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/mkldnn/Matmul.h>
@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
}
}
template <typename T>
bool use_mkldnn_typed_matmul(
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
bool dtype_check = false;
if constexpr (std::is_same_v<T, c10::BFloat16>) {
#if defined(__aarch64__)
if (mkldnn_bf16_device_check_arm()) {
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
// inputs, allow it for float as well
dtype_check = use_mkldnn_bf16_matmul() &&
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
}
#else
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
(mat1.scalar_type() == kBFloat16);
if (mkldnn_bf16_device_check_arm()) {
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
// inputs, allow it for float as well
return (
use_mkldnn_bf16_matmul() &&
(mat1.scalar_type() == mat2.scalar_type()) &&
(!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
} else
#endif
} else if constexpr (std::is_same_v<T, c10::Half>) {
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
(mat1.scalar_type() == kHalf);
} else if constexpr (std::is_same_v<T, float>) {
dtype_check = dtype_check &&
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
(mat1.scalar_type() == kFloat);
{
return (
use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 &&
mat2.scalar_type() == kBFloat16 &&
(!result.defined() || result.scalar_type() == kBFloat16) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
}
if (!dtype_check) {
return false;
}
bool size_check =
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
(!result.defined() || result.scalar_type() == mat1.scalar_type());
return dtype_check && size_check;
}
bool use_mkldnn_fp16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf &&
mat2.scalar_type() == kHalf &&
(!result.defined() || result.scalar_type() == kHalf) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
}
bool use_mkldnn_bf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat &&
mat2.scalar_type() == kFloat &&
(!result.defined() || result.scalar_type() == kFloat) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
}
bool use_mkldnn_tf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat &&
mat2.scalar_type() == kFloat &&
(!result.defined() || result.scalar_type() == kFloat) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
}
bool use_mkldnn_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
auto mat1_type = mat1.scalar_type();
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
return false;
}
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
});
return false;
return (
use_mkldnn_bf16_matmul(mat1, mat2, result) ||
use_mkldnn_fp16_matmul(mat1, mat2, result) ||
use_mkldnn_bf32_matmul(mat1, mat2, result) ||
use_mkldnn_tf32_matmul(mat1, mat2, result));
}
static void _mkldnn_matmul_i8i8i32_with_primitive(

View File

@ -469,4 +469,94 @@ Tensor _weight_int4pack_mm_xpu(
return C;
}
Tensor& _int_mm_out_xpu(
const Tensor& self,
const Tensor& mat2,
Tensor& result) {
TORCH_CHECK(
self.dim() == 2,
"Expected self to be of dimension 2 but got ",
self.dim());
TORCH_CHECK(
mat2.dim() == 2,
"Expected mat2 to be of dimension 2 but got ",
mat2.dim());
TORCH_CHECK(
self.size(1) == mat2.size(0),
"self.size(1) needs to match mat2.size(0) but got ",
self.size(1),
" and ",
mat2.size(0));
TORCH_CHECK(
self.dtype() == at::kChar,
"Expected self dtype to be of type int8 but got ",
self.dtype());
TORCH_CHECK(
mat2.dtype() == at::kChar,
"Expected mat2 dtype to be of type int8 but got ",
mat2.dtype());
TORCH_CHECK(
result.dtype() == at::kInt,
"Expected result dtype to be of type kInt but got ",
result.dtype());
TORCH_CHECK(
result.size(0) == self.size(0),
"Expected result.size(0) to be ",
self.size(0),
" but got ",
result.size(0));
TORCH_CHECK(
result.size(1) == mat2.size(1),
"Expected result.size(1) to be ",
mat2.size(1),
" but got ",
result.size(1));
TORCH_CHECK(
result.dim() == 2,
"Expected result to be of dimension 2 but got ",
result.dim());
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
if (result.numel() == 0 || self.size(1) == 0) {
return result.zero_();
}
Tensor bias = at::Tensor();
Tensor mat2_scales = at::ones({1}, mat2.options().dtype(at::kFloat));
Tensor mat2_zero_points = at::Tensor();
auto post_op_args = torch::List<std::optional<at::Scalar>>();
at::native::onednn::quantized_matmul(
self.contiguous(),
1.0,
0,
mat2.contiguous(),
mat2_scales,
mat2_zero_points,
bias,
result,
1.0,
0,
result.scalar_type(),
/*other*/ std::nullopt,
/*other scale*/ 1.0,
/*other zp*/ 0,
/*binary post op*/ "none",
/*binary alpha*/ 1.0,
/*post_op_name*/ "none",
post_op_args,
/*post_op_algorithm*/ "none",
/*m2_trans*/ true);
return result;
}
Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
Tensor result =
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
return _int_mm_out_xpu(self, mat2, result);
}
} // namespace at::native

View File

@ -953,8 +953,7 @@ class BundledShaderLibary : public MetalShaderLibrary {
if (C10_UNLIKELY(!library)) {
auto device = MPSDevice::getInstance()->device();
NSError* error = nil;
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
}
return library;

View File

@ -33,21 +33,15 @@ struct shrink_backward_functor {
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
#endif
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
#endif
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
#endif
struct hardsigmoid_functor {
template <typename T>
@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
REGISTER_UNARY_OP(hardsigmoid, float, float);
REGISTER_UNARY_OP(hardsigmoid, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
#endif
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
#endif
struct hardswish_functor {
template <typename T>
@ -103,15 +93,11 @@ struct hardswish_backward_functor {
REGISTER_UNARY_OP(hardswish, float, float);
REGISTER_UNARY_OP(hardswish, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
#endif
REGISTER_BINARY_OP(hardswish_backward, float, float);
REGISTER_BINARY_OP(hardswish_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
#endif
struct leaky_relu_functor {
template <typename T>
@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
#endif
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
#endif

View File

@ -113,18 +113,12 @@ kernel void ampUpdateScale(
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
#endif
INSTANTIATE_AMP_UPDATE_SCALE(float);
INSTANTIATE_AMP_UPDATE_SCALE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
#endif
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
#endif

View File

@ -590,9 +590,7 @@ kernel void attention(
INSTANTIATE_SDPA_VECTOR_HEADS(float);
INSTANTIATE_SDPA_VECTOR_HEADS(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
#endif
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
INSTANTIATE_ATTN_SHAPES_HELPER(float);
INSTANTIATE_ATTN_SHAPES_HELPER(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
#endif

View File

@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
};
struct nextafter_functor {
#if __METAL_VERSION__ < 310
template <typename U>
struct bit_type {};
template <>
struct bit_type<float> {
using type = int;
};
template <>
struct bit_type<half> {
using type = short;
};
#endif
template <typename T>
inline T operator()(const T a, const T b) {
#if __METAL_VERSION__ >= 310
return static_cast<T>(::metal::nextafter(a, b));
#else
using U = typename bit_type<T>::type;
if (a == b) {
return a;
}
if (::metal::isunordered(a, b)) {
return NAN;
}
if (a == 0) {
constexpr auto eps = as_type<T>(static_cast<U>(1));
return b > 0 ? eps : -eps;
}
auto bits = as_type<U>(a);
(a > 0) ^ (a > b) ? bits++ : bits--;
return as_type<T>(bits);
#endif
}
};
@ -344,13 +315,6 @@ struct fmod_functor {
}
};
// Some helper defines
#if __METAL_VERSION__ >= 310
#define _METAL_310_PLUS(x) x
#else
#define _METAL_310_PLUS(x)
#endif
#define REGISTER_INTEGER_BINARY_OP(NAME) \
REGISTER_BINARY_OP(NAME, long, long); \
REGISTER_BINARY_OP(NAME, int, int); \
@ -370,12 +334,12 @@ struct fmod_functor {
#define REGISTER_FLOAT_BINARY_OP(NAME) \
REGISTER_BINARY_OP(NAME, float, float); \
REGISTER_BINARY_OP(NAME, half, half); \
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
REGISTER_FLOAT_BINARY_OP(copysign);
REGISTER_INT2FLOAT_BINARY_OP(copysign);
@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
#endif
// Complex binary functions
REGISTER_BINARY_OP(polar, float, float2);

View File

@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
REGISTER_SEARCHSORTED_OP(float, long);
REGISTER_SEARCHSORTED_OP(half, int);
REGISTER_SEARCHSORTED_OP(half, long);
#if __METAL_VERSION__ >= 310
REGISTER_SEARCHSORTED_OP(bfloat, int);
REGISTER_SEARCHSORTED_OP(bfloat, long);
#endif
REGISTER_SEARCHSORTED_OP(char, int);
REGISTER_SEARCHSORTED_OP(char, long);
REGISTER_SEARCHSORTED_OP(uchar, int);

View File

@ -96,6 +96,4 @@ kernel void col2im_kernel(
INSTANTIATE_COL2IM(bool);
INSTANTIATE_COL2IM(float);
INSTANTIATE_COL2IM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_COL2IM(bfloat);
#endif

View File

@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
REGISTER_CROSS_FUNC(char);
REGISTER_CROSS_FUNC(uchar);
REGISTER_CROSS_FUNC(bool);
#if __METAL_VERSION__ >= 310
REGISTER_CROSS_FUNC(bfloat);
#endif
template <typename T, typename U>
kernel void cross(
@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
REGISTER_CROSS_OP(char);
REGISTER_CROSS_OP(uchar);
REGISTER_CROSS_OP(bool);
#if __METAL_VERSION__ >= 310
REGISTER_CROSS_OP(bfloat);
#endif

View File

@ -1,11 +1,9 @@
#include <metal_stdlib>
using metal::max;
#if __METAL_VERSION__ >= 310
bfloat max(bfloat a, bfloat b) {
return a > b ? a : b;
}
#endif
#define kmaxThreadGroups 32
#define kmaxTensors 32
@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
REGISTER_ADAM_OPS_QUART(float, half);
REGISTER_ADAM_OPS_QUART(half, float);
REGISTER_ADAM_OPS_QUART(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_ADAM_OPS_QUART(float, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, float);
#endif
template <typename T>
inline void sgd_momentum_math(
@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
REGISTER_FUSED_SGD_OP(half);
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_FUSED_SGD_OP(bfloat);
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
#endif

View File

@ -106,9 +106,7 @@ kernel void polygamma(
constant int64_t& order [[buffer(2)]], \
uint id [[thread_position_in_grid]]);
#if __METAL_VERSION__ >= 310
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
#endif
INSTANTIATE_GAMMA_KERNELS(half, half);
INSTANTIATE_GAMMA_KERNELS(float, float);
INSTANTIATE_GAMMA_KERNELS(bool, float);

View File

@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float);
INSTANTIATE_IM2COL(float2);
INSTANTIATE_IM2COL(half);
INSTANTIATE_IM2COL(half2);
#if __METAL_VERSION__ >= 310
INSTANTIATE_IM2COL(bfloat);
#endif

View File

@ -240,9 +240,7 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
REGISTER_INDEX_OP(put_accumulate, char, char);
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
REGISTER_INDEX_OP(put_accumulate, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
#endif
template <typename StridesT, typename DataT>
kernel void kernel_index_offsets(
@ -477,10 +475,8 @@ INSTANTIATE_INDEX_COPY(char, long);
INSTANTIATE_INDEX_COPY(uchar, int);
INSTANTIATE_INDEX_COPY(uchar, long);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INDEX_COPY(bfloat, int);
INSTANTIATE_INDEX_COPY(bfloat, long);
#endif
INSTANTIATE_INDEX_COPY(float2, int);
INSTANTIATE_INDEX_COPY(float2, long);
INSTANTIATE_INDEX_COPY(half2, int);

View File

@ -288,7 +288,6 @@ kernel void layer_norm_looped(
#define instantiate_layer_norm(DTYPE) \
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
instantiate_layer_norm(float) instantiate_layer_norm(half)
#if __METAL_VERSION__ >= 310
instantiate_layer_norm(bfloat)
#endif
instantiate_layer_norm(float);
instantiate_layer_norm(half);
instantiate_layer_norm(bfloat);

View File

@ -635,9 +635,7 @@ kernel void applyPivots(
INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_NAIVE_MM(bfloat);
#endif
// Integral MM
INSTANTIATE_NAIVE_MM(short);

View File

@ -48,3 +48,14 @@ struct PoolingBackwardParams {
::c10::metal::array<idx_type_t, N> grad_output_strides;
::c10::metal::array<idx_type_t, N> indices_strides;
};
template <unsigned N = 5, typename idx_type_t = int32_t>
struct MaxUnpoolingParams {
int32_t dims;
int32_t pooling_dims;
::c10::metal::array<idx_type_t, N> input_sizes;
::c10::metal::array<idx_type_t, N> input_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> indices_strides;
};

View File

@ -168,6 +168,16 @@ PoolOffsets find_pool_offsets(
leading_dims,
return_indices,
tid);
case 3:
return find_pool_offsets_dim_specific<3>(
output_sizes,
output_strides,
indices_strides,
input_strides,
pooling_dim_indices,
leading_dims,
return_indices,
tid);
}
return PoolOffsets();
}
@ -292,6 +302,68 @@ kernel void max_pool_backward(
pooling_dims);
}
template <typename T>
void max_unpool_impl(
device T* output,
T input_element,
int32_t input_index,
constant int32_t* output_sizes,
constant int32_t* output_strides,
int32_t pooling_dims) {
int32_t size_prod = 1;
int32_t pool_offset = 0;
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
auto next_size_prod = output_sizes[dim] * size_prod;
pool_offset +=
output_strides[dim] * ((input_index % next_size_prod) / size_prod);
size_prod *= output_sizes[dim];
}
output[pool_offset] = input_element;
}
// Kernel computes one element of the grad input per kernel call.
template <typename T>
kernel void max_unpool(
device T* output [[buffer(0)]],
constant T* input [[buffer(1)]],
constant int64_t* indices [[buffer(2)]],
constant MaxUnpoolingParams<5>& params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto input_sizes = params.input_sizes.data();
auto input_strides = params.input_strides.data();
auto output_sizes = params.output_sizes.data();
auto output_strides = params.output_strides.data();
auto indices_strides = params.indices_strides.data();
auto leading_dims = dims - pooling_dims;
// NOTE: Since we're doing unpooling, the variable names "input" and "output"
// are reversed compared to the pooling operations. So in `find_pool_offsets`,
// we need to map "input" -> "output" and "output" -> "input".
PoolOffsets offsets = find_pool_offsets(
/*output_sizes=*/input_sizes,
/*output_strides=*/input_strides,
indices_strides,
/*input_strides=*/output_strides,
/*pooling_dim_indices=*/nullptr,
dims,
leading_dims,
/*return_indices=*/true,
tid);
max_unpool_impl<T>(
output + offsets.input_leading,
input[offsets.output],
indices[offsets.indices],
output_sizes + leading_dims,
output_strides + leading_dims,
pooling_dims);
}
template <typename T>
struct AvgPoolIterBounds {
T start;
@ -358,7 +430,6 @@ void avg_pool_3d_input_iter(
auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto size12 = input_sizes[1] * input_sizes[2];
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = input_strides[0] * i0;
@ -376,6 +447,64 @@ void avg_pool_3d_input_iter(
*output = value_sum / static_cast<T>(divisor);
}
template <typename T>
void avg_pool_backward_3d_input_iter(
device AtomicType_t<T>* grad_input,
constant T* grad_output,
constant int32_t* grad_input_sizes,
constant int32_t* grad_input_strides,
int32_t grad_input_leading_offset,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad,
bool has_divisor_override,
int32_t divisor_override) {
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto grad_val = *grad_output / static_cast<T>(divisor);
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = grad_input_strides[0] * i0;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
auto offset1 = grad_input_strides[1] * i1;
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
auto offset2 = grad_input_strides[2] * i2;
auto pool_offset = offset0 + offset1 + offset2;
AtomicType<T>::atomic_add(
grad_input, grad_input_leading_offset + pool_offset, grad_val);
}
}
}
}
// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void avg_pool(
@ -428,31 +557,97 @@ kernel void avg_pool(
params.divisor_override);
}
#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant PoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
template <typename T>
kernel void avg_pool_backward(
device AtomicType_t<T>* grad_input [[buffer(0)]],
constant T* grad_output [[buffer(1)]],
constant AvgPoolingParams<5>& params [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto grad_input_sizes = params.input_sizes.data();
auto grad_input_strides = params.input_strides.data();
auto grad_output_sizes = params.output_sizes.data();
auto grad_output_strides = params.output_strides.data();
auto kernel_size = params.kernel_size.data();
auto stride = params.stride.data();
auto padding = params.padding.data();
auto leading_dims = dims - pooling_dims;
// This buffer keeps track of the pooling dimension indices of this thread's
// element of the output. We need to fill it with the proper values below.
int32_t pooling_dim_indices[3];
PoolOffsets offsets = find_pool_offsets(
grad_output_sizes,
grad_output_strides,
/*indices_strides=*/nullptr,
grad_input_strides,
pooling_dim_indices,
dims,
leading_dims,
/*return_indices=*/false,
tid);
grad_output += offsets.output;
grad_input_sizes += leading_dims;
grad_input_strides += leading_dims;
avg_pool_backward_3d_input_iter<T>(
grad_input,
grad_output,
grad_input_sizes,
grad_input_strides,
offsets.input_leading,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
}
#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant PoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("max_unpool_" #DTYPE)]] kernel void max_unpool<DTYPE>( \
device DTYPE * output [[buffer(0)]], \
constant DTYPE * input [[buffer(1)]], \
constant int64_t* indices [[buffer(2)]], \
constant MaxUnpoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
#define REGISTER_POOL_BACKWARD_OP(DTYPE) \
template [[host_name("max_pool_backward_" #DTYPE)]] \
kernel void max_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output_ [[buffer(1)]], \
constant int64_t* grad_indices_ [[buffer(2)]], \
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_backward_" #DTYPE)]] \
kernel void avg_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_POOL_OP(float);
REGISTER_POOL_OP(half);
REGISTER_POOL_OP(bfloat);
REGISTER_POOL_OP(int);
REGISTER_POOL_OP(long);
REGISTER_POOL_OP(short);
@ -460,10 +655,6 @@ REGISTER_POOL_OP(char);
REGISTER_POOL_OP(uchar);
REGISTER_POOL_OP(bool);
REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_POOL_OP(bfloat);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
#endif
REGISTER_POOL_BACKWARD_OP(float);
REGISTER_POOL_BACKWARD_OP(half);
REGISTER_POOL_BACKWARD_OP(bfloat);

View File

@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128);
INSTANTIATE_INT4MV(half, 128);
INSTANTIATE_INT4MV(float, 256);
INSTANTIATE_INT4MV(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT4MV(bfloat, 32);
INSTANTIATE_INT4MV(bfloat, 64);
INSTANTIATE_INT4MV(bfloat, 128);
INSTANTIATE_INT4MV(bfloat, 256);
#endif
// ------------------------------ int8 MM For M >= 12 ------------------------------------
/**
@ -234,12 +232,10 @@ template <> struct BlockType<half> {
using simdgroup_type8x8 = simdgroup_half8x8;
using type4 = half4;
};
#if __METAL_VERSION__ >= 310
template <> struct BlockType<bfloat> {
using simdgroup_type8x8 = simdgroup_bfloat8x8;
using type4 = bfloat4;
};
#endif
template<typename T>
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
@ -490,9 +486,7 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
INSTANTIATE_MM(float, char, get_scale_zero_q8);
INSTANTIATE_MM(half, char, get_scale_zero_q8);
#if __METAL_VERSION__ >= 310
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
#endif
// ------------------------------ int8 MM For M < 12 ------------------------------------
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
@ -646,6 +640,4 @@ kernel void kernel_mul_mv<DTYPE>(
INSTANTIATE_MV(float);
INSTANTIATE_MV(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_MV(bfloat);
#endif

View File

@ -192,6 +192,4 @@ template <typename T>
instantiate_rms(float)
instantiate_rms(half)
#if __METAL_VERSION__ >= 310
instantiate_rms(bfloat)
#endif // clang-format on

View File

@ -23,6 +23,4 @@ kernel void renorm(
REGISTER_RENORM_OP(float);
REGISTER_RENORM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_RENORM_OP(bfloat);
#endif

View File

@ -25,379 +25,6 @@ struct LogAddExp {
};
};
#if __METAL_VERSION__ < 310
template <typename T, typename acc_t = accum_t<T>>
struct CumMinOp {
static acc_t apply(acc_t a, acc_t b) {
return metal::min(a, b);
}
static acc_t identity() {
return static_cast<acc_t>(
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
: metal::numeric_limits<T>::max());
}
};
template <typename T, typename acc_t = accum_t<T>>
struct CumMaxOp {
static acc_t apply(acc_t a, acc_t b) {
return metal::max(a, b);
}
static acc_t identity() {
return static_cast<acc_t>(
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
: metal::numeric_limits<T>::lowest());
}
};
template <typename T, typename acc_t = accum_t<T>>
struct LogCumSumExpOp {
static acc_t apply(acc_t x, acc_t y) {
return LogAddExp{}(x, y);
}
static acc_t identity() {
return -metal::numeric_limits<acc_t>::infinity();
}
};
// Inclusive scan along innermost dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_contiguous_innermost_dim(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant uint& num_rows [[buffer(2)]],
constant uint& row_size [[buffer(3)]],
uint row [[thread_position_in_grid]]) {
if (row >= num_rows)
return;
const uint offset = row * row_size;
acc_t accumulator = Op::identity();
for (uint col = 0; col < row_size; col++) {
T val = input[offset + col];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[offset + col] = static_cast<T>(accumulator);
}
}
// Inclusive scan along outer dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_contiguous_outer_dim(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant uint& num_orows [[buffer(2)]],
constant uint& num_irows [[buffer(3)]],
constant uint& row_size [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
const uint orow = thread_index / num_irows;
const uint irow = thread_index % num_irows;
if (orow >= num_orows)
return;
acc_t accumulator = Op::identity();
const uint idx_base = orow * row_size * num_irows + irow;
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
T val = input[idx];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[idx] = static_cast<T>(accumulator);
}
}
// Inclusive scan with indices along innermost dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_contiguous_innermost_dim(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant uint& num_rows [[buffer(3)]],
constant uint& row_size [[buffer(4)]],
uint row [[thread_position_in_grid]]) {
if (row >= num_rows)
return;
const uint offset = row * row_size;
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
for (uint col = 0; col < row_size; col++) {
T val = input[offset + col];
acc_t accum_val = static_cast<acc_t>(val);
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = col;
}
values[offset + col] = static_cast<T>(accumulator);
indices[offset + col] = best_idx;
}
}
// Inclusive scan with indices along outer dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_contiguous_outer_dim(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant uint& num_orows [[buffer(3)]],
constant uint& num_irows [[buffer(4)]],
constant uint& row_size [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
const uint orow = thread_index / num_irows;
const uint irow = thread_index % num_irows;
if (orow >= num_orows)
return;
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
const uint idx_base = orow * row_size * num_irows + irow;
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
T val = input[idx];
acc_t accum_val = static_cast<acc_t>(val);
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = col;
}
values[idx] = static_cast<T>(accumulator);
indices[idx] = best_idx;
}
}
// Shared utility functions for strided kernels
inline long calculate_non_scan_elements(
constant long* sizes,
uint ndim,
uint scan_dim) {
long total = 1;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
total *= sizes[i];
}
}
return total;
}
inline void thread_index_to_coordinates(
uint index,
int pos[c10::metal::max_ndim],
constant long* sizes,
uint ndim,
uint scan_dim) {
long remaining_index = index;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
pos[i] = remaining_index % sizes[i];
remaining_index /= sizes[i];
} else {
pos[i] = 0;
}
}
}
inline long calculate_base_offset(
int pos[c10::metal::max_ndim],
constant long* strides,
uint ndim,
uint scan_dim) {
long offset = 0;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
offset += pos[i] * strides[i];
}
}
return offset;
}
// Generic strided scan kernel
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_strided(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant long* sizes [[buffer(2)]],
constant long* input_strides [[buffer(3)]],
constant long* output_strides [[buffer(4)]],
constant uint& ndim [[buffer(5)]],
constant uint& scan_dim [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
const long total_non_scan_elements =
calculate_non_scan_elements(sizes, ndim, scan_dim);
if (thread_index >= total_non_scan_elements) {
return;
}
int pos[c10::metal::max_ndim];
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
const long input_base_offset =
calculate_base_offset(pos, input_strides, ndim, scan_dim);
const long output_base_offset =
calculate_base_offset(pos, output_strides, ndim, scan_dim);
acc_t accumulator = Op::identity();
const long scan_size = sizes[scan_dim];
const long input_scan_stride = input_strides[scan_dim];
const long output_scan_stride = output_strides[scan_dim];
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
const long output_offset =
output_base_offset + scan_idx * output_scan_stride;
T val = input[input_offset];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[output_offset] = static_cast<T>(accumulator);
}
}
// Generic strided scan with indices kernel
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_strided(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant long* sizes [[buffer(3)]],
constant long* input_strides [[buffer(4)]],
constant long* values_strides [[buffer(5)]],
constant long* indices_strides [[buffer(6)]],
constant uint& ndim [[buffer(7)]],
constant uint& scan_dim [[buffer(8)]],
uint thread_index [[thread_position_in_grid]]) {
const long total_non_scan_elements =
calculate_non_scan_elements(sizes, ndim, scan_dim);
if (thread_index >= total_non_scan_elements) {
return;
}
int pos[c10::metal::max_ndim];
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
const long input_base_offset =
calculate_base_offset(pos, input_strides, ndim, scan_dim);
const long values_base_offset =
calculate_base_offset(pos, values_strides, ndim, scan_dim);
const long indices_base_offset =
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
const long scan_size = sizes[scan_dim];
const long input_scan_stride = input_strides[scan_dim];
const long values_scan_stride = values_strides[scan_dim];
const long indices_scan_stride = indices_strides[scan_dim];
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
const long values_offset =
values_base_offset + scan_idx * values_scan_stride;
const long indices_offset =
indices_base_offset + scan_idx * indices_scan_stride;
T val = input[input_offset];
acc_t accum_val = static_cast<acc_t>(val);
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = scan_idx;
}
values[values_offset] = static_cast<T>(accumulator);
indices[indices_offset] = best_idx;
}
}
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant uint & num_rows [[buffer(2)]], \
constant uint & row_size [[buffer(3)]], \
uint row [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant uint & num_orows [[buffer(2)]], \
constant uint & num_irows [[buffer(3)]], \
constant uint & row_size [[buffer(4)]], \
uint thread_index [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant long* sizes [[buffer(2)]], \
constant long* input_strides [[buffer(3)]], \
constant long* output_strides [[buffer(4)]], \
constant uint& ndim [[buffer(5)]], \
constant uint& scan_dim [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant uint& num_rows [[buffer(3)]], \
constant uint& row_size [[buffer(4)]], \
uint row [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant uint& num_orows [[buffer(3)]], \
constant uint& num_irows [[buffer(4)]], \
constant uint& row_size [[buffer(5)]], \
uint thread_index [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant long* sizes [[buffer(3)]], \
constant long* input_strides [[buffer(4)]], \
constant long* values_strides [[buffer(5)]], \
constant long* indices_strides [[buffer(6)]], \
constant uint& ndim [[buffer(7)]], \
constant uint& scan_dim [[buffer(8)]], \
uint thread_index [[thread_position_in_grid]]);
// Simple scan operations
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
// Scan operations with indices
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
#else // __METAL_VERSION__ >= 310
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
// The reminder of this file contains cummin and cummax implementations adapted
@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
#endif

View File

@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float);
REGISTER_SPECIAL(int, float);
REGISTER_SPECIAL(long, float);
REGISTER_SPECIAL(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_SPECIAL(bfloat, bfloat);
#endif

View File

@ -100,9 +100,7 @@ kernel void triul(
INSTANTIATE_TRIUL_KERNELS(float, int);
INSTANTIATE_TRIUL_KERNELS(half, int);
#if __METAL_VERSION__ >= 310
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
#endif
INSTANTIATE_TRIUL_KERNELS(float2, int);
INSTANTIATE_TRIUL_KERNELS(half2, int);

View File

@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half);
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
REGISTER_UNARY_OP(neg, bfloat, bfloat);
REGISTER_UNARY_OP(abs, bfloat, bfloat);
#endif
INSTANTIATE_UNARY_KERNELS2(half, half);
INSTANTIATE_UNARY_KERNELS2(float, float);
INSTANTIATE_UNARY_KERNELS2(float, bool);
@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
#endif

View File

@ -70,6 +70,4 @@ kernel void unfold_backward(
INSTANTIATE_UNFOLD_BACKWARD(float);
INSTANTIATE_UNFOLD_BACKWARD(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
#endif

View File

@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
INSTANTIATE_UPSAMPLE_3D(uchar);
INSTANTIATE_UPSAMPLE_ALL(float);
INSTANTIATE_UPSAMPLE_ALL(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UPSAMPLE_ALL(bfloat);
#endif

View File

@ -14,6 +14,7 @@
#include <ATen/ops/avg_pool2d_backward.h>
#include <ATen/ops/avg_pool2d_backward_native.h>
#include <ATen/ops/avg_pool2d_native.h>
#include <ATen/ops/avg_pool3d_backward_native.h>
#include <ATen/ops/avg_pool3d_native.h>
#include <ATen/ops/max_pool2d_backward_native.h>
#include <ATen/ops/max_pool2d_native.h>
@ -21,6 +22,8 @@
#include <ATen/ops/max_pool2d_with_indices_native.h>
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
#include <ATen/ops/max_pool3d_with_indices_native.h>
#include <ATen/ops/max_unpool2d_native.h>
#include <ATen/ops/max_unpool3d_native.h>
#endif
namespace at::native {
@ -492,6 +495,60 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
});
}
static void max_unpool_out_mps_template(const Tensor& input,
const Tensor& indices,
IntArrayRef output_size_,
IntArrayRef stride,
IntArrayRef padding,
Tensor& output,
const int32_t pooling_dims,
const std::string& op_name) {
auto dims = input.dim();
auto leading_dims = input.dim() - pooling_dims;
const auto memory_format = input.suggest_memory_format();
std::vector<int64_t> output_size(dims);
for (int dim : c10::irange(leading_dims)) {
output_size[dim] = input.sizes()[dim];
}
for (int dim : c10::irange(pooling_dims)) {
output_size[leading_dims + dim] = output_size_[dim];
}
output.resize_(output_size, memory_format);
output.fill_(0);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = input.numel();
MaxUnpoolingParams<5> params;
params.dims = dims;
params.pooling_dims = pooling_dims;
for (const auto dim : c10::irange(dims)) {
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
}
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("max_unpool_" + scalarToMetalTypeString(input));
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, output, input, indices, params);
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}
static void avg_pool2d_template(const Tensor& input,
const Tensor& output,
const std::optional<Tensor>& grad_output_opt,
@ -669,6 +726,64 @@ static void avg_pool_out_mps_template(const Tensor& output,
});
}
static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
const Tensor& input,
const Tensor& grad_output,
IntArrayRef _kernel_size,
IntArrayRef _stride,
IntArrayRef _padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, _, kernel_size, stride, padding, __] =
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
const auto memory_format = input.suggest_memory_format();
grad_input.resize_(input.sizes(), memory_format);
grad_input.fill_(0);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = grad_output.numel();
AvgPoolingParams<5> params;
params.dims = dims;
params.pooling_dims = pooling_dims;
params.count_include_pad = count_include_pad;
params.has_divisor_override = divisor_override.has_value();
if (divisor_override.has_value()) {
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
}
for (const auto dim : c10::irange(dims)) {
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
}
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input));
getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, grad_input, grad_output, params);
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}
} // namespace mps
Tensor mps_max_pool2d(const Tensor& input,
@ -896,6 +1011,68 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
return grad_input;
}
Tensor& max_unpooling2d_forward_out_mps(const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
Tensor& output) {
mps::max_unpool_out_mps_template(self,
indices,
output_size,
/*stride=*/{},
/*padding=*/{},
output,
/*pooling_dims=*/2,
"max_unpool2d");
return output;
}
Tensor max_unpooling2d_forward_mps(const Tensor& self, const Tensor& indices, IntArrayRef output_size) {
auto output = at::empty({0}, self.options());
mps::max_unpool_out_mps_template(self,
indices,
output_size,
/*stride=*/{},
/*padding=*/{},
output,
/*pooling_dims=*/2,
"max_unpool2d");
return output;
}
Tensor& max_unpooling3d_forward_out_mps(const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding,
Tensor& output) {
mps::max_unpool_out_mps_template(self,
indices,
output_size,
stride,
padding,
output,
/*pooling_dims=*/3,
"max_unpool3d");
return output;
}
Tensor max_unpooling3d_forward_mps(const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
auto output = at::empty({0}, self.options());
mps::max_unpool_out_mps_template(self,
indices,
output_size,
stride,
padding,
output,
/*pooling_dims=*/3,
"max_unpool3d");
return output;
}
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
(const Tensor& input,
int64_t kH,
@ -965,4 +1142,26 @@ TORCH_IMPL_FUNC(avg_pool3d_out_mps)
"avg_pool3d");
}
TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const Tensor& grad_input) {
mps::avg_pool_backward_out_mps_template(grad_input,
input,
grad_output,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
/*pooling_dims=*/3,
"avg_pool3d_backward");
}
} // namespace at::native

View File

@ -719,6 +719,7 @@
dispatch:
CPU, CUDA: all_out
MPS: all_out_mps
MTIA: all_out_mtia
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -808,6 +809,7 @@
CPU, Meta: arange_out
CUDA: arange_cuda_out
MPS: arange_mps_out
MTIA: arange_mtia_out
cpp_no_default_args: ['step']
# This function is a temporary hack to allow tracing of arange like constructs with dynamic
@ -1889,7 +1891,10 @@
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CUDA: cudnn_batch_norm
autogen: cudnn_batch_norm.out
- func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
dispatch:
CUDA: cudnn_batch_norm_out
# NB: You can only use this if you used cudnn_batch_norm training=True
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
@ -4182,11 +4187,13 @@
dispatch:
CPU: _int_mm_cpu
CUDA: _int_mm_cuda
XPU: _int_mm_xpu
- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: _int_mm_out_cpu
CUDA: _int_mm_out_cuda
XPU: _int_mm_out_xpu
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
dispatch:
@ -4223,6 +4230,7 @@
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
dispatch:
CPU: _weight_int8pack_mm_cpu
CUDA: _weight_int8pack_mm_cuda
MPS: _weight_int8pack_mm_mps
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
@ -7124,18 +7132,21 @@
dispatch:
CPU: _scaled_mm_cpu
CUDA: _scaled_mm_cuda
tags: needs_exact_strides
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: _scaled_mm_out_cpu
CUDA: _scaled_mm_out_cuda
tags: needs_exact_strides
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function
@ -10487,6 +10498,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_
CUDA: foreach_tensor_add_scalar_kernel_cuda_
MTIA: foreach_tensor_add_scalar_kernel_mtia_
autogen: _foreach_add.Scalar_out
- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
@ -10495,6 +10507,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow
CUDA: foreach_tensor_add_list_kernel_cuda
MTIA: foreach_tensor_add_list_kernel_mtia
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10502,6 +10515,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_
CUDA: foreach_tensor_add_list_kernel_cuda_
MTIA: foreach_tensor_add_list_kernel_mtia_
autogen: _foreach_add.List_out
- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
@ -10532,6 +10546,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_
CUDA: foreach_tensor_add_tensor_kernel_cuda_
MTIA: foreach_tensor_add_tensor_kernel_mtia_
autogen: _foreach_add.Tensor_out
- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
@ -10592,6 +10607,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
MTIA: foreach_tensor_mul_scalar_kernel_mtia_
autogen: _foreach_mul.Scalar_out
- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
@ -10600,6 +10616,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow
CUDA: foreach_tensor_mul_list_kernel_cuda
MTIA: foreach_tensor_mul_list_kernel_mtia
- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10607,6 +10624,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_
CUDA: foreach_tensor_mul_list_kernel_cuda_
MTIA: foreach_tensor_mul_list_kernel_mtia_
autogen: _foreach_mul.List_out
- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
@ -10630,6 +10648,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow
CUDA: foreach_tensor_mul_tensor_kernel_cuda
MTIA: foreach_tensor_mul_tensor_kernel_mtia
- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10637,6 +10656,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_
CUDA: foreach_tensor_mul_tensor_kernel_cuda_
MTIA: foreach_tensor_mul_tensor_kernel_mtia_
autogen: _foreach_mul.Tensor_out
- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
@ -10933,6 +10953,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow
CUDA: foreach_tensor_addcmul_scalar_cuda
MTIA: foreach_tensor_addcmul_scalar_mtia
- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10954,6 +10975,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_
CUDA: foreach_tensor_addcmul_scalar_cuda_
MTIA: foreach_tensor_addcmul_scalar_mtia_
autogen: _foreach_addcmul.Scalar_out
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
@ -10978,6 +11000,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_abs_slow
CUDA: foreach_tensor_abs_cuda
MTIA: foreach_tensor_abs_mtia
- func: _foreach_abs_(Tensor(a!)[] self) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10985,6 +11008,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_abs_slow_
CUDA: foreach_tensor_abs_cuda_
MTIA: foreach_tensor_abs_mtia_
autogen: _foreach_abs.out
- func: _foreach_acos(Tensor[] self) -> Tensor[]
@ -11319,6 +11343,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_norm_slow
CUDA: foreach_tensor_norm_cuda
MTIA: foreach_tensor_norm_mtia
autogen: _foreach_norm.Scalar_out
- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
@ -11491,6 +11516,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_sqrt_slow_
CUDA: foreach_tensor_sqrt_cuda_
MTIA: foreach_tensor_sqrt_mtia_
autogen: _foreach_sqrt.out
- func: _foreach_tan(Tensor[] self) -> Tensor[]
@ -11552,6 +11578,7 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
CUDA: foreach_tensor_copy_list_kernel_cuda_
MTIA: foreach_tensor_copy_list_kernel_mtia_
autogen: _foreach_copy.out
- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
@ -11559,6 +11586,7 @@
variants: function
dispatch:
CompositeExplicitAutograd: _foreach_copy
MTIA: foreach_tensor_copy_list_kernel_mtia
- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
dispatch:
@ -12351,6 +12379,7 @@
dispatch:
CPU: avg_pool3d_backward_out_cpu
CUDA: avg_pool3d_backward_out_cuda
MPS: avg_pool3d_backward_out_mps
MkldnnCPU: mkldnn_avg_pool3d_backward_out
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
@ -12476,24 +12505,28 @@
dispatch:
CPU: max_unpooling2d_forward_out_cpu
CUDA: max_unpooling2d_forward_out_cuda
MPS: max_unpooling2d_forward_out_mps
- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
python_module: nn
dispatch:
CPU: max_unpooling2d_forward_cpu
CUDA: max_unpooling2d_forward_cuda
MPS: max_unpooling2d_forward_mps
- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
dispatch:
CPU: max_unpooling3d_forward_out_cpu
CUDA: max_unpooling3d_forward_out_cuda
MPS: max_unpooling3d_forward_out_mps
- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
python_module: nn
dispatch:
CPU: max_unpooling3d_forward_cpu
CUDA: max_unpooling3d_forward_cuda
MPS: max_unpooling3d_forward_mps
- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn

View File

@ -1,7 +1,7 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret
)
@ -11,7 +11,27 @@ endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
--api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_SPLITKV kernels via Python.")
endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_APPENDKV kernels via Python.")
endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
RESULT_VARIABLE ret
)
@ -19,15 +39,29 @@ if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
endif()
# Generate the files for both fwd and bwd
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
# Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
endif()
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.")
endif()
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.")
endif()
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
)
@ -44,6 +78,22 @@ if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass")
endif()
execute_process(
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt"
RESULT_VARIABLE ret)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd_splitkv pass")
endif()
execute_process(
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt"
RESULT_VARIABLE ret)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass")
endif()
# Change make_kernel to make_kernel_pt for bwd
execute_process(
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt"

View File

@ -21,6 +21,8 @@ while IFS= read -r file; do
if [ -f "$file" ]; then
# Use sed to replace "make_kernel" with "make_kernel_pt" in place
sed -i 's/make_kernel/make_kernel_pt/g' "$file"
sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
echo "Updated: $file"
else
echo "Skipping: $file (not found)"

View File

@ -1,100 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include <ck_tile/core.hpp>
#include <ck_tile/ops/fmha.hpp>
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};

View File

@ -1,457 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/ops/fmha.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <mask.hpp>
#include <bias.hpp>
#include <launch_kernel_pt.hpp>
#include <type_traits>
#include <utility>
#include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<FmhaBwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using GemmDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::half_t;
using OGradDataType = ck_tile::half_t;
using QGradDataType = ck_tile::half_t;
using KGradDataType = ck_tile::half_t;
using VGradDataType = ck_tile::half_t;
using BiasGradDataType = ck_tile::half_t;
};
template <>
struct FmhaBwdTypeConfig<FmhaBwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using GemmDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::bf16_t;
using OGradDataType = ck_tile::bf16_t;
using QGradDataType = ck_tile::bf16_t;
using KGradDataType = ck_tile::bf16_t;
using VGradDataType = ck_tile::bf16_t;
using BiasGradDataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_bwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
const void* o_ptr;
const void* lse_ptr;
const void* do_ptr;
void* d_ptr;
void* rand_val_ptr;
void* dq_ptr;
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdOGradDotOKernel>
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
{
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_do,
args.batch_stride_o,
args.batch_stride_lsed);
}
}();
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api
};
template <int Version = 2>
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

View File

@ -1,824 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/ops/fmha.hpp>
#include <bias.hpp>
#include <mask.hpp>
#include <rotary.hpp>
#include <launch_kernel_pt.hpp>
#include <type_traits>
#include <utility>
#include <variant>
struct FmhaFwdFp16
{
};
struct FmhaFwdBf16
{
};
struct FmhaFwdFp8
{
};
struct FmhaFwdBf8
{
};
struct FmhaFwdFp8Fp16
{
};
struct FmhaFwdFp8Bf16
{
};
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::fp8_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdBf8>
{
using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t;
using BiasDataType = ck_tile::bf8_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf8_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
struct fmha_fwd_splitkv_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr;
void* o_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const void* cache_batch_idx;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
};
struct fmha_fwd_appendkv_args
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
ck_tile::index_t rotary_dim;
bool has_mask;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.scale_p,
args.scale_o,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.scale_p,
args.scale_o,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
if constexpr(FmhaKernel::kIsGroupMode)
{
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
}
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
}();
dim3 grids = Kernel::GridSize(
args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel argumentszs
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.seqlen_q,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.batch_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.has_mask,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_splitkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kN1_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadDv_>
struct fmha_fwd_splitkv_combine_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kTileSizeS_,
ck_tile::index_t kTileSizeSk_,
ck_tile::index_t kTileSizeD_,
ck_tile::index_t kTileSizeDv_,
bool kIsVLayoutRowMajor_,
bool kPadS_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
struct fmha_fwd_splitkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
fmha_fwd_splitkv_args,
const ck_tile::stream_config&);
struct fmha_fwd_appendkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
rope_enum rope_type;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);

View File

@ -1,157 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include <ck_tile/core.hpp>
#include <ck_tile/ops/fmha.hpp>
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
mask_top_left,
mask_bottom_right,
window_generic,
};
struct mask_info
{
mask_enum type;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::mask_top_left)
os << "t(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "b(" << left << ":" << right << ")";
else
{
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
{
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "xt" || t == "xb")
{
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = atoi(v.c_str());
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = left_size;
tmp.right = right_size;
}
else
{
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
assert(0);
}
}
}
else
{
auto set_causal_top_left = [&]() {
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
};
auto set_causal_bottom_right = [&]() {
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
};
if(str == "t")
set_causal_top_left();
else if(str == "b")
set_causal_bottom_right();
else
{
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::mask_top_left)
{
set_causal_top_left();
}
else if(tmp.type == mask_enum::mask_bottom_right)
{
set_causal_bottom_right();
}
}
}
return tmp;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
};

View File

@ -22,6 +22,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
dtype,
false, // is_group_mode
true, // is_v_rowmajor
false, // has_logits_soft_cap
mask.type,
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
has_lse,
@ -85,6 +86,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
ck_tile::index_t stride_attn_bias = 0;
ck_tile::index_t batch_stride_bias = 0;
ck_tile::index_t nhead_stride_bias = 0;
if (attn_bias_.has_value()) {
auto a_b = attn_bias_.value();
CHECK_DEVICE(a_b);
@ -94,7 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
nhead_stride_bias = a_b.stride(1);
batch_stride_bias = a_b.stride(0);
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
@ -116,6 +117,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
0.0f, // logits_soft_cap
stride_q,
stride_k,
stride_v,
@ -139,6 +141,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
-1, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};

View File

@ -20,6 +20,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
false, // has_logits_soft_cap
mask.type,
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
has_lse,
@ -117,6 +118,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
0.0f, // logits_soft_cap
stride_q,
stride_k,
stride_v,
@ -140,6 +142,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
-1, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};

View File

@ -1,84 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}

View File

@ -5,6 +5,12 @@ import os
import sys
# Run only this selected group of models, leave this empty to run everything
TORCHBENCH_ONLY_MODELS = [
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
]
# Note - hf and timm have their own version of this, torchbench does not
# TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this...
def model_names(filename: str) -> set[str]:
@ -17,6 +23,8 @@ def model_names(filename: str) -> set[str]:
if len(line_parts) == 1:
line_parts = line.split(",")
model_name = line_parts[0]
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
continue
names.add(model_name)
return names

View File

@ -9,6 +9,7 @@ import copy
import csv
import dataclasses
import functools
import gc
import importlib
import itertools
import json
@ -2387,6 +2388,7 @@ class BenchmarkRunner:
)
def warmup(fn, model, example_inputs, mode, niters=10):
gc.collect()
peak_mem = 0
start_stats = get_dynamo_stats()
try:
@ -2548,6 +2550,7 @@ class BenchmarkRunner:
return experiment(*self.maybe_cast(model, example_inputs))
def warmup(fn, model, example_inputs, mode, niters=5):
gc.collect()
peak_mem = 0
start_stats = get_dynamo_stats()
try:

View File

@ -106,6 +106,11 @@ finally:
# on A100 GPUs - 40 GB.
BATCH_SIZE_KNOWN_MODELS = {}
# Run only this selected group of models, leave this empty to run everything
TORCHBENCH_ONLY_MODELS = [
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
]
# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
# Get the list of models and their batch sizes
@ -116,6 +121,8 @@ with open(MODELS_FILENAME) as fh:
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(",")
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
continue
batch_size = int(batch_size)
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
assert len(BATCH_SIZE_KNOWN_MODELS)

View File

@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1009000000,0.1
@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,8787000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3959000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10650000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4461000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 8417000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 8348000000 8787000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7464000000 0.1
24
82
83
84
85
86
87
88

View File

@ -39,13 +39,20 @@ finally:
from timm.models import create_model
TIMM_MODELS = {}
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
# Run only this selected group of models, leave this empty to run everything
TORCHBENCH_ONLY_MODELS = [
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
]
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
with open(filename) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
continue
TIMM_MODELS[model_name] = int(batch_size)

View File

@ -224,7 +224,7 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
// check if the key is unrecognized.
if (device_config_parser_hook_) {
TORCH_CHECK(
keys_.find(key) != keys_.end(),
getKeys().find(key) != getKeys().end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");

View File

@ -220,11 +220,24 @@ class C10_API AcceleratorAllocatorConfig {
return instance().last_allocator_settings_;
}
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
// issue.
static std::unordered_set<std::string>& getMutableKeys() {
static std::unordered_set<std::string> keys{
"max_split_size_mb",
"max_non_split_rounding_mb",
"garbage_collection_threshold",
"roundup_power2_divisions",
"expandable_segments",
"pinned_use_background_threads"};
return keys;
}
// Returns the set of valid keys for the allocator configuration.
// This set is used to validate the presence and correctness of keys in
// device-specific configuration parsers.
static const std::unordered_set<std::string>& getKeys() {
return keys_;
return getMutableKeys();
}
// Registers a device-specific configuration parser hook and its key. This
@ -238,9 +251,10 @@ class C10_API AcceleratorAllocatorConfig {
std::function<void(const std::string&)>&& hook,
const std::unordered_set<std::string>& keys) {
device_config_parser_hook_ = std::move(hook);
auto& mutable_keys = getMutableKeys();
for (auto& key : keys) {
TORCH_CHECK(
keys_.insert(key).second,
mutable_keys.insert(key).second,
"Duplicated key '",
key,
"' found in device-specific configuration parser hook registration");
@ -326,17 +340,6 @@ class C10_API AcceleratorAllocatorConfig {
// their own environment configuration extensions.
inline static std::function<void(const std::string&)>
device_config_parser_hook_{nullptr};
// A set of valid configuration keys, including both common and
// device-specific options. This set is used to validate the presence and
// correctness of keys during parsing.
inline static std::unordered_set<std::string> keys_{
"max_split_size_mb",
"max_non_split_rounding_mb",
"garbage_collection_threshold",
"roundup_power2_divisions",
"expandable_segments",
"pinned_use_background_threads"};
};
C10_API inline void setAllocatorSettings(const std::string& env) {

View File

@ -0,0 +1,10 @@
#include <c10/core/CachingDeviceAllocator.h>
namespace c10 {
// Ensures proper DLL export of this pure virtual base class on Windows,
// since it's mainly used in other DLLs outside c10.dll.
DeviceAllocator::DeviceAllocator() = default;
DeviceAllocator::~DeviceAllocator() = default;
} // namespace c10

View File

@ -1,6 +1,7 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/Stream.h>
namespace c10::CachingDeviceAllocator {
@ -59,3 +60,55 @@ struct DeviceStats {
};
} // namespace c10::CachingDeviceAllocator
namespace c10 {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by Graph mode capture_begin.
// second is set if the instance is created by Graph mode graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
struct C10_API DeviceAllocator : public c10::Allocator {
DeviceAllocator();
~DeviceAllocator() override;
// Returns true if the allocator has been properly initialized and is ready
// for use
virtual bool initialized() = 0;
// Releases all cached device memory from the specified memory pool back to
// the system
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
// Associates a memory allocation with a stream to establish dependency
// tracking. Prevents memory reuse until all operations on the specified
// stream complete
virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0;
// Retrieves comprehensive memory statistics for the specified device,
// including allocation patterns, usage metrics
virtual CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) = 0;
// Resets cumulative allocation statistics for the specified device to zero
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type
// and keep backward compatibility with c10::GetAllocator.
C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) {
TORCH_CHECK(
t != DeviceType::CPU,
"getDeviceAllocator is not supported for CPU device type.");
auto* allocator = c10::GetAllocator(t);
auto* device_allocator = dynamic_cast<DeviceAllocator*>(allocator);
TORCH_INTERNAL_ASSERT(
device_allocator, "Allocator for ", t, " is not a DeviceAllocator.");
return device_allocator;
}
} // namespace c10

View File

@ -191,11 +191,17 @@ class C10_API Scalar {
isIntegral() const {
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag;
}
bool isIntegral(bool includeBool) const {
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag ||
(includeBool && isBoolean());
}
// See Note [Meaning of HAS_u]
bool isUnsigned() const {
return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0);
}
bool isComplex() const {
return Tag::HAS_z == tag;
}

View File

@ -19,25 +19,16 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <ostream>
#include <type_traits>
#include <unordered_map>
#include <torch/headeronly/core/ScalarType.h>
namespace c10 {
// dummy struct for uint1 to uint7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_uint1_7_t {};
// dummy struct for int1 to int7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_int1_7_t {};
// For the macros below:
// [dtype Macros note] For the macros below:
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
@ -57,56 +48,6 @@ struct dummy_int1_7_t {};
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
@ -152,17 +93,6 @@ struct dummy_int1_7_t {};
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
Undefined,
NumOptions
};
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
namespace impl {
// These are used to map ScalarTypes to C++ types.

View File

@ -110,8 +110,22 @@ class C10_CUDA_API CUDAAllocatorConfig {
return instance().m_use_async_allocator;
}
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
// issue.
static const std::unordered_set<std::string>& getKeys() {
return keys_;
static std::unordered_set<std::string> keys{
"backend",
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_cud"
"amalloc",
"pinned_use_cud"
"a_host_register",
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_hipmalloc",
"pinned_use_hip_host_register",
"pinned_num_register_threads"};
return keys;
}
static CUDAAllocatorConfig& instance() {
@ -163,18 +177,6 @@ class C10_CUDA_API CUDAAllocatorConfig {
std::atomic<bool> m_pinned_use_cuda_host_register{false};
std::atomic<bool> m_use_async_allocator{false};
std::atomic<bool> m_is_allocator_loaded{false};
inline static std::unordered_set<std::string> keys_{
"backend",
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_cud"
"amalloc",
"pinned_use_cud"
"a_host_register",
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_hipmalloc",
"pinned_use_hip_host_register",
"pinned_num_register_threads"};
};
// Keep this for backwards compatibility

Some files were not shown because too many files have changed in this diff Show More