Compare commits

..

135 Commits

Author SHA1 Message Date
2f10f1b888 add config 2025-10-02 12:49:06 -07:00
517f267085 add config 2025-10-02 12:47:06 -07:00
cac7242b91 add config 2025-10-02 10:42:53 -07:00
b54dc58cb5 add config 2025-10-02 09:38:22 -07:00
4efdd216bd add config 2025-10-02 09:36:24 -07:00
95654a32f5 add config 2025-10-02 09:35:34 -07:00
2f5c2ccf7a add config 2025-10-02 09:35:18 -07:00
813cae6074 add config 2025-10-02 09:34:38 -07:00
ef4730d5bb add config 2025-10-02 08:34:27 -07:00
3ad3df90c3 add config 2025-10-02 08:33:04 -07:00
257bf0e654 add config 2025-10-02 08:31:23 -07:00
02d16522d8 add config 2025-10-02 08:29:56 -07:00
e6d3372157 Would this work?
Why did it link with CUDA runtime here https://github.com/pytorch/pytorch/actions/runs/18187580810/job/51775137099?pr=164361#step:17:20314?
2025-10-02 03:21:49 -07:00
3f3d86adf2 Bring in https://github.com/vllm-project/vllm/pull/25730 2025-10-02 01:24:01 -07:00
58478b0ab8 Another tweak 2025-10-02 00:07:11 -07:00
98e554222f [no ci] Another tweak 2025-10-02 00:06:15 -07:00
700d608f4a Remove somewhat unnecessary change 2025-10-01 22:48:01 -07:00
1b27857415 Minor tweak 2025-10-01 22:44:50 -07:00
73995b1b5e add config 2025-10-01 18:04:39 -07:00
03d7c77071 add config 2025-10-01 16:07:45 -07:00
019d9cda40 add config 2025-10-01 15:53:13 -07:00
3620191a0a add config 2025-10-01 15:47:36 -07:00
5a722ca130 add config 2025-10-01 14:00:05 -07:00
8746e3cea2 add config 2025-10-01 12:45:52 -07:00
8cd1996b57 add config 2025-10-01 11:22:17 -07:00
73c23f3554 add config 2025-10-01 09:49:30 -07:00
1da3d6f595 add config 2025-10-01 08:32:45 -07:00
b1033789fe Use TMA loads always for Triton grouped MM kernel (#164256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164256
Approved by: https://github.com/ngimel
ghstack dependencies: #163895
2025-10-01 15:24:51 +00:00
07d896fa48 Revert "CUDACachingHostAllocatorImpl skip event query during capture (#164001)"
This reverts commit 4cf29004749714670fee9e7e3776778faf5ced25.

Reverted https://github.com/pytorch/pytorch/pull/164001 on behalf of https://github.com/yangw-dev due to failed internal error with multiple errors found: Not equal to tolerance rtol=0.1, atol=0.1.. ([comment](https://github.com/pytorch/pytorch/pull/164001#issuecomment-3356894787))
2025-10-01 15:11:21 +00:00
31681bcacc [PyTorch] Pull ARM's box-cox (#164152)
Summary:
ARM has provided with an SVE128 box-cox implementation.

It uses the same underlying algorithm as the previous version, but it has better log and exp implementations.
These supplied mathematical functions have switches to adjust the precision/speed trade-off.

We've noted a slight precision improvement, while also about a 5% peroformance increase

Before:

ZeroLambda1                                                61.66ns    16.22M
NonZeroLambda1                                            125.73ns     7.95M
NonZeroLambdaManyColumns                                    1.84ms    542.11
NonZeroLambdaEigenColumnar                                262.31us     3.81K
NonZeroLambdaEigenRowMajor                                275.17us     3.63K
NonZeroLambdaWithPyTorchColumnar                           97.43us    10.26K
NonZeroLambdaWithPyTorchRowMajor                           90.82us    11.01K
NonZeroLambdaWithPyTorchRowMajorFullBatch                  96.96us    10.31K
NonZeroLambdaBatch                                        151.84us     6.59K

After:

ZeroLambda1                                                57.85ns    17.29M
NonZeroLambda1                                            118.85ns     8.41M
NonZeroLambdaManyColumns                                    1.82ms    548.16
NonZeroLambdaEigenColumnar                                261.67us     3.82K
NonZeroLambdaEigenRowMajor                                274.53us     3.64K
NonZeroLambdaWithPyTorchColumnar                           89.12us    11.22K
NonZeroLambdaWithPyTorchRowMajor                           83.49us    11.98K
NonZeroLambdaWithPyTorchRowMajorFullBatch                  88.79us    11.26K
NonZeroLambdaBatch                                        144.74us     6.91K

Test Plan:
Correctness:

buck2 test @//mode/opt //koski/functions_contrib/df4ai/tests:batch_box_cox_test

Performance:

buck2 run @//mode/opt //koski/functions_contrib/df4ai/benchmark:boxcox_benchmark

Differential Revision:
D83485704

Privacy Context Container: L1196524

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164152
Approved by: https://github.com/ezyang
2025-10-01 15:00:03 +00:00
e901866dd7 Add a RECORD_FUNCTION for Python fallback so it shows in profile (#160573)
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160573
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2025-10-01 14:10:44 +00:00
70d1043bdf Fix non-TMA loads in grouped MM Triton kernel (#163895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163895
Approved by: https://github.com/lezcano
2025-10-01 12:21:13 +00:00
69fa26d9b4 Triton 3.5.x pin update (#164268)
Updates triton pin to latest: https://github.com/triton-lang/triton/commits/release/3.5.x/

This updates contains 2 cherry-pick to remove Python 3.9 from list of supported python versions:
https://github.com/triton-lang/triton/pull/8288
https://github.com/triton-lang/triton/pull/8287
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164268
Approved by: https://github.com/aakhundov
2025-10-01 11:41:50 +00:00
d9c80ef97d Build and Install Arm Compute Library in manylinux docker image (#159737)
----

This PR will be part of a series of PR's that aims to remove `.ci/aarch64_linux` folder entirely, such that Aarch64 manylinux build happens as part of `.ci/manywheel/build.sh`, the same as other platforms.

In this PR:

- We prebuild + install Arm Compute Library in the manylinux docker image ( at /acl ), instead of a build time for every pytorch build.  Also updated jammy install path to be /acl too.
- We can therefore remove build_ArmComputeLibrary functions from the ci build scripts.
- There is also some refactoring of install_openblas.sh and install_acl.sh to align them together ( similar formatting, similar variable names, same place for version number update )
- We had 2 places to define openblas version, this has been reduced to 1 now ( install_openblas.sh ).
- ACL_VERSION and OPENBLAS_VERSION are now able to be overriden at build.sh level for developers, but there is only 1 version of each hardcoded for ci.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159737
Approved by: https://github.com/seemethere, https://github.com/aditew01
2025-10-01 11:33:51 +00:00
ac1bc51608 [dynamo] do not pop from framelocals dict in Python 3.10 (#164316)
Followup to https://github.com/pytorch/pytorch/pull/164038

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164316
Approved by: https://github.com/anijain2305
2025-10-01 10:20:46 +00:00
ed90040d33 Releases multicast object before releasing mapped buffers in CUDASymmetricMemory (#163750)
Fixes: https://github.com/pytorch/pytorch/issues/162429. In B200, cuMulticastUnbind can error if the mapped buffers are free'd before the multicast object is free'd. The only documentation I could find is here: e11d7f77c1/src/transport/nvls.cc (L113).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163750
Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/kwen2501, https://github.com/nWEIdia, https://github.com/cyyever
ghstack dependencies: #163575
2025-10-01 09:07:48 +00:00
4dab208d97 Adds Issue#153109 as a test for CUDAPluggableAllocator (#163575)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163575
Approved by: https://github.com/ngimel
2025-10-01 09:07:48 +00:00
9fd53a2bdc Register MTIA kernel for all_all_out (#164293)
Reviewed By: srsuryadev

Differential Revision: D83517879

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164293
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-01 09:05:08 +00:00
17ab99463a [Easy] Add notes for setting up dev venv with specific Python version (#164214)
Resolves https://github.com/pytorch/pytorch/issues/164010#issuecomment-3340751377

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164214
Approved by: https://github.com/ezyang
ghstack dependencies: #162324
2025-10-01 08:25:13 +00:00
eca6ac2293 [BE][Easy] update CUDA and ROCm sources in nightly tool (#162324)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162324
Approved by: https://github.com/ezyang
2025-10-01 08:25:13 +00:00
12d4cb0122 Suppress FutureWarnings in torch.distributed.algorithms.ddp_comm_hooks (#163939)
Fixes #163938

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163939
Approved by: https://github.com/cyyever, https://github.com/kwen2501
2025-10-01 07:51:12 +00:00
590224f83c Improve repeat op to a single copy (#163842)
In #163455 , the `reshape` was not a pure view op.

The `permute` before it created an non-contiguous tensor, which would trigger a data copy during the reshape.

This PR improved the implementation by remove the `urtensor` intermediate tensor completely.
By simply expanding the `xtensor` would achieve the `repeat` effect.

Before this PR, there were two data copies (in `urtensor.copy_` and `urtensor.reshape`).
Now, there is only one data copy in the `.copy_()`.
Reshape would not copy data because it is on a contiguous tensor.

One more note is that we do want at one copy because we want to duplicate the elements for the repeats.
User can inplace modify single elements without afffecting others.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163842
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-10-01 06:27:53 +00:00
cc8b14d09a [2/N] Simplify "in" operation for containers of a single item (#164323)
These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164323
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
2025-10-01 05:39:11 +00:00
96c3b9e275 [dynamo] Use strings instead of modules for fqn info tracking (#164272)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164272
Approved by: https://github.com/Skylion007, https://github.com/williamwen42, https://github.com/mlazos
2025-10-01 04:22:57 +00:00
9ddfc59b9b [BE] Delete stale non-ephemeral runners workarounds (#164285)
As all Win runners are ephemeral, no need to cleanup leftover processes
or uninstall PyTorch at the end of the test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164285
Approved by: https://github.com/Skylion007
2025-10-01 03:47:36 +00:00
6d4dfa0878 [CI] Push viable/strict/${time} tags (#164183)
Every time viable strict is updated
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164183
Approved by: https://github.com/seemethere
2025-10-01 03:41:10 +00:00
11ccb95ccb [PyTorch Pinned Allocator] Pinned memory stats and perf fixes around allocating blocks (#163777)
Summary: This diff adds bucket stats for pinned memory and also a perf fix to not check for sizes when background thread is enabled

Differential Revision: D83162186

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163777
Approved by: https://github.com/bbus
2025-10-01 03:28:58 +00:00
bd0907dc4c [BE][CI] Unify requirments (#163396)
Both Linux, Windows and MacOS CI workflows should use `.ci/docker/requirements-ci.txt`
TODOS:
 - Investigate why `choco install cmake` is needed to successfully detect MKL
 - Move `psutil` installation from specific scripts into requirements-ci.txt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163396
Approved by: https://github.com/Skylion007
2025-10-01 03:28:48 +00:00
8bb71c07c4 Skip symmetric memory tests calling _scaled_mm on CCC < 8.9 (#164251)
This avoids them failing on e.g. A100 GPUs with
> RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
2025-10-01 03:26:21 +00:00
fa90090735 Use dataclass features in two classes (#164221)
This PR completes two TODO items by using features of `dataclass`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221
Approved by: https://github.com/Skylion007, https://github.com/mlazos

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-10-01 03:20:39 +00:00
591997490a [BE][Easy]: Add prims common TypeGuard (#164263)
Slightly improves typing by adding a TypeGuard.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164263
Approved by: https://github.com/albanD
2025-10-01 03:13:10 +00:00
531f3bf5e1 Adding check for square matrix for input tensor in matrix_exp backwar… (#163357)
…d op.

Fixes #146796

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163357
Approved by: https://github.com/lezcano
2025-10-01 03:12:30 +00:00
2a5ce2feb4 Add algorithm in header (#164295)
Fixes #163307. Added ```#include <algorithm>``` to vulkan QueryPool for the std::for_each call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164295
Approved by: https://github.com/Skylion007
2025-10-01 03:09:50 +00:00
3787a5a60e [export] Explicitly passing requires_grad to nn.Parameter() in deserialization (#164290)
Summary: `nn.Parameter()` by default has `requires_grad=True` and would cause issues when there are non-float parameters.

Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_non_float_weight

Differential Revision: D83598796

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164290
Approved by: https://github.com/angelayi
2025-10-01 02:55:20 +00:00
c66d18d24d [dynamo][sac] Support functools partial context_fn for sac (#164308)
Fixes https://github.com/pytorch/pytorch/issues/164300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164308
Approved by: https://github.com/Lucaskabela, https://github.com/soulitzer
2025-10-01 02:47:55 +00:00
e0f118585f skip non memory deps in memory estimator (#164294)
Differential Revision: [D83601030](https://our.internmc.facebook.com/intern/diff/D83601030)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164294
Approved by: https://github.com/mlazos
2025-10-01 02:44:58 +00:00
10a005e87f [torchfuzz] add layout operators (#164210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164210
Approved by: https://github.com/pianpwk
ghstack dependencies: #164034, #164209, #164211
2025-10-01 02:33:19 +00:00
1f3995cdc8 [torchfuzz] raise if Operator abstract method is not implemented (#164211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164211
Approved by: https://github.com/pianpwk
ghstack dependencies: #164034, #164209
2025-10-01 02:33:19 +00:00
abfcce58a4 [torchfuzz] remove erroneous can_produce check (#164209)
can_produce is an abstract method that always return false
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164209
Approved by: https://github.com/pianpwk
ghstack dependencies: #164034
2025-10-01 02:33:19 +00:00
5b1c39f5a1 Add smoke tests to verify that stable ABI FA3 wheel runs w/ newer torch (#163782)
Passing CI: https://github.com/pytorch/pytorch/actions/runs/18141589975/job/51635340255?pr=163782

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163782
Approved by: https://github.com/huydhn, https://github.com/mikaylagawarecki
2025-10-01 02:30:38 +00:00
8df3f2fa98 Revert new-test part of #163829 (#164259)
Summary:

New test sizes for `test_scaled_mm_vs_emulated_block_wise` all fail with

```
RuntimeError: Invalid scaling configuration
```

Disable these new tests for now (the remaining test is a parametrized
version of the original test case)

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164259
Approved by: https://github.com/jananisriram
ghstack dependencies: #164266
2025-10-01 02:23:21 +00:00
7a9119948e Split scaled-mm tests into separate file (#164266)
Summary:

* Split scaled-mm-specific tests into `test/test_scaled_matmul.py`

Test Plan:

```
pytest test/test_matmul_cuda.py
pytest test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164266
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-10-01 02:23:21 +00:00
28c1d2f81b [aoti] AOTI mingw cross compilation (#163188)
To run this, you need to install `mingw64-gcc-c++` and download windows cuda library toolkit.

See design doc and demo instructions in https://docs.google.com/document/d/1iDaChqA5nNKkBFTzsdkmoomvQlXHbnlb1Z4yEp7xaJA/edit?tab=t.0

If cross_platform_target is windows, we do the following:

- do not link to `sleef`. This can be improved in the future if we need it. Currently I avoid it because that requires extra setup on the linux side
- Use `mingw64-gcc-c++` to compile
- Use `WINDOWS_CUDA_HOME` instead of `CUDA_HOME` when linking to cuda

```
 python test/inductor/test_aot_inductor_windows.py -k so
 ```

 Other changes:
 - de-couples compile_standalone config and dynamic link flag
 - create a new aot_inductor_mode config module, which is used to control configs in aot_inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163188
Approved by: https://github.com/desertfire
2025-10-01 02:22:06 +00:00
c4bbc6433e [PyTorch CCA] Add an API to get expandable segment sizes (#163771)
Summary: This diffs add an API to query expandable segment size for each stream so that we can use this info to warmup the segment in advance, so we dont incur any performance penalty during steady state inference for new CUDA memory allocations.

Differential Revision: D76447308

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163771
Approved by: https://github.com/bbus
2025-10-01 02:16:58 +00:00
ad7e3c93b1 [ROCm][CD] librocroller.so missing from ROCm 7 wheel (#164244)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164244
Approved by: https://github.com/jeffdaily, https://github.com/Skylion007

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-01 00:02:34 +00:00
7f3dc45300 Migrate DeviceType to torch/headeronly (#163999)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163999
Approved by: https://github.com/mikaylagawarecki
2025-09-30 23:13:27 +00:00
ff715366aa [vllm hash update] update the pinned vllm hash (#164190)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vllm hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164190
Approved by: https://github.com/pytorchbot
2025-09-30 22:43:49 +00:00
60a4961ff4 [DTensor] Allow redistribute to Partial if src matches (#164253)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164253
Approved by: https://github.com/zpcore
2025-09-30 22:42:49 +00:00
bec6541d84 [CUDA][CUDAGraph] Reduce capture overhead in CUDA Graph memory reuse (#162186)
Previous work #158352 delivered CUDAGraph memory footprint reduction with no replay-time impact, but capture time regressed (up to 20× slower) due to repeated full-graph traversals. See previous benchmark results [here](https://github.com/pytorch/pytorch/pull/158352#issuecomment-3215947565)

This PR removes capture/reply overhead while preserving the memory savings:

1. **Terminals as free markers**
   We stop inserting empty nodes and instead record the current stream terminals as free markers. This avoids mutating the user’s graph and keeps semantics unchanged.

2. **Incremental, cached reachability**
   We add a **per-graph reuse context** that caches reverse-traversal state:

   * `graph_reuse_context[graph].visited[stream]` tracks nodes already seen from that stream’s terminal frontier.
   * On each allocation during capture, we resume traversal from the latest terminals and only visit unseen nodes.
   * A block is freed when all its recorded markers are in the visited set of its allocation stream—i.e., all markers are proven predecessors of future work.

See [the performance results here](https://docs.google.com/spreadsheets/d/e/2PACX-1vRPvdd9Xa8W87ixbiA0da_qvOhrUAjUpFz0G-_j-MsDnoeRyhEa4_ut_W3rqcg1VVZVFJ-gucwov-3b/pubhtml?gid=1468302443&single=true), we sweep synthetic multi-stream CUDA Graphs built by `capture_benchmark.py` (same as before, we generate random interleaving of alloc/free/join with given probabilities, see [gist here](https://gist.github.com/eee4017/e2092d215b1d4bd46534148939af39e3)), and we compare median capture/replay times and memory. On an NVIDIA H100 PCIe across 24 configs, the optimization preserves reserved memory reduction at ~24–98%, leaves allocated memory unchanged, and brings capture time back to baseline (range 0.96–1.04× vs. baseline) with replay time unchanged (range 0.97–1.11×).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162186
Approved by: https://github.com/eqy, https://github.com/ngimel
2025-09-30 22:28:46 +00:00
1f1de20ba9 [c10d][BE][ez] Update tensor ptr inside nccl.cpp (#164276)
This is mostly a cosmetic change which replace the deprecating `data_ptr` API with mutable or const one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164276
Approved by: https://github.com/Skylion007, https://github.com/eqy, https://github.com/kwen2501
2025-09-30 22:05:12 +00:00
2810977d3a [FSDP][Replicate] tests replicate type casting behavior and edge cases in mixed precision (#162861)
**Summary:** Ensures that replicate can handle the same type casting behavior and edge cases that fully shard can when mixed precision is used

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_float16_on_one_submodule
2. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_submodules_with_external_inputs
3. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_bf16
4. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_fp16
5. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_clamp_reduce_dtype
6. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_dataclass_input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162861
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851, #162853, #162855
2025-09-30 22:03:23 +00:00
ae4fd4ea75 [FSDP2] support AC(FSDP) for torchtitan's MOE (#164009)
for fsdp2 + EP, titan has fully_shard(AC(layer)) and fully_shard(layer.moe.experts): https://github.com/pytorch/torchtitan/issues/1624

for implicit prefetching, backward order is
* _pre_backward unshard (norm, output)
* _backward_prefetch unshard layers.6
* post_backward reshard (norm, output)
* _pre_backward unshard layers.6 (no-op, unsharded already)
* _backward_prefetch unshard layers.6.moe.experts
* recompute_fn pre_forward unshard layers.6.moe.experts (no-op, unsharded already)
* ~~recompute_fn post_forward reshard layers.6.moe.experts~~ <----- this PR make it a no-op
* _pre_backward unshard layers.6.moe.experts (no-op, unsharded already)
* _backward_prefetch unshard layers.5
* post_backward reshard layers.6.moe.experts
* post_backward reshard layers.6

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_comm.py -k test_set_modules_to_backward_prefetch_inside_ac`

before fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 11:43:01,714 - root - INFO - step:  1  loss: 12.0162  grad_norm:  1.7315  memory: 45.64GiB(48.05%)  tps: 1,028  tflops: 10.87  mfu: 1.10%
[rank0]:[titan] 2025-09-30 11:43:01,714 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 11:43:35,233 - root - INFO - [GC] Performing periodical GC collection 0.06 seconds
[rank0]:[titan] 2025-09-30 11:43:35,987 - root - INFO - step: 50  loss:  6.9302  grad_norm:  0.9985  memory: 59.66GiB(62.80%)  tps: 11,712  tflops: 123.89  mfu: 12.53%
```

after fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 11:38:57,377 - root - INFO - step:  1  loss: 12.0134  grad_norm:  1.6916  memory: 38.42GiB(40.45%)  tps: 805  tflops: 8.51  mfu: 0.86%
[rank0]:[titan] 2025-09-30 11:38:57,377 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 11:39:28,541 - root - INFO - [GC] Performing periodical GC collection 0.06 seconds
[rank0]:[titan] 2025-09-30 11:39:29,279 - root - INFO - step: 50  loss:  6.9346  grad_norm:  1.1875  memory: 52.58GiB(55.36%)  tps: 12,583  tflops: 133.10  mfu: 13.46%
```

for explicit prefetching, layers.6 backward prefetch layers.5 and layers.5.moe.experts. layers.6.moe.experts does not have explicit prefetch. backward order is like this
* _pre_backward unshard (norm, output)
* _prefetch_unshard layers.6
* post_backward reshard (norm, output)
* _pre_backward unshard layers.6 (no-op, unsharded already)
* _prefetch_unshard layers.5
* _prefetch_unshard layers.5.moe.experts
* recompute_fn pre_forward unshard layers.6.moe.experts
* ~~recompute_fn post_forward reshard layers.6.moe.experts~~ <----- this PR makes it a no-op
* _pre_backward unshard layers.6.moe.expert (no-op, unsharded already)
* post_backward reshard layers.6.moe.expert
* post_backward reshard layers.6

before fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 11:53:24,574 - root - INFO - step:  1  loss: 12.0180  grad_norm:  1.6948  memory: 45.77GiB(48.18%)  tps: 849  tflops: 8.98  mfu: 0.91%
[rank0]:[titan] 2025-09-30 11:53:24,574 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 11:53:57,768 - root - INFO - [GC] Performing periodical GC collection 0.07 seconds
[rank0]:[titan] 2025-09-30 11:53:58,515 - root - INFO - step: 50  loss:  6.9358  grad_norm:  1.0528  memory: 59.80GiB(62.95%)  tps: 11,827  tflops: 125.10  mfu: 12.65%```
```

after fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 12:08:39,404 - root - INFO - step:  1  loss: 12.0143  grad_norm:  1.7030  memory: 38.55GiB(40.58%)  tps: 988  tflops: 10.45  mfu: 1.06%
[rank0]:[titan] 2025-09-30 12:08:39,404 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 12:09:10,482 - root - INFO - [GC] Performing periodical GC collection 0.06 seconds
[rank0]:[titan] 2025-09-30 12:09:11,168 - root - INFO - step: 50  loss:  6.9356  grad_norm:  0.9911  memory: 52.81GiB(55.59%)  tps: 12,637  tflops: 133.68  mfu: 13.52%
```

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164009
Approved by: https://github.com/soulitzer
2025-09-30 22:02:24 +00:00
adc11a7634 [export] avoid checks during tracing of export verification (#164219)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164219
Approved by: https://github.com/Lucaskabela
2025-09-30 21:46:59 +00:00
99e28ffab3 [FSDP][Replicate] tests replicate core functionality with mixed precision (#162855)
**Summary:** Ensures that replicate functionality works the same as fully shard's when mixed precision is used

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k TestReplicateMixedPrecisionTraining

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162855
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851, #162853
2025-09-30 21:45:58 +00:00
01dd2c2b42 [FSDP][Replicate] tests replicate is composable with tp (#162853)
**Summary:** Proof that new replicate API is composable with TP

**Test Case**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_replicate_tp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162853
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851
2025-09-30 21:29:54 +00:00
d3bdf8c32e [FSDP][Replicate] tests replicate with custom forward method (#162851)
**Summary: tests replicate works when users use custom forward methods**

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_register_fsdp_forward_method

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162851
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839
2025-09-30 21:15:34 +00:00
1ce9563ff6 [FSDP][Replicate] tests replicate gradient accumulation and 1f1b microbatching (#162839)
**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. The first test verifies Replicate works with gradient accumulation properly. The second verifies that replicate works correctly with a One-Forward-One-Backward (1F1B) pipeline parallelism schedule

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_gradient_accumulation
2. pytest test/distributed/_composable/test_replicate_training.py -k test_1f1b_microbatching

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162839
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836
2025-09-30 21:00:16 +00:00
9e631392dc Missing lambda in torch._check (#164225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164225
Approved by: https://github.com/Skylion007
2025-09-30 20:32:38 +00:00
1cce6efdb8 Fix silent incorrectness for bmm/baddmm out_dtype overload (#164095)
Add input checks like meta functions for standard ops in `ATen/native/LinearAlgebra.cpp` for the `out_dtype` variants. Fixes silent incorrectness in https://github.com/pytorch/pytorch/issues/163816

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164095
Approved by: https://github.com/ngimel
2025-09-30 20:13:13 +00:00
5a93f00c79 [CI] Delete binary smoke workflows (#164260)
Those were very useful in the past, because:
- CI builder jobs did not generates wheels, but rather run `python setup.py develop` and shared docker layers, which is no longer the case, all CI jobs produce wheels
- CD jobs were targeting pre-CXX11 ABI, but this is no longer the case after manylinux2_28 migration

Existing, but acceptable gaps:
 - Windows libtorch debug builds sometimes might fail, but IMO it's ok not to be able to produce those for a few days, as number of libtorch users are somewhat small
 - All CD jobs are based on AlmaLinux, while CI are based on Ubuntu, but this could be adjusted if needed, besides AlmaLinux-9 and Ubuntu-22.04 are pretty close in terms of glibc and gcc versions
 - CD jobs build for all GPU architectures, while CI only for the one being tested, but there are now periodic H100 and B200 jobs, and not a lot of development happens for Voltas or Pascals

Besides there are better tools to alert about the nightly failures

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164260
Approved by: https://github.com/seemethere, https://github.com/atalman
2025-09-30 20:00:07 +00:00
e30f01b5b5 [1/N] Simplify "in" operation for containers of a single item (#164224)
These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164224
Approved by: https://github.com/rec, https://github.com/Skylion007
2025-09-30 19:59:43 +00:00
ffc645c870 half support for fused_moving_avg_obs_fake_quant() op (#164175)
Follow up to https://github.com/pytorch/pytorch/pull/162620.  Add half support, as well.  This fixes some failures in inductor benchmarks such as from this log https://github.com/pytorch/pytorch/actions/runs/18051942373/job/51376749459.

`NotImplementedError: "aminmax_kernel" not implemented for 'Half'`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164175
Approved by: https://github.com/malfet, https://github.com/jerryzh168
2025-09-30 19:35:17 +00:00
60f0a356fd Update persons of interest for XLA. The previous one is out of date. (#158652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158652
Approved by: https://github.com/JackCaoG, https://github.com/albanD
2025-09-30 19:21:18 +00:00
d2c5f231f6 Fix the shape check inside gnll loss (#147522)
Fixes #147521
This modification allow user to put any size of var in GaussianNLLLoss if the var is broadcastable (to input/target's size)

Therefore, the demo code in #147521 will result in expected behaviour and correct output.

This allow all input size that match:
`input.size = (..., n, ...), var.size = (..., 1, ...)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147522
Approved by: https://github.com/mikaylagawarecki
2025-09-30 18:40:15 +00:00
cc5d74c366 Revert "[BE] Remove HermeticPyObjectTLS and Simplify PythonOpRegistrationTrampoline (#163464)"
This reverts commit 94195a37ae4eae9c486a81b0f67725c8970f74d6.

Reverted https://github.com/pytorch/pytorch/pull/163464 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/163464#issuecomment-3353307034))
2025-09-30 18:20:20 +00:00
a707042353 fix: inductor non_blocking test - warmup events to make test pass whether it is the first run or not (#164188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164188
Approved by: https://github.com/williamwen42
2025-09-30 18:20:17 +00:00
d615f6b935 [inductor] use hint_override in kernel benchmark args (#164207)
Summary: forward fix T239259207

Test Plan: test_multi_kernel

Differential Revision: D83539263

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164207
Approved by: https://github.com/bobrenjc93, https://github.com/mlazos
2025-09-30 18:09:29 +00:00
719b64ee8b Fix TMA transpose logic to handle 1D shapes + string differences (#163966)
Fixes #163702.

This fixes 2 issues:
1. The value may inconsistently be a shape or string. This normalizes to handle both of these.
2. 1D shapes should not transpose data. This fixes the order of operations to prevent this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163966
Approved by: https://github.com/eellison
2025-09-30 17:51:37 +00:00
1cf1b9138d [inductor][templates] Template hooks should be finalised inside a kernel context (#164229)
The prologue buffer added in https://github.com/pytorch/pytorch/pull/160480 is added to template code in the DEF_KERNEL [hook](29221b9828/torch/_inductor/select_algorithm.py (L742)). The lines in this buffer may be of type `DeferredLine`, and so require the correct kernel context to determine whether lines should be added or removed.

Test plan:

Tested with a custom template using tensor descriptors for prologue fused inputs, whose tensor descriptors need to be hoisted to the top of the kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164229
Approved by: https://github.com/njriasan
2025-09-30 17:50:59 +00:00
5ed4672477 [dynamo, 3.14] fix _detect_and_normalize_assert_statement for 3.14 (#164005)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164005
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191, #163292, #163796, #163818, #163919, #163920, #164004
2025-09-30 17:43:03 +00:00
2600f8b3d1 [dynamo, 3.14] fix tracing typing.Union (#164004)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164004
Approved by: https://github.com/anijain2305, https://github.com/mlazos
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191, #163292, #163796, #163818, #163919, #163920
2025-09-30 17:43:03 +00:00
9ce31e4278 [3.14] make unbacked_sym[int/float]_counter integers (#163920)
3.14 removed copy/deepcopy/pickle support for `itertools` iterators: https://docs.python.org/3.14/whatsnew/3.14.html#itertools

Change unbacked_sym[int/float]_counter from `itertools.count` to regular integers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163920
Approved by: https://github.com/ezyang
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191, #163292, #163796, #163818, #163919
2025-09-30 17:42:55 +00:00
0657de9c61 [dynamo, 3.14] support LOAD_COMMON_CONSTANT (#163919)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163919
Approved by: https://github.com/anijain2305, https://github.com/mlazos
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191, #163292, #163796, #163818
2025-09-30 17:42:47 +00:00
4ead8ebf70 [dynamo, 3.14] fix BUILD_TUPLE with 0 args (#163818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163818
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191, #163292, #163796
2025-09-30 17:42:40 +00:00
d4b785a6a7 [dynamo, 3.14] fix stack ref copy error (#163796)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163796
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191, #163292
2025-09-30 17:42:33 +00:00
9278b18ec0 [dynamo, 3.14] fix WITH_EXCEPT_START (#163292)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163292
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110, #163191
2025-09-30 17:42:26 +00:00
008b0a9425 [dynamo, 3.14] fix inactive ctx handling in resume functions (#163191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163191
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839, #163009, #163109, #163110
2025-09-30 17:42:19 +00:00
44677ad917 [dynamo, 3.14] support LOAD_CONST on slice, codegen LOAD_CONST slice instead of BINARY/STORE_SLICE (#163110)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163110
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839, #163009, #163109
2025-09-30 17:42:11 +00:00
1c9987fdf4 [dynamo, 3.14] fix context managers (#163109)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163109
Approved by: https://github.com/anijain2305, https://github.com/mlazos
ghstack dependencies: #161838, #161555, #161839, #163009
2025-09-30 17:42:03 +00:00
7cbc011700 [dynamo, 3.14] support some bytecodes, fix CALL_FUNCTION_EX (#163009)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163009
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838, #161555, #161839
2025-09-30 17:41:56 +00:00
09c774145e [dynamo, 3.14] Python dynamo changes to get basic programs working (#161839)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161839
Approved by: https://github.com/Lucaskabela, https://github.com/anijain2305
ghstack dependencies: #161838, #161555
2025-09-30 17:41:49 +00:00
763ab2a6ed [dynamo, 3.14] compile actual code in C dynamo (#161555)
No 3.14 CI tests enabled yet, but this was enough to get Dynamo compiling locally and Python Dynamo is at least being called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161555
Approved by: https://github.com/anijain2305
ghstack dependencies: #161838
2025-09-30 17:41:42 +00:00
4b8fe795f8 [dynamo] format cpython_defs.c (#161838)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161838
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
2025-09-30 17:41:35 +00:00
84e1cd7392 [inductor] fx comm overlap: align runtime estimations across dist ranks (#164226)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164226
Approved by: https://github.com/eellison
2025-09-30 17:29:18 +00:00
937869657e Exporting aten.sdpa with cuda under fake mode on a cuda-less machine (#164162)
Summary:
As titled.

sdpa will select backend based on hardware check, and it fails when exporting with cuda under fake mode on a cuda-less machine.

We guard `at::cuda::is_available()` check before `at::cuda::getCurrentDeviceProperties()` and give warnings.

Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r nn_functional_scaled_dot_product_attention

Differential Revision: D83496154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164162
Approved by: https://github.com/SherlockNoMad
2025-09-30 17:21:31 +00:00
7d7ae4d7b2 [submodule] upgrade cutlass version to 4.2.1 and completely resolved python/cutlass name collision (#164156)
Differential Revision: D83489362

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164156
Approved by: https://github.com/Skylion007, https://github.com/mlazos
2025-09-30 17:04:57 +00:00
906fe7b120 [ROCm][CI] no longer build almalinux image for ROCm 6.3 (#164201)
Missed during ROCm 7 upgrades.  We only build N and N-1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164201
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-09-30 16:59:31 +00:00
7edd18f0fd [Inductor-FX] Generalize FloorDiv conversion to handle more complex launch grids. Remove python_slow grid mode. (#163828)
# Problem
Inductor's FX backend receives sympy expressions for Triton launch grids, and passes these to a tracer to generate equivalent FX IR. However, the tracer does not support all possible sympy expressions. In particular, it can't handle ops like `floor` and `Pow` which would be found in an expression like `floor(x / y)`. Instead, it expects `FloorDiv(x, y)`, which has the advantage that all intermediate values are integers, unlike `x / y`.

Inductor's Python backend uses a trick where `ceil(x / y)` is computed in Python as `-(x // -y)`, which is faster when evaluating Python launch grids at runtime. However, this trick generates more complex sympy expressions, so the FX backend introduced a `"python_slow"` mode using a more familiar form of ceil division. However, this mode is slower to evaluate, which increased production CPU usage. (Internal reviewers see T237853632.)

# Solution
To get the best of both worlds, this PR removes `"python_slow"` mode, and generalizes the `replace_floor_div` function  to handle the more complex expressions resulting from the `"python"` grid mode. The new algorithm is conceptually similar to the existing one, except instead of analyzing only the first argument to a `sympy.Mul` op, it checks all factors, so it can handle expressions containing both `Rational` and `Pow` ops, among other cases. It also uses `Mul.make_args` to handle the case when the argument to `floor` is not a `Mul`. Finally, it uses `expr.is_positive` to check the sign of symbolic exponents.

This new algorithm is guaranteed to convert all `floor` ops to an equivalent expression using `FloorDiv`. (To see this, consider that `floor(x) == FloorDiv(x, 1)`.) Note it may not remove all `Pow` ops, with a counterexample being `floor(x / (2 + z ** y))`, but it covers everything we've seen in practice for symbolic launch grids. In particular, it covers the typical case where `Pow` is a factor of the argument to `floor`, and the exponent is `-1`. Is this situation, we move the `Pow` to the denominator of `FloorDiv` and the exponent becomes `1`, eliminating the `Pow` op.

# Test plan
This PR adds an end-to-end test for static padding with dynamic outer dimensions, which creates a difficult sympy expression that the existing algorithm would not be able to handle.

This PR also adds some unit tests for the `replace_floor_div` function. It can be difficult to construct end-to-end tests that expose all the trickiest expressions, as those tests have to pass through a number of other systems handling dynamic shapes. Therefore, it's easier to expose the edge cases with these new unit tests. The tests check that we can replace all `floor` ops in the input expression with `FloorDiv`, then they expand `FloorDiv` back to `floor` and check equality with the original expression.

Note this PR also requires some MTIA changes to pass internal tests. Those will be stacked onto the imported diff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163828
Approved by: https://github.com/nandesuka, https://github.com/angelayi, https://github.com/jansel
2025-09-30 16:47:49 +00:00
3564cd294c Fix TestExportOpInfo (#164184)
Fixes https://github.com/pytorch/pytorch/issues/163699

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164184
Approved by: https://github.com/yiming0416, https://github.com/tugsbayasgalan
2025-09-30 16:12:39 +00:00
1412a4a42f [precompile] Add option to disable guard check on aot-compiled function. (#163432)
Summary:
Under circumstances it seems reasonable to return a callable directly without guard check when user use aot_compile on a function with single compilation result.

When having multiple entries (aot_compile_module), we should start enabling guard check to differetiate different compiled functions apart.

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163432
Approved by: https://github.com/dolpm, https://github.com/mlazos
2025-09-30 16:10:15 +00:00
96330f490d [testing] Add upload for test status during test stat uploads (#164189)
Add test status (flaky, success, skipped, failure) upload for easier comparison between test status on two commits

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164189
Approved by: https://github.com/huydhn, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-09-30 15:53:53 +00:00
eqy
66abba8f49 [CUDA][Expandable Segments] Follow-up cleanups for even more expandable segments tests (#163297)
Gets original setting even earlier in case of crashes, fixes previous get call where set should be

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163297
Approved by: https://github.com/Skylion007
2025-09-30 15:39:14 +00:00
e88cca0691 Update Sphinx theme (#164147)
Fix links in the top nav bar: 71e55749be

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164147
Approved by: https://github.com/albanD
2025-09-30 15:35:58 +00:00
5c020beba4 Update LPPool docs to clarify ceil_mode padding semantics when ceil_mode=True (#163186)
# Summary

- Add a note to each `nn.LPPool*d` docstring explaining how `ceil_mode=True` interacts with right padding.
- Mirror the same clarification in the `torch.nn.functional.lp_pool*` docstrings so the rendered functional docs stay in sync.

# Motivation

The current PyTorch spec for **LPPool** does not fully match runtime behavior, which has led to downstream confusion in other specs (e.g., ONNX) and runtimes (e.g., [onnxruntime issue #25848](https://github.com/microsoft/onnxruntime/issues/25848)). A corresponding clarification was also made in the ONNX spec: [onnx/onnx#5741](https://github.com/onnx/onnx/pull/5741).

PyTorch’s **LPPool** implementation calls into **AvgPool**, which enforces the rule that windows starting entirely in the right padded region are ignored when `ceil_mode=True`. As a result, **LPPool** inherits the same behavior.

This is an edge case where the output size formula shown in the LPPool docs/spec is not sufficient on its own. Without the added caveat, the documentation is technically incorrect. This PR brings the LPPool docs in line with actual behavior.

Note that this is a trivial fix to the spec as all major implementers of the spec adhere to this caveat.

For comparison, both **MaxPool** and **AvgPool** already include this clarification in their spec. Their docstrings explicitly state:

> *When `ceil_mode=True`, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored.*

Adding the same note to LPPool ensures consistency across all pooling operators.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163186
Approved by: https://github.com/mikaylagawarecki
2025-09-30 15:22:46 +00:00
edd9e07aff [BE] Remove not existing mnist mirror (#164238)
Looks like original source is empty now:
http://yann.lecun.com/exdb/mnist/

Pytorch hosted mirror exist. Hence leaving it as only option.
https://ossci-datasets.s3.amazonaws.com/mnist/

Fixes these errors in pytorch/ci:
```
C:\actions-runner\_work\pytorch\pytorch>python tools\download_mnist.py --quiet -d C:\actions-runner\_work\pytorch\pytorch\test\cpp\api\mnist
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz ...
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ...
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz ...
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ...
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz ...
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz ...
```

Link to workflow with example:
https://github.com/pytorch/pytorch/actions/runs/18109150240/job/51542177282#step:15:2335
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164238
Approved by: https://github.com/jeanschmidt
2025-09-30 15:15:13 +00:00
0fb89b84b9 Revert "Consistently use c10_ovrsource in arvr mode everywhere (#164128)"
This reverts commit efd7fd5ed5ac7ec03201a546a09fb19ec59de431.

Reverted https://github.com/pytorch/pytorch/pull/164128 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164128#issuecomment-3352544006))
2025-09-30 14:43:52 +00:00
79fcfd49d6 Revert "[CI] Push viable/strict/${time} tags (#164183)"
This reverts commit 9f27b0c24515d9cf319d9a728d5009bf9ed035cf.

Reverted https://github.com/pytorch/pytorch/pull/164183 on behalf of https://github.com/malfet due to Hmm, didn't work that way ([comment](https://github.com/pytorch/pytorch/pull/164183#issuecomment-3352494098))
2025-09-30 14:32:46 +00:00
71b4fada57 Revert "Add less warps config to inner reductions (#162447)"
This reverts commit 84d673ef577d42d6ec20c6c9f09863583c3111f5.

Reverted https://github.com/pytorch/pytorch/pull/162447 on behalf of https://github.com/PaulZhang12 due to internal failure ([comment](https://github.com/pytorch/pytorch/pull/162447#issuecomment-3352474768))
2025-09-30 14:28:19 +00:00
46ec0664e3 Remove unused PyIntXXX, THPUtils_newReal_BOOL, THPQXXX macros (#164056)
The removed macros are not used in other places of the `pytorch` GitHub org.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164056
Approved by: https://github.com/albanD
2025-09-30 13:48:25 +00:00
410ed3006b Revert "Add functions to setup PrivateUse1 as a python backend device. (#157859)"
This reverts commit 1310d6a1f9194ddcf6753f7e12fb78f278451f8a.

Reverted https://github.com/pytorch/pytorch/pull/157859 on behalf of https://github.com/jeanschmidt due to introduce linting errors ([comment](https://github.com/pytorch/pytorch/pull/157859#issuecomment-3352140098))
2025-09-30 13:24:37 +00:00
77354e22e1 [OpenReg] Add AMP Integration guide for accelerators (#162050)
Fix part of #158917

Add AMP integration document and OpenReg code as example to explain steps of integration.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162050
Approved by: https://github.com/albanD

Co-authored-by: FFFrog <ljw1101.vip@gmail.com>
2025-09-30 12:27:11 +00:00
7f29c47a4f Fix cdist export compute mode validation (#161724)
Fixes #161089. Added '0' as the acceptable value for compute mode in _meta_registrations.py. Also, added a test case in test_export.py file.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161724
Approved by: https://github.com/albanD, https://github.com/angelayi
2025-09-30 12:23:20 +00:00
ace6c76103 [inductor] Small refactor of CachingAutotuner (#162406)
This is a simple refactor that just moves some logic in `_precompile_config` to two new functions for separation of concerns. This will allow subclasses e.g. out of tree to configure options and metadata for triton.compile.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162406
Approved by: https://github.com/exclamaforte
2025-09-30 11:29:15 +00:00
1310d6a1f9 Add functions to setup PrivateUse1 as a python backend device. (#157859)
Fixes #156052 and #156444.

This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.

Changes done in this PR:

1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.

This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
2025-09-30 08:39:36 +00:00
7f4c3e7d2f distributed/serialization: support zero sized tensors (#164198)
Fixes
```
[4] ValueError: both buffer length (0) and count (-1) must not be 0
```

Test plan:

```
pytest test/distributed/test_serialization.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164198
Approved by: https://github.com/amirafzali
2025-09-30 08:11:29 +00:00
6e5b4249a5 [DTensor][Export] Supporting exporting a model with DTensor params/inputs (#163609)
I experimented with 3 paths to get joint graph for DTensorized module and input

1. strict_export + aot_export_joint_with_descriptors
2. graph_capture + aot_export_joint_with_descriptors
3. aot_export_joint_with_descriptors alone

Added test to guard them.

1 doesn't work, as bw graph region is missing from the joint graph.
I am leaning towards making 2 the recommended path.
If 2 doesn't work going forward, we can fallback to 3.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163609
Approved by: https://github.com/tugsbayasgalan

Co-authored-by: suo <suo@fb.com>
2025-09-30 07:54:13 +00:00
5274753873 [dynamo][device_mesh] Support mesh_dim_names (#164200)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164200
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel
2025-09-30 07:16:28 +00:00
7afcb030d8 Back out "Revert D81959389" (#163905)
Summary:
Original commit changeset: 06888d7ebff0

Original Phabricator Diff: D82932788

Restricted the test to SM90 for scaled_grouped_mm

Test Plan: TBD (will share the linux CI results)

Differential Revision: D83283991

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163905
Approved by: https://github.com/angelayi
2025-09-30 07:05:13 +00:00
bbf6816f35 [dynamo] Special path for cloning of torch dispatch tensors (#164081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164081
Approved by: https://github.com/tugsbayasgalan, https://github.com/mlazos
2025-09-30 05:15:56 +00:00
ace89350fc better error handling for rrelu when lower or upper range is infinite (#160965)
… - issue#153281

Fixes #153281

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160965
Approved by: https://github.com/janeyx99
2025-09-30 05:01:32 +00:00
7d59e37434 Add Comm-Compute Preserving Bucketer (#163960)
tl;dr performs bucketing while preserving comm-compute overlap.

In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
2025-09-30 04:53:58 +00:00
92108f4abd Helper to augment graph with additional deps (#163959)
In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754
2025-09-30 04:53:58 +00:00
0b2fdc30a2 refactor bucketing (#163754)
Preparatory refactory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163754
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #163215
2025-09-30 04:53:58 +00:00
0d7994ca97 [inductor] do comm compute overlap at aten fx level (#163215)
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
2025-09-30 04:53:58 +00:00
c39357bab6 [torchfuzz] Make scalar and tensor distribution configurable (#164034)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164034
Approved by: https://github.com/pianpwk
2025-09-30 04:50:54 +00:00
236 changed files with 9884 additions and 4401 deletions

View File

@ -13,49 +13,6 @@ def list_dir(path: str) -> list[str]:
return check_output(["ls", "-1", path]).decode().split("\n")
def build_ArmComputeLibrary() -> None:
"""
Using ArmComputeLibrary for aarch64 PyTorch
"""
print("Building Arm Compute Library")
acl_build_flags = [
"debug=0",
"neon=1",
"opencl=0",
"os=linux",
"openmp=1",
"cppthreads=0",
"arch=armv8a",
"multi_isa=1",
"fixed_format_kernels=1",
"build=native",
]
acl_install_dir = "/acl"
acl_checkout_dir = os.getenv("ACL_SOURCE_DIR", "ComputeLibrary")
if os.path.isdir(acl_install_dir):
shutil.rmtree(acl_install_dir)
if not os.path.isdir(acl_checkout_dir) or not len(os.listdir(acl_checkout_dir)):
check_call(
[
"git",
"clone",
"https://github.com/ARM-software/ComputeLibrary.git",
"-b",
"v25.02",
"--depth",
"1",
"--shallow-submodules",
]
)
check_call(
["scons", "Werror=1", f"-j{os.cpu_count()}"] + acl_build_flags,
cwd=acl_checkout_dir,
)
for d in ["arm_compute", "include", "utils", "support", "src", "build"]:
shutil.copytree(f"{acl_checkout_dir}/{d}", f"{acl_install_dir}/{d}")
def replace_tag(filename) -> None:
with open(filename) as f:
lines = f.readlines()
@ -356,19 +313,13 @@ if __name__ == "__main__":
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn:
build_ArmComputeLibrary()
print("build pytorch with mkldnn+acl backend")
build_vars += (
"USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
"ACL_ROOT_DIR=/acl "
"LD_LIBRARY_PATH=/pytorch/build/lib:/acl/build:$LD_LIBRARY_PATH "
"ACL_INCLUDE_DIR=/acl/build "
"ACL_LIBRARY=/acl/build "
)
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
build_vars += "ACL_ROOT_DIR=/acl "
if enable_cuda:
build_vars += "BLAS=NVPL "
else:
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/OpenBLAS "
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
else:
print("build pytorch without mkldnn backend")

View File

@ -299,40 +299,6 @@ def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
)
def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
print("Building OpenBLAS")
host.run_cmd(
f"git clone https://github.com/xianyi/OpenBLAS -b v0.3.28 {git_clone_flags}"
)
make_flags = "NUM_THREADS=64 USE_OPENMP=1 NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=ARMV8"
host.run_cmd(
f"pushd OpenBLAS && make {make_flags} -j8 && sudo make {make_flags} install && popd && rm -rf OpenBLAS"
)
def build_ArmComputeLibrary(host: RemoteHost, git_clone_flags: str = "") -> None:
print("Building Arm Compute Library")
acl_build_flags = " ".join(
[
"debug=0",
"neon=1",
"opencl=0",
"os=linux",
"openmp=1",
"cppthreads=0",
"arch=armv8a",
"multi_isa=1",
"fixed_format_kernels=1",
"build=native",
]
)
host.run_cmd(
f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v25.02 {git_clone_flags}"
)
host.run_cmd(f"cd ComputeLibrary && scons Werror=1 -j8 {acl_build_flags}")
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd(
@ -700,7 +666,6 @@ def start_build(
configure_system(
host, compiler=compiler, use_conda=use_conda, python_version=python_version
)
build_OpenBLAS(host, git_clone_flags)
if host.using_docker():
print("Move libgfortant.a into a standard location")
@ -723,6 +688,8 @@ def start_build(
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
)
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
print("Building PyTorch wheel")
build_opts = ""
if pytorch_build_number is not None:
@ -743,16 +710,18 @@ def start_build(
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn:
build_ArmComputeLibrary(host, git_clone_flags)
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
build_vars += " BLAS=OpenBLAS"
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
build_vars += " ACL_ROOT_DIR=/acl"
host.run_cmd(
f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && "
f"{build_vars} python3 -m build --wheel --no-isolation{build_opts}"
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Repair the wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
ld_library_path = "$HOME/acl/build:$HOME/pytorch/build/lib"
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
host.run_cmd(
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
)
@ -908,7 +877,7 @@ def terminate_instances(instance_type: str) -> None:
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Builid and test AARCH64 wheels using EC2")
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")

View File

@ -1 +1 @@
bbb06c0334a6772b92d24bde54956e675c8c6604
27664085f804afc83df26f740bb46c365854f2c4

27
.ci/docker/common/install_acl.sh Normal file → Executable file
View File

@ -1,16 +1,27 @@
set -euo pipefail
#!/bin/bash
# Script used only in CD pipeline
readonly version=v25.02
readonly src_host=https://github.com/ARM-software
readonly src_repo=ComputeLibrary
set -eux
ACL_VERSION=${ACL_VERSION:-"v25.02"}
ACL_INSTALL_DIR="/acl"
# Clone ACL
[[ ! -d ${src_repo} ]] && git clone ${src_host}/${src_repo}.git
cd ${src_repo}
git checkout $version
git clone https://github.com/ARM-software/ComputeLibrary.git -b "${ACL_VERSION}" --depth 1 --shallow-submodules
ACL_CHECKOUT_DIR="ComputeLibrary"
# Build with scons
pushd $ACL_CHECKOUT_DIR
scons -j8 Werror=0 debug=0 neon=1 opencl=0 embed_kernels=0 \
os=linux arch=armv8a build=native multi_isa=1 \
fixed_format_kernels=1 openmp=1 cppthreads=0
popd
# Install ACL
sudo mkdir -p ${ACL_INSTALL_DIR}
for d in arm_compute include utils support src build
do
sudo cp -r ${ACL_CHECKOUT_DIR}/${d} ${ACL_INSTALL_DIR}/${d}
done
rm -rf $ACL_CHECKOUT_DIR

12
.ci/docker/common/install_openblas.sh Normal file → Executable file
View File

@ -3,8 +3,10 @@
set -ex
cd /
git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION:-v0.3.30}" --depth 1 --shallow-submodules
OPENBLAS_VERSION=${OPENBLAS_VERSION:-"v0.3.30"}
# Clone OpenBLAS
git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" --depth 1 --shallow-submodules
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
@ -17,5 +19,7 @@ CFLAGS=-O3
BUILD_BFLOAT16=1
"
make -j8 ${OPENBLAS_BUILD_FLAGS} -C ${OPENBLAS_CHECKOUT_DIR}
make -j8 ${OPENBLAS_BUILD_FLAGS} install -C ${OPENBLAS_CHECKOUT_DIR}
make -j8 ${OPENBLAS_BUILD_FLAGS} -C $OPENBLAS_CHECKOUT_DIR
sudo make install -C $OPENBLAS_CHECKOUT_DIR
rm -rf $OPENBLAS_CHECKOUT_DIR

View File

@ -62,6 +62,13 @@ ARG OPENBLAS_VERSION
ADD ./common/install_openblas.sh install_openblas.sh
RUN bash ./install_openblas.sh && rm install_openblas.sh
# Install Arm Compute Library
FROM base as arm_compute
# use python3.9 to install scons
RUN python3.9 -m pip install scons==4.7.0
RUN ln -sf /opt/python/cp39-cp39/bin/scons /usr/local/bin
COPY ./common/install_acl.sh install_acl.sh
RUN bash ./install_acl.sh && rm install_acl.sh
FROM base as final
# remove unnecessary python versions
@ -70,4 +77,5 @@ RUN rm -rf /opt/python/cp26-cp26mu /opt/_internal/cpython-2.6.9-ucs4
RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6
RUN rm -rf /opt/python/cp34-cp34m /opt/_internal/cpython-3.4.6
COPY --from=openblas /opt/OpenBLAS/ /opt/OpenBLAS/
ENV LD_LIBRARY_PATH=/opt/OpenBLAS/lib:$LD_LIBRARY_PATH
COPY --from=arm_compute /acl /acl
ENV LD_LIBRARY_PATH=/opt/OpenBLAS/lib:/acl/build/:$LD_LIBRARY_PATH

View File

@ -86,6 +86,15 @@ FROM base as nvpl
ADD ./common/install_nvpl.sh install_nvpl.sh
RUN bash ./install_nvpl.sh && rm install_nvpl.sh
# Install Arm Compute Library
FROM base as arm_compute
# use python3.9 to install scons
RUN python3.9 -m pip install scons==4.7.0
RUN ln -sf /opt/python/cp39-cp39/bin/scons /usr/local/bin
COPY ./common/install_acl.sh install_acl.sh
RUN bash ./install_acl.sh && rm install_acl.sh
FROM base as final
FROM final as cuda_final
ARG BASE_CUDA_VERSION
RUN rm -rf /usr/local/cuda-${BASE_CUDA_VERSION}
@ -93,5 +102,7 @@ COPY --from=cuda /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BAS
COPY --from=magma /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION}
COPY --from=nvpl /opt/nvpl/lib/ /usr/local/lib/
COPY --from=nvpl /opt/nvpl/include/ /usr/local/include/
COPY --from=arm_compute /acl /acl
RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/acl/build/:$LD_LIBRARY_PATH

View File

@ -28,6 +28,7 @@ fi
MANY_LINUX_VERSION=${MANY_LINUX_VERSION:-}
DOCKERFILE_SUFFIX=${DOCKERFILE_SUFFIX:-}
OPENBLAS_VERSION=${OPENBLAS_VERSION:-}
ACL_VERSION=${ACL_VERSION:-}
case ${image} in
manylinux2_28-builder:cpu)
@ -41,7 +42,6 @@ case ${image} in
GPU_IMAGE=arm64v8/almalinux:8
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13 --build-arg NINJA_VERSION=1.12.1"
MANY_LINUX_VERSION="2_28_aarch64"
OPENBLAS_VERSION="v0.3.30"
;;
manylinuxs390x-builder:cpu-s390x)
TARGET=final
@ -119,7 +119,8 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]')
DOCKER_BUILDKIT=1 docker build \
${DOCKER_GPU_BUILD_ARG} \
--build-arg "GPU_IMAGE=${GPU_IMAGE}" \
--build-arg "OPENBLAS_VERSION=${OPENBLAS_VERSION}" \
--build-arg "OPENBLAS_VERSION=${OPENBLAS_VERSION:-}" \
--build-arg "ACL_VERSION=${ACL_VERSION:-}" \
--target "${TARGET}" \
-t "${tmp_tag}" \
$@ \

View File

@ -52,10 +52,10 @@ flatbuffers==24.12.23
#Pinned versions: 24.12.23
#test that import:
hypothesis==5.35.1
hypothesis==6.56.4
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
#Description: advanced library for generating parametrized tests
#Pinned versions: 5.35.1
#Pinned versions: 6.56.4
#test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py
junitparser==2.1.1
@ -98,7 +98,7 @@ librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x"
#Pinned versions:
#test that import:
mypy==1.16.0 ; platform_system != "Windows"
mypy==1.16.0 ; platform_system == "Linux"
# Pin MyPy version because new errors are likely to appear with each release
# Skip on Windows as lots of type annotations are POSIX specific
#Description: linter
@ -169,7 +169,7 @@ optree==0.13.0
pillow==11.0.0
#Description: Python Imaging Library fork
#Pinned versions: 10.3.0
#Pinned versions: 11.0.0
#test that import:
protobuf==5.29.5
@ -217,7 +217,7 @@ pytest-subtests==0.13.1
#Pinned versions:
#test that import:
xdoctest==1.1.0
xdoctest==1.3.0
#Description: runs doctests in pytest
#Pinned versions: 1.1.0
#test that import:
@ -268,7 +268,7 @@ scipy==1.14.1 ; python_version >= "3.12"
#test that import:
# needed by torchgen utils
typing-extensions>=4.10.0
typing-extensions==4.12.2
#Description: type hints for python
#Pinned versions:
#test that import:
@ -361,9 +361,10 @@ pwlf==2.2.1
#test that import: test_sac_estimator.py
# To build PyTorch itself
pyyaml
pyyaml==6.0.2
pyzstd
setuptools>=70.1.0
setuptools==78.1.1
packaging==23.1
six
scons==4.5.2 ; platform_machine == "aarch64"
@ -384,7 +385,10 @@ cmake==3.31.6
tlparse==0.4.0
#Description: required for log parsing
cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x"
filelock==3.18.0
#Description: required for inductor testing
cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x" and platform_system != "Darwin"
#Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits.
#test that import: test_cuda.py

View File

@ -9,7 +9,7 @@ standard-imghdr==3.13.0; python_version >= "3.13"
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@d53b0ffb9b1cda68260693ea98f3483823c88d8e#egg=pytorch_sphinx_theme2
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.

View File

@ -107,6 +107,10 @@ if [[ $ROCM_INT -ge 60200 ]]; then
ROCM_SO_FILES+=("librocm-core.so")
fi
if [[ $ROCM_INT -ge 70000 ]]; then
ROCM_SO_FILES+=("librocroller.so")
fi
OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release`
if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"

View File

@ -89,7 +89,7 @@ fi
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
export USE_MKLDNN_ACL=1
export ACL_ROOT_DIR=/ComputeLibrary
export ACL_ROOT_DIR=/acl
fi
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then

View File

@ -26,6 +26,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_compute_comm_reordering
time python test/run_test.py --verbose -i distributed/test_aten_comm_compute_reordering
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_symmetric_memory
time python test/run_test.py --verbose -i distributed/test_pg_wrapper

View File

@ -435,7 +435,7 @@ test_inductor_distributed() {
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
# with if required # gpus aren't available
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_compute_comm_reordering --verbose
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_aten_comm_compute_reordering distributed/test_compute_comm_reordering --verbose
assert_git_not_dirty
}

View File

@ -0,0 +1,32 @@
#!/bin/bash
set -ex -o pipefail
# Suppress ANSI color escape sequences
export TERM=vt100
# shellcheck source=./common.sh
source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
# shellcheck source=./common-build.sh
source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
echo "Environment variables"
env
echo "Testing FA3 stable wheel still works with currently built torch"
echo "Installing ABI Stable FA3 wheel"
# The wheel was built on https://github.com/Dao-AILab/flash-attention/commit/b3846b059bf6b143d1cd56879933be30a9f78c81
# on torch nightly torch==2.9.0.dev20250830+cu129
$MAYBE_SUDO pip -q install https://s3.amazonaws.com/ossci-linux/wheels/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
pushd flash-attention/hopper
export PYTHONPATH=$PWD
pytest -v -s \
"test_flash_attn.py::test_flash_attn_output[1-1-192-False-False-False-0.0-False-False-mha-dtype0]" \
"test_flash_attn.py::test_flash_attn_varlen_output[511-1-64-True-False-False-0.0-False-False-gqa-dtype2]" \
"test_flash_attn.py::test_flash_attn_kvcache[1-128-128-False-False-True-None-0.0-False-False-True-False-True-False-gqa-dtype0]" \
"test_flash_attn.py::test_flash_attn_race_condition[97-97-192-True-dtype0]" \
"test_flash_attn.py::test_flash_attn_combine[2-3-64-dtype1]" \
"test_flash_attn.py::test_flash3_bw_compatibility"
popd

View File

@ -38,10 +38,12 @@ if errorlevel 1 goto fail
if not errorlevel 0 goto fail
:: Update CMake
:: TODO: Investigate why this helps MKL detection, even when CMake from choco is not used
call choco upgrade -y cmake --no-progress --installargs 'ADD_CMAKE_TO_PATH=System' --apply-install-arguments-to-dependencies --version=3.27.9
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
:: TODO: Move to .ci/docker/requirements-ci.txt
call pip install mkl==2024.2.0 mkl-static==2024.2.0 mkl-include==2024.2.0
if errorlevel 1 goto fail
if not errorlevel 0 goto fail

View File

@ -37,27 +37,8 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
export PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda"
fi
# TODO: Move both of them to Windows AMI
python -m pip install tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1
# Copied from https://github.com/pytorch/test-infra/blob/be01a40157c36cd5a48391fdf44a7bc3ebd4c7e3/aws/ami/windows/scripts/Installers/Install-Pip-Dependencies.ps1#L16 with some adjustments
# pytest-rerunfailures==10.3 as 10.2 fails with INTERNALERROR> pluggy._manager.PluginValidationError: unknown hook 'pytest_configure_node'
# scipy from 1.6.3 to 1.10
# expecttest from 0.1.3 to 0.3.0
# xdoctest from 1.0.2 to 1.3.0
python -m pip install "future==0.18.2" "hypothesis==5.35.1" "expecttest==0.3.0" "librosa>=0.6.2" "scipy==1.10.1" "psutil==5.9.1" "pynvml==11.4.1" "pillow==9.2.0" "unittest-xml-reporting<=3.2.0,>=2.0.0" "pytest==7.1.3" "pytest-xdist==2.5.0" "pytest-flakefinder==1.1.0" "pytest-rerunfailures==10.3" "pytest-shard==0.1.2" "sympy==1.11.1" "xdoctest==1.3.0" "pygments==2.12.0" "opt-einsum>=3.3" "networkx==2.8.8" "mpmath==1.2.1" "pytest-cpp==2.3.0" "boto3==1.35.42"
# Install Z3 optional dependency for Windows builds.
python -m pip install z3-solver==4.15.1.0
# Install tlparse for test\dynamo\test_structured_trace.py UTs.
python -m pip install tlparse==0.4.0
# Install parameterized
python -m pip install parameterized==0.8.1
# Install pulp for testing ilps under torch\distributed\_tools
python -m pip install pulp==2.9.0
# TODO: Move this to .ci/docker/requirements-ci.txt
python -m pip install "psutil==5.9.1" "pynvml==11.4.1" "pytest-shard==0.1.2"
run_tests() {
# Run nvidia-smi if available

View File

@ -23,9 +23,6 @@ runs:
run: |
.github\scripts\kill_active_ssh_sessions.ps1
- name: Clean up leftover processes on non-ephemeral Windows runner
uses: pytorch/test-infra/.github/actions/cleanup-runner@main
# Cleaning up Windows workspace sometimes fails flakily with device or resource busy
# error, meaning one or more processes haven't stopped completely yet. So trying to
# retry this step several time similar to how checkout-pytorch GHA does

View File

@ -1 +1 @@
0307428d65acf5cf1a73a70a7722e076bbb83f22
78a47f87ce259a48f0391fa9ae15add05ea7432b

View File

@ -202,7 +202,7 @@ ARG max_jobs=16
ENV MAX_JOBS=${max_jobs}
ARG nvcc_threads=4
ENV NVCC_THREADS=$nvcc_threads
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
ARG torch_cuda_arch_list='8.0 8.6 8.9 9.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
ARG USE_SCCACHE
@ -297,16 +297,28 @@ RUN echo "[INFO] Listing current directory before torch install step:" && \
echo "[INFO] Showing torch_build_versions.txt content:" && \
cat torch_build_versions.txt
# Install build and runtime dependencies, this is needed for flashinfer install
COPY requirements/build.txt requirements/build.txt
COPY use_existing_torch.py use_existing_torch.py
RUN python3 use_existing_torch.py
RUN cat requirements/build.txt
# Install uv for faster pip installs if not existed
RUN --mount=type=cache,target=/root/.cache/uv \
if ! python3 -m uv --version > /dev/null 2>&1; then \
python3 -m pip install uv==0.8.4; \
fi
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt
# Default mount file as placeholder, this just avoid the mount error
ARG TORCH_WHEELS_PATH="./requirements"
# Install torch, torchaudio and torchvision
@ -332,13 +344,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Install xformers wheel from previous stage
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system /wheels/xformers/*.whl --verbose
# Build flashinfer from source.
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
# install package for build flashinfer
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
RUN pip install build==1.3.0
RUN pip freeze | grep -E 'setuptools|packaging|build'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}

View File

@ -1,9 +1,14 @@
import glob
import os
requires_files = glob.glob("requirements/*.txt")
requires_files += ["pyproject.toml"]
for file in requires_files:
if not os.path.exists(file):
print(f"!!! skipping missing {file}")
continue
print(f">>> cleaning {file}")
with open(file) as f:
lines = f.readlines()

View File

@ -1,37 +0,0 @@
boto3==1.35.42
build==1.2.2.post1
cmake==3.27.*
expecttest==0.3.0
fbscribelogger==0.1.7
filelock==3.18.0
hypothesis==6.56.4
librosa>=0.6.2
mpmath==1.3.0
networkx==2.8.7
ninja==1.10.2.4
numba==0.59.0
numpy==1.26.4
opt-einsum>=3.3
optree==0.13.0
packaging==23.1
parameterized==0.8.1
pillow==10.3.0
protobuf==5.29.5
psutil==5.9.8
pygments==2.15.0
pytest-cpp==2.3.0
pytest-flakefinder==1.1.0
pytest-rerunfailures==10.3
pytest-subtests==0.13.1
pytest-xdist==3.3.1
pytest==7.3.2
pyyaml==6.0.2
scipy==1.12.0
setuptools==78.1.1
sympy==1.13.3
tlparse==0.4.0
tensorboard==2.13.0
typing-extensions==4.12.2
unittest-xml-reporting<=3.2.0,>=2.0.0
xdoctest==1.1.0
z3-solver==4.15.1.0

View File

@ -127,53 +127,6 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
),
]
ROCM_SMOKE_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_variant="rocm",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["6.4"],
python_versions=["3.10"],
),
ciflow_config=CIFlowConfig(
labels={
LABEL_CIFLOW_BINARIES,
LABEL_CIFLOW_BINARIES_WHEEL,
LABEL_CIFLOW_ROCM,
},
isolated_workflow=True,
),
branches="main",
),
]
LINUX_BINARY_SMOKE_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["13.0"],
python_versions=["3.12"],
),
branches="main",
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="libtorch",
build_variant=generate_binary_build_matrix.RELEASE,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.LINUX,
generate_binary_build_matrix.RELEASE,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
branches="main",
),
]
WINDOWS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
@ -259,39 +212,6 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
),
]
WINDOWS_BINARY_SMOKE_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
package_type="libtorch",
build_variant=generate_binary_build_matrix.RELEASE,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.WINDOWS,
generate_binary_build_matrix.RELEASE,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
branches="main",
ciflow_config=CIFlowConfig(
isolated_workflow=True,
),
),
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
package_type="libtorch",
build_variant=generate_binary_build_matrix.DEBUG,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.WINDOWS,
generate_binary_build_matrix.DEBUG,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
branches="main",
ciflow_config=CIFlowConfig(
isolated_workflow=True,
),
),
]
MACOS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.MACOS_ARM64,
@ -372,23 +292,10 @@ def main() -> None:
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
S390X_BINARY_BUILD_WORKFLOWS,
),
(
# Give rocm it's own workflow file
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
ROCM_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
LINUX_BINARY_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
WINDOWS_BINARY_BUILD_WORKFLOWS,
),
(
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
WINDOWS_BINARY_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("macos_binary_build_workflow.yml.j2"),
MACOS_BINARY_BUILD_WORKFLOWS,

View File

@ -0,0 +1,255 @@
# The point of this workflow is to test that a FA3 wheel that was built based off the
# stable ABI as of torch nightly 20250830 can still run on the newer torch.
#
# This workflow is very similar to the _linux-test.yml workflow, with the following
# differences:
# 1. It is simpler (there is no test matrix)
# 2. It pulls flash-attention as a secondary repository in order to access the tests.
# Note that it does not BUILD anything from flash-attention, as we have a prebuilt
# wheel. We pull flash-attention only to run a few tests.
# 3. It runs only FA3 tests. No PyTorch tests are run.
name: linux-test-stable-fa3
on:
workflow_call:
inputs:
build-environment:
required: true
type: string
description: Top-level label for what's being built/tested.
docker-image:
required: true
type: string
description: Docker image to run in.
timeout-minutes:
required: false
type: number
default: 30
description: |
Set the maximum (in minutes) how long the workflow should take to finish
s3-bucket:
description: S3 bucket to download artifact
required: false
type: string
default: "gha-artifacts"
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
description: |
HF Auth token to avoid rate limits when downloading models or datasets from hub
VLLM_TEST_HUGGING_FACE_TOKEN:
required: false
description: |
HF Auth token to test vllm
SCRIBE_GRAPHQL_ACCESS_TOKEN:
required: false
description: |
FB app token to write to scribe endpoint
env:
GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
jobs:
test:
# Don't run on forked repos
if: github.repository_owner == 'pytorch'
runs-on: linux.aws.h100
timeout-minutes: ${{ inputs.timeout-minutes || 30 }}
permissions:
id-token: write
contents: read
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
no-sudo: true
- name: Checkout flash-attention as a secondary repository
uses: actions/checkout@v4
with:
repository: Dao-AILab/flash-attention
path: flash-attention
- name: Setup Linux
uses: ./.github/actions/setup-linux
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-image-name: ${{ inputs.docker-image }}
- name: Use following to pull public copy of the image
id: print-ghcr-mirror
env:
ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
shell: bash
run: |
tag=${ECR_DOCKER_IMAGE##*:}
echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}"
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Check if in a container runner
shell: bash
id: check_container_runner
run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT"
- name: Setup GPU_FLAG for docker run
id: setup-gpu-flag
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
id: setup-sscache-port-flag
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }}
- name: Get workflow job id
id: get-job-id
uses: ./.github/actions/get-workflow-job-id
if: always()
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Download build artifacts
uses: ./.github/actions/download-build-artifacts
with:
name: ${{ inputs.build-environment }}
s3-bucket: ${{ inputs.s3-bucket }}
- name: Parse ref
id: parse-ref
run: .github/scripts/parse_ref.py
- name: Set Test step time
id: test-timeout
shell: bash
env:
JOB_TIMEOUT: ${{ inputs.timeout-minutes }}
run: |
echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}"
- name: Preserve github env variables for use in docker
shell: bash
run: |
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}"
- name: Test
id: test
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
env:
BUILD_ENVIRONMENT: ${{ inputs.build-environment }}
PR_NUMBER: ${{ github.event.pull_request.number }}
GITHUB_REPOSITORY: ${{ github.repository }}
GITHUB_WORKFLOW: ${{ github.workflow }}
GITHUB_JOB: ${{ github.job }}
GITHUB_RUN_ID: ${{ github.run_id }}
GITHUB_RUN_NUMBER: ${{ github.run_number }}
GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }}
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
BRANCH: ${{ steps.parse-ref.outputs.branch }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
SHM_SIZE: '2g'
DOCKER_IMAGE: ${{ inputs.docker-image }}
VLLM_TEST_HUGGING_FACE_TOKEN: ${{ secrets.VLLM_TEST_HUGGING_FACE_TOKEN }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }}
ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ steps.get-job-id.outputs.job-id }}
run: |
set -x
TEST_COMMAND=.ci/pytorch/test_fa3_abi_stable.sh
# Leaving 1GB for the runner and other things
TOTAL_AVAILABLE_MEMORY_IN_GB=$(awk '/MemTotal/ { printf "%.3f \n", $2/1024/1024 - 1 }' /proc/meminfo)
# https://docs.docker.com/engine/containers/resource_constraints/#--memory-swap-details, the 3GB swap
# comes from https://github.com/pytorch/test-infra/pull/6058
TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
SHM_OPTS="--shm-size=${SHM_SIZE}"
JENKINS_USER="--user jenkins"
DOCKER_SHELL_CMD=
# detached container should get cleaned up by teardown_ec2_linux
# TODO: Stop building test binaries as part of the build phase
# Used for GPU_FLAG, SHM_OPTS, JENKINS_USER and DOCKER_SHELL_CMD since that doesn't play nice
# shellcheck disable=SC2086,SC2090
container_name=$(docker run \
${GPU_FLAG:-} \
${SCCACHE_SERVER_PORT_DOCKER_FLAG:-} \
-e BUILD_ENVIRONMENT \
-e PR_NUMBER \
-e GITHUB_ACTIONS \
-e GITHUB_REPOSITORY \
-e GITHUB_WORKFLOW \
-e GITHUB_JOB \
-e GITHUB_RUN_ID \
-e GITHUB_RUN_NUMBER \
-e GITHUB_RUN_ATTEMPT \
-e JOB_ID \
-e JOB_NAME \
-e BASE_SHA \
-e BRANCH \
-e SHA1 \
-e MAX_JOBS="$(nproc --ignore=2)" \
-e HUGGING_FACE_HUB_TOKEN \
-e VLLM_TEST_HUGGING_FACE_TOKEN \
-e SCRIBE_GRAPHQL_ACCESS_TOKEN \
-e ARTIFACTS_FILE_SUFFIX \
--memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \
--memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE \
--ipc=host \
${SHM_OPTS} \
--tty \
--detach \
--name="${container_name}" \
${JENKINS_USER} \
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
-w /var/lib/jenkins/workspace \
"${DOCKER_IMAGE}" \
${DOCKER_SHELL_CMD}
)
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
docker exec -t "${container_name}" sh -c "python3 -m pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"
- name: Collect backtraces from coredumps (if any)
if: always()
run: |
# shellcheck disable=SC2156
find . -iname "core.[1-9]*" -exec docker exec "${DOCKER_CONTAINER_ID}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \;
- name: Store Core dumps on S3
uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0
if: failure()
with:
name: coredumps-fa3-stable-abi-smoke-tests
retention-days: 14
if-no-files-found: ignore
path: ./**/core.[1-9]*
- name: Upload utilization stats
if: ${{ always() && steps.test.conclusion && steps.test.conclusion != 'skipped' }}
continue-on-error: true
uses: ./.github/actions/upload-utilization-stats
with:
job_id: ${{ steps.get-job-id.outputs.job-id }}
job_name: ${{ steps.get-job-id.outputs.job-name }}
workflow_name: ${{ github.workflow }}
workflow_run_id: ${{github.run_id}}
workflow_attempt: ${{github.run_attempt}}
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always() && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false'

View File

@ -85,7 +85,7 @@ jobs:
uses: pytorch/test-infra/.github/actions/setup-python@main
with:
python-version: ${{ inputs.python-version }}
pip-requirements-file: .github/requirements/pip-requirements-macOS.txt
pip-requirements-file: .ci/docker/requirements-ci.txt
- name: Install sccache (only for non-forked PRs, and pushes to trunk)
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0

View File

@ -122,7 +122,7 @@ jobs:
uses: pytorch/test-infra/.github/actions/setup-python@main
with:
python-version: ${{ inputs.python-version }}
pip-requirements-file: .github/requirements/pip-requirements-macOS.txt
pip-requirements-file: .ci/docker/requirements-ci.txt
- name: Start monitoring script
id: monitor-script

View File

@ -84,9 +84,6 @@ jobs:
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
- name: Clean up leftover processes on non-ephemeral Windows runner
uses: pytorch/test-infra/.github/actions/cleanup-runner@main
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
with:

View File

@ -77,9 +77,6 @@ jobs:
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
- name: Clean up leftover processes on non-ephemeral Windows runner
uses: pytorch/test-infra/.github/actions/cleanup-runner@main
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
with:
@ -106,18 +103,6 @@ jobs:
with:
cuda-version: ${{ inputs.cuda-version }}
# TODO: Move to a requirements.txt file for windows
- name: Install pip dependencies
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
shell: bash
timeout_minutes: 5
max_attempts: 5
retry_wait_seconds: 30
command: |
set -eu
python3 -m pip install 'xdoctest>=1.1.0'
- name: Get workflow job id
id: get-job-id
uses: ./.github/actions/get-workflow-job-id
@ -272,15 +257,6 @@ jobs:
shell: bash
run: python3 .github/scripts/parse_ref.py
- name: Uninstall PyTorch
if: always()
continue-on-error: true
shell: bash
run: |
# This step removes PyTorch installed by the test to give a clean slate
# to the next job
python3 -mpip uninstall -y torch
- name: Teardown Windows
uses: ./.github/actions/teardown-win
if: always()

View File

@ -36,7 +36,7 @@ jobs:
runs-on: linux.9xlarge.ephemeral
strategy:
matrix:
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "rocm7.0", "cpu"]
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
steps:
- name: Build docker image
uses: pytorch/pytorch/.github/actions/binary-docker-build@main

View File

@ -1,87 +0,0 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-libtorch-release
on:
push:
branches:
- main
tags:
- 'ciflow/trunk/*'
workflow_dispatch:
permissions:
id-token: write
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-libtorch-release
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: cpu
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: libtorch-cpu-shared-with-deps-release
build_environment: linux-binary-libtorch-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-cpu-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cpu-shared-with-deps-release-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: cpu
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-cpu-shared-with-deps-release
build_environment: linux-binary-libtorch-release
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,88 +0,0 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-manywheel
on:
push:
branches:
- main
tags:
- 'ciflow/trunk/*'
workflow_dispatch:
permissions:
id-token: write
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-manywheel
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
manywheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu130
GPU_ARCH_VERSION: "13.0"
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
DESIRED_PYTHON: "3.12"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_12-cuda13_0
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-cuda13_0-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_12-cuda13_0-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu130
GPU_ARCH_VERSION: "13.0"
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cuda13_0
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,136 +0,0 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-manywheel-rocm
on:
push:
branches:
- main
tags:
- 'ciflow/binaries/*'
- 'ciflow/binaries_wheel/*'
- 'ciflow/rocm/*'
workflow_dispatch:
permissions:
id-token: write
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-manywheel-rocm
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-manywheel-rocm-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
manywheel-py3_10-rocm6_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.10"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: manywheel-py3_10-rocm6_4
build_environment: linux-binary-manywheel-rocm
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_10-rocm6_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_10-rocm6_4-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.10"
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: manywheel-py3_10-rocm6_4
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: manylinux2_28-builder
custom-tag-prefix: rocm6.4
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm

View File

@ -1,261 +0,0 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: windows-binary-libtorch-debug
on:
push:
branches:
- main
workflow_dispatch:
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BUILD_ENVIRONMENT: windows-binary-libtorch-debug
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 1
OS: windows
concurrency:
group: windows-binary-libtorch-debug-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cpu-shared-with-deps-debug
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cpu-shared-with-deps-debug-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cpu-shared-with-deps-debug-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cpu-shared-with-deps-debug
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1

View File

@ -1,261 +0,0 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: windows-binary-libtorch-release
on:
push:
branches:
- main
workflow_dispatch:
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BUILD_ENVIRONMENT: windows-binary-libtorch-release
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 1
OS: windows
concurrency:
group: windows-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cpu-shared-with-deps-release
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cpu-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cpu-shared-with-deps-release-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cpu-shared-with-deps-release
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1

View File

@ -61,3 +61,15 @@ jobs:
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-sm90-FA3-ABI-stable-test:
name: linux-jammy-cuda12_8-py3_10-gcc11-sm90-FA3-ABI-stable-test
uses: ./.github/workflows/_linux-test-stable-fa3.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-sm90-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.docker-image }}
timeout-minutes: 30
s3-bucket: gha-artifacts
secrets: inherit

View File

@ -49,5 +49,6 @@ jobs:
pip install awscli==1.29.40
aws s3 cp "/tmp/${LATEST_SHA}.json" "s3://ossci-raw-job-status/stable_pushes/pytorch/pytorch/${LATEST_SHA}.json"
# Push new viable/strict tag
cd pytorch/pytorch
git push origin "${LATEST_SHA}:refs/tags/viable/strict/${TIME}"
fi

View File

@ -42,7 +42,7 @@ jobs:
build-external-packages: "vllm"
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm
cuda-arch-list: '8.0;8.9;9.0'
cuda-arch-list: '8.0 8.9 9.0'
runner: linux.24xlarge.memory
test-matrix: |
{ include: [

View File

@ -1260,6 +1260,7 @@ exclude_patterns = [
'test/test_masked.py',
'test/test_maskedtensor.py',
'test/test_matmul_cuda.py',
'test/test_scaled_matmul_cuda.py',
'test/test_meta.py',
'test/test_metal.py',
'test/test_mkl_verbose.py',

View File

@ -81,7 +81,7 @@ git remote add upstream git@github.com:pytorch/pytorch.git
make setup-env
# Or run `make setup-env-cuda` for pre-built CUDA binaries
# Or run `make setup-env-rocm` for pre-built ROCm binaries
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
### Tips and Debugging
@ -182,28 +182,36 @@ You can use this script to check out a new nightly branch with the following:
```bash
./tools/nightly.py checkout -b my-nightly-branch
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
```bash
./tools/nightly.py checkout -b my-nightly-branch --cuda
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
To install the nightly binaries built with ROCm, you can pass in the flag `--rocm`:
```bash
./tools/nightly.py checkout -b my-nightly-branch --rocm
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
You can also use this tool to pull the nightly commits into the current branch:
```bash
./tools/nightly.py pull -p my-env
source my-env/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
./tools/nightly.py pull
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
To create the virtual environment with a specific Python interpreter, you can
pass in the `--python` argument:
```bash
./tools/nightly.py --python /path/to/python3.12
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
Pulling will recreate a fresh virtual environment and reinstall the development

View File

@ -6,6 +6,7 @@
#include <c10/core/thread_pool.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/llvmMathExtras.h>
#include <iostream>
#include <optional>
#include <deque>
@ -75,6 +76,9 @@ struct TORCH_API HostStats {
// COUNT: number of times cudaHostFree/cudaHostUnregister was called.
int64_t num_host_free = 0; // This is derived from segment or timing
// Count of cudaHostFree/cudaHostUnregister per bucket
std::vector<int64_t> bucket_allocation = std::vector<int64_t>(MAX_SIZE_INDEX);
};
// Struct containing memory allocator summary statistics for host, as they
@ -196,27 +200,7 @@ struct CachingHostAllocatorImpl {
// background.
if (!pinned_use_background_threads()) {
process_events();
}
// Round up the allocation to the nearest power of two to improve reuse.
// These power of two sizes are also used to index into the free list.
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
// First, try to allocate from the free list
auto* block = get_free_block(roundSize);
if (block) {
return {block->ptr_, reinterpret_cast<void*>(block)};
}
// Check in the recently freed blocks with pending events to see if we
// can reuse them. Call get_free_block again after processing events
if (pinned_use_background_threads()) {
process_events_for_specific_size(roundSize);
block = get_free_block(roundSize);
if (block) {
return {block->ptr_, reinterpret_cast<void*>(block)};
}
} else {
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
getBackgroundThreadPool()->run([&]() {
@ -229,6 +213,16 @@ struct CachingHostAllocatorImpl {
}();
}
// Round up the allocation to the nearest power of two to improve reuse.
// These power of two sizes are also used to index into the free list.
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
// First, try to allocate from the free list
auto* block = get_free_block(roundSize);
if (block) {
return {block->ptr_, reinterpret_cast<void*>(block)};
}
// Slow path: if we can't allocate from the cached free list, we need
// to create a new block.
void* ptr = nullptr;
@ -278,8 +272,6 @@ struct CachingHostAllocatorImpl {
auto index = size_index(block->size_);
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
free_list_[index].list_.push_back(block);
stats_.allocation_bucket_stats[index].decrease(1);
stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
} else {
// restore these events that record by used streams.
std::lock_guard<std::mutex> g(events_mutex_);
@ -339,9 +331,12 @@ struct CachingHostAllocatorImpl {
for (auto* block : blocks_to_remove) {
blocks_.erase(block);
ptr_to_block_.erase(block->ptr_);
auto index = size_index(block->size_);
free_block(block);
stats_.allocation.decrease(1);
stats_.allocated_bytes.decrease(block->size_);
free_block(block);
stats_.allocation_bucket_stats[index].decrease(1);
stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
delete block;
}
}
@ -398,6 +393,7 @@ struct CachingHostAllocatorImpl {
// a best effort manner, since we can't really replay the cached events per bucket.
add_bucket_stats(stats.allocation, stats_.allocation_bucket_stats[i]);
add_bucket_stats(stats.allocated_bytes, stats_.allocated_bytes_bucket_stats[i]);
stats.bucket_allocation[i] = stats_.allocation_bucket_stats[i].allocated;
}
// Get the timing stats
@ -488,8 +484,6 @@ struct CachingHostAllocatorImpl {
B* block = free_list_[index].list_.back();
free_list_[index].list_.pop_back();
block->allocated_ = true;
stats_.allocation_bucket_stats[index].increase(1);
stats_.allocated_bytes_bucket_stats[index].increase(size);
return block;
}
return nullptr;
@ -583,8 +577,6 @@ struct CachingHostAllocatorImpl {
auto index = size_index(block->size_);
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
free_list_[index].list_.push_back(block);
stats_.allocation_bucket_stats[index].decrease(1);
stats_.allocated_bytes_bucket_stats[index].decrease(size);
if (size != -1) {
return;
}

View File

@ -2,6 +2,7 @@
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <ATen/core/PythonFallbackKernel.h>
#include <c10/core/SafePyObject.h>
#include <ATen/record_function.h>
namespace {
@ -53,20 +54,24 @@ void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
// StashTLSOnEntryGuard stash_guard;
c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
c10::impl::ExcludeDispatchKeyGuard exclude_guard(after_Python_keyset);
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
if (mode_stack_len > 0) {
RECORD_FUNCTION("PythonDispatchMode", torch::jit::last(*stack, num_arguments));
const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
return;
}
RECORD_FUNCTION("PythonSubclass", torch::jit::last(*stack, num_arguments));
// Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
// without checking the interpreters of any of the arguments, because when
// we actually run dispatch(), we will take out PyObjects in the context

View File

@ -1,22 +1,32 @@
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <c10/core/impl/PyInterpreterHooks.h>
// TODO: delete this
namespace at::impl {
c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::interpreter_ = nullptr;
// The strategy is that all python interpreters attempt to register themselves
// as the main interpreter, but only one wins. Only that interpreter is
// allowed to interact with the C++ dispatcher. Furthermore, when we execute
// logic on that interpreter, we do so hermetically, never setting pyobj field
// on Tensor.
std::atomic<c10::impl::PyInterpreter*>
PythonOpRegistrationTrampoline::interpreter_{nullptr};
c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::getInterpreter() {
return c10::impl::getGlobalPyInterpreter();
return PythonOpRegistrationTrampoline::interpreter_.load();
}
bool PythonOpRegistrationTrampoline::registerInterpreter(
c10::impl::PyInterpreter* interp) {
if (interpreter_ != nullptr) {
c10::impl::PyInterpreter* expected = nullptr;
interpreter_.compare_exchange_strong(expected, interp);
if (expected != nullptr) {
// This is the second (or later) Python interpreter, which means we need
// non-trivial hermetic PyObject TLS
c10::impl::HermeticPyObjectTLS::init_state();
return false;
} else {
return true;
}
interpreter_ = interp;
return true;
}
} // namespace at::impl

View File

@ -2,21 +2,19 @@
#include <ATen/core/dispatch/Dispatcher.h>
// TODO: We can get rid of this
// TODO: this can probably live in c10
namespace at::impl {
// Manages the single Python interpreter instance for PyTorch.
class TORCH_API PythonOpRegistrationTrampoline final {
static c10::impl::PyInterpreter* interpreter_;
static std::atomic<c10::impl::PyInterpreter*> interpreter_;
public:
// Register the Python interpreter. Returns true on first registration,
// false if an interpreter was already registered.
// Returns true if you successfully registered yourself (that means
// you are in the hot seat for doing the operator registrations!)
static bool registerInterpreter(c10::impl::PyInterpreter*);
// Returns the registered interpreter via the global PyInterpreter hooks.
// Returns nullptr if no interpreter has been registered yet.
static c10::impl::PyInterpreter* getInterpreter();
};

View File

@ -151,11 +151,6 @@ struct CUDACachingHostAllocatorImpl
}
bool query_event(EventPool::Event& event) override {
// Do not call cudaEventQuery if capturing is underway
if (at::cuda::currentStreamCaptureStatusMayInitCtx() !=
at::cuda::CaptureStatus::None) {
return false;
}
cudaError_t err = cudaEventQuery(*event);
if (err == cudaErrorNotReady) {
(void)cudaGetLastError(); // clear CUDA error

View File

@ -90,6 +90,10 @@ public:
allocator_->setMemoryFraction(fraction, device);
}
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
return allocator_->getExpandableSegmentSizes(device);
}
void enable(bool value) override {
allocator_->enable(value);
}

View File

@ -670,6 +670,8 @@ Tensor rrelu_with_noise_backward(
}
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
TORCH_CHECK(std::isfinite(lower.to<double>()), "rrelu: lower bound must be finite, got ", lower.to<double>());
TORCH_CHECK(std::isfinite(upper.to<double>()), "rrelu: upper bound must be finite, got ", upper.to<double>());
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::rrelu_with_noise(self, noise, lower, upper, training, std::move(generator));

View File

@ -2801,6 +2801,7 @@ Tensor matrix_exp(const Tensor& a) {
// TODO This should be deprecated in favor of linalg_matrix_exp_differential
// in FunctionsManual.cpp
Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
squareCheckInputs(self, "matrix_exp_backward");
NoTF32Guard disable_tf32;
return backward_analytic_function_of_a_matrix(
self, grad,

View File

@ -1880,43 +1880,34 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) {
Tensor xtensor = self.expand(padded_size);
Tensor urtensor;
if (self.is_quantized()) {
urtensor = at::empty_quantized(target_size, self);
} else {
urtensor = at::empty(target_size, self.options());
}
// return an empty tensor if one of the repeat dimensions is zero
if (zero_tensor) {
return urtensor;
return self.is_quantized() ? at::empty_quantized(target_size, self)
: at::empty(target_size, self.options());
}
// Create view of shape [r0, s0, r1, s1, ...]
// where ri is repeat[i], si is self.size(i).
Tensor view = xtensor;
auto expand_shape = std::vector<int64_t>();
expand_shape.reserve(xtensor.dim() * 2);
for (const auto i : c10::irange(xtensor.dim())) {
// can't unfold with step 0, so make sure step is at least 1
// (it doesn't matter what it is in that case, because the size is 0).
auto size_i = xtensor.sizes()[i];
urtensor = urtensor.unfold(i, size_i, std::max<int64_t>(size_i, 1));
view = view.unsqueeze(2 * i);
expand_shape.push_back(repeats[i]);
expand_shape.push_back(xtensor.size(i));
}
// expanded_view is non-contiguous because .expand set stride to 0.
auto expanded_view = view.expand(expand_shape);
urtensor.copy_(xtensor.expand_as(urtensor));
// copy to contiguous tensor.
auto contiguous_copy = at::empty(
expanded_view.sizes(),
expanded_view.options(),
at::MemoryFormat::Contiguous);
contiguous_copy.copy_(expanded_view);
// Combine the dimensions to produce the target_size.
// xtensor dims: [a0, ..., ad-1]
// urtensor dims: [a0, ..., ad-1, b0, ..., bd-1]
// b dims are produced by unfold.
// Transform urtensor to [a0 * b0, ..., ad-1 * bd-1]
const int64_t n_dims = xtensor.dim();
auto range_a = at::arange(xtensor.dim(), at::TensorOptions(at::kLong));
auto range_b = range_a + n_dims;
auto stacked = stack({std::move(range_a), std::move(range_b)}, 1).flatten();
auto permutation = IntArrayRef(stacked.data_ptr<int64_t>(), n_dims * 2);
// Permute from [a0, ..., ad-1, b0, ..., bd-1] to [a0, b0, ..., ad-1, bd-1]
urtensor = urtensor.permute(permutation);
// Reshape from [a0, b0, ..., ad-1, bd-1] to [a0 * b0, ..., ad-1 * bd-1]
urtensor = urtensor.reshape(target_size);
return urtensor;
// Reshape to [s0 * r0, s1 * r1, ...].
// No extra copy of data during reshape for a contiguous tensor.
return contiguous_copy.view(target_size);
}
Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) {

View File

@ -1831,6 +1831,37 @@ std::optional<c10::ScalarType> out_dtype) {
return out;
}
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
const auto batch1_sizes = batch1.sizes();
const auto batch2_sizes = batch2.sizes();
int64_t bs = batch1_sizes[0];
int64_t contraction_size = batch1_sizes[2];
int64_t res_rows = batch1_sizes[1];
int64_t res_cols = batch2_sizes[2];
std::vector<int64_t> output_size {bs, res_rows, res_cols};
TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
"Expected size for first two dimensions of batch2 tensor to be: [",
bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
TORCH_CHECK(batch1.scalar_type() == batch2.scalar_type(), "batch1 and batch2 must have the same dtype");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
if (!is_bmm && self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
}
}
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
IntArrayRef batch1_sizes = batch1.sizes();
IntArrayRef batch2_sizes = batch2.sizes();
@ -1840,12 +1871,7 @@ Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::Sca
}
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
Scalar beta(0.0);
Scalar alpha(1.0);
{
@ -1864,12 +1890,7 @@ Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tenso
}
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
@ -1884,6 +1905,12 @@ Tensor _mm_dtype_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarTy
}
Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarType out_dtype, Tensor &out) {
TORCH_CHECK(self.dim() == 2, "self must be a matrix, got ", self.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "input dtypes must be the same");
TORCH_CHECK(out_dtype == self.scalar_type() ||
@ -1903,6 +1930,14 @@ Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& m
}
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == self.scalar_type() ||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),

View File

@ -10256,6 +10256,7 @@
structured: True
dispatch:
CPU, CUDA: all_all_out
MTIA: all_all_out_mtia
MPS: all_all_out_mps
- func: any(Tensor self) -> Tensor

View File

@ -101,6 +101,9 @@ __device__ inline bool isinf_device(float v) {
__device__ inline bool isinf_device(c10::BFloat16 v) {
return ::isinf(static_cast<float>(v));
}
__device__ inline bool isinf_device(at::Half v) {
return ::isinf(static_cast<float>(v));
}
// CUDA kernel to compute Moving Average Min/Max of the tensor.
// It uses the running_min and running_max along with averaging const, c.
@ -160,8 +163,8 @@ void _calculate_moving_average(
std::tie(x_min, x_max) = at::aminmax(x, 1);
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
@ -181,8 +184,8 @@ void _calculate_moving_average(
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
std::tie(x_min, x_max) = at::aminmax(x);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
@ -221,8 +224,8 @@ void _calc_moving_avg_qparams_helper(
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
if (per_row_fq) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
int num_threads = std::min(size, (int64_t)512);
@ -244,8 +247,8 @@ void _calc_moving_avg_qparams_helper(
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(

View File

@ -316,6 +316,12 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
return false;
#endif
#else
if (!at::cuda::is_available()) {
if (debug) {
TORCH_WARN("flash attention requires a CUDA device, which is not available.");
}
return false;
}
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm121>(dprops)) {
if (debug) {
@ -367,6 +373,12 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
return false;
#endif
#else
if (!at::cuda::is_available()) {
if (debug) {
TORCH_WARN("Mem Efficient attention requires a CUDA device, which is not available.");
}
return false;
}
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm121>(dprops)) {
if (debug) {
@ -597,6 +609,12 @@ bool check_cudnn_layout(sdp_params const& params, bool debug) {
bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
using sm80 = SMVersion<8, 0>;
using sm121 = SMVersion<12, 1>;
if (!at::cuda::is_available()) {
if (debug) {
TORCH_WARN("cuDNN SDPA requires a CUDA device, which is not available.");
}
return false;
}
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm121>(dprops)) {
if (debug) {

View File

@ -5,6 +5,7 @@
#include <torch/csrc/profiler/orchestration/vulkan.h>
#endif // USE_KINETO
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <iostream>

View File

@ -1,10 +1,83 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/test/allocator_clone_test.h>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
std::unordered_map<void*, size_t> allocation_sizes;
void* logging_malloc(size_t size, int device, cudaStream_t stream) {
void* ptr;
cudaMalloc(&ptr, size);
allocation_sizes[ptr] = size;
return ptr;
}
void logging_free(void* ptr, size_t size, int device, cudaStream_t stream) {
if (allocation_sizes.find(ptr) != allocation_sizes.end()) {
if (allocation_sizes[ptr] != size) {
throw std::runtime_error("free mismatch");
}
} else {
throw std::runtime_error("free of unknown ptr");
}
cudaFree(ptr);
allocation_sizes.erase(ptr);
}
TEST(TestTorchUnique, UniqueComparisonTest) {
if (!at::cuda::is_available()) return;
auto custom_allocator =
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(logging_malloc, logging_free);
torch::cuda::CUDAPluggableAllocator::changeCurrentAllocator(custom_allocator);
// Run the command 3 times; the first 2 will pass and the third invocation will have
// different sizes in alloc and free if the test fails.
for (int i = 0; i < 3; ++i) {
// Initialize simple sorted tensor with repeats
at::Tensor sorted_tensor =
at::tensor({0, 0, 0, 1, 1, 2, 3, 3, 3, 3, 5},
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA));
// This operation will call malloc/free with different sizes on the same pointer
auto unique_dim_result = at::unique_consecutive(sorted_tensor, false, true, 0);
// Everything below is only there to validate correct results
auto unique_dim_values = std::get<0>(unique_dim_result);
auto unique_dim_counts = std::get<2>(unique_dim_result);
// Check tensor sizes
EXPECT_EQ(unique_dim_values.size(0), 5);
EXPECT_EQ(unique_dim_counts.size(0), 5);
// Copy to CPU before accessing elements
at::Tensor cpu_values = unique_dim_values.cpu();
at::Tensor cpu_counts = unique_dim_counts.cpu();
// Use accessors on the CPU tensors
auto values_accessor = cpu_values.accessor<float, 1>();
auto counts_accessor = cpu_counts.accessor<int64_t, 1>();
// Check individual values using accessors
EXPECT_EQ(values_accessor[0], 0.0f);
EXPECT_EQ(values_accessor[1], 1.0f);
EXPECT_EQ(values_accessor[2], 2.0f);
EXPECT_EQ(values_accessor[3], 3.0f);
EXPECT_EQ(values_accessor[4], 5.0f);
// Check count values using accessors
EXPECT_EQ(counts_accessor[0], 3);
EXPECT_EQ(counts_accessor[1], 2);
EXPECT_EQ(counts_accessor[2], 1);
EXPECT_EQ(counts_accessor[3], 4);
EXPECT_EQ(counts_accessor[4], 1);
}
}
TEST(AllocatorTestCUDA, test_clone) {
if (!at::cuda::is_available()) return;
test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
}

View File

@ -50,6 +50,7 @@ run_if_exists cuda_complex_test
run_if_exists cuda_complex_math_test
run_if_exists cuda_cub_test
run_if_exists cuda_atomic_ops_test
run_if_exists cuda_allocator_test
if [ "$VALGRIND" == "ON" ]; then
# NB: As these tests are invoked by valgrind, let's leave them for now as it's

View File

@ -156,7 +156,7 @@ ROOT = "//" if IS_OSS else "//xplat/caffe2"
# for targets in subfolders
ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/"
C10 = "//c10:c10" if IS_OSS else ("//xplat/caffe2/c10:c10_ovrsource" if is_arvr_mode() else "//xplat/caffe2/c10:c10")
C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10"
# a dictionary maps third party library name to fbsource and oss target
THIRD_PARTY_LIBS = {

View File

@ -1,100 +1,16 @@
#pragma once
// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
// If you modify me, keep me synchronized with that file.
#include <c10/macros/Export.h>
#include <cstddef>
#include <cstdint>
#include <functional>
// If you modified DeviceType in caffe2/proto/caffe2.proto, please also sync
// your changes into torch/headeronly/core/DeviceType.h.
#include <torch/headeronly/core/DeviceType.h>
#include <ostream>
#include <string>
namespace c10 {
// These contains all device types that also have a BackendComponent
// and therefore participate in per-backend functionality dispatch keys.
// This is most backends except PrivateUse2 and PrivateUse3
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
_(CPU, extra) \
_(CUDA, extra) \
_(HIP, extra) \
_(XLA, extra) \
_(MPS, extra) \
_(IPU, extra) \
_(XPU, extra) \
_(HPU, extra) \
_(VE, extra) \
_(Lazy, extra) \
_(Meta, extra) \
_(MTIA, extra) \
_(PrivateUse1, extra)
enum class DeviceType : int8_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
MAIA = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
MPS = 13, // MPS
Meta = 14, // Meta (tensors with no data)
HPU = 15, // HPU / HABANA
VE = 16, // SX-Aurora / NEC
Lazy = 17, // Lazy Tensors
IPU = 18, // Graphcore IPU
MTIA = 19, // Meta training and inference devices
PrivateUse1 = 20, // PrivateUse1 device
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
};
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMAIA = DeviceType::MAIA;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMPS = DeviceType::MPS;
constexpr DeviceType kMeta = DeviceType::Meta;
constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kHPU = DeviceType::HPU;
constexpr DeviceType kVE = DeviceType::VE;
constexpr DeviceType kLazy = DeviceType::Lazy;
constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMTIA = DeviceType::MTIA;
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
static_assert(
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
"for this constant to reflect the actual number of DeviceTypes we support "
"in PyTorch; it's important that this number is not too large as we "
"use this to allocate stack arrays in some places in our code. If you "
"are indeed just adding the 20th device type, feel free to change "
"the check to 32; but if you are adding some sort of extensible device "
"types registration, please be aware that you are affecting code that "
"this number is small. Try auditing uses of this constant.");
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
C10_API bool isValidDeviceType(DeviceType d);
@ -108,15 +24,6 @@ C10_API bool is_privateuse1_backend_registered();
} // namespace c10
namespace std {
template <>
struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
namespace torch {
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::DeviceType;

View File

@ -0,0 +1,21 @@
#include <c10/core/impl/HermeticPyObjectTLS.h>
namespace c10::impl {
thread_local static std::atomic<bool> hermeticPyObjectState{false};
std::atomic<bool> HermeticPyObjectTLS::haveState_{false};
void HermeticPyObjectTLS::set_state(bool state) {
hermeticPyObjectState = state;
}
bool HermeticPyObjectTLS::get_tls_state() {
return hermeticPyObjectState;
}
void HermeticPyObjectTLS::init_state() {
haveState_ = true;
}
} // namespace c10::impl

View File

@ -0,0 +1,62 @@
#pragma once
#include <c10/macros/Export.h>
#include <atomic>
namespace c10::impl {
// This TLS controls whether or not we permanently associate PyObject
// with Tensor the first time it is allocated. When hermetic PyObject
// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor,
// meaning you get a distinct PyObject whenever you execute the code in
// question.
struct C10_API HermeticPyObjectTLS {
static void set_state(bool state);
static bool get_state() {
// Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy
// isn't used. Per
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// this qualifies relaxed access because it is a single-location data
// structure (only the boolean here).
//
// Forgetting about data races for a moment, is there a logical race?
//
// - Boolean only ever transitions from false to true. So the
// critical situation is when one interpreter is already running
// when a second interpreter switches haveState from false to true.
//
// - The first interpreter is indifferent whether or not it sees
// hasState true/false; obviously false works (this is what the
// interpreter was previously using; more directly, the interpreter
// calls into itself as the handler, so being hermetic is not
// required), and true simply means serviced python operator calls will
// be hermetic; in these cases it is expected to be functionally
// equivalent.
//
// - The second interpreter MUST see hasState true (as its requests will
// be forwarded to the first interpreter), but it is assumed that there
// is a synchronization between the interpreter initialization, and
// when we actually perform operations, so it is guaranteed to see
// hasState true.
//
// QED.
//
// This fastpath is currently disabled so that we can more easily test that
// hermetic mode works correctly even on stock build of PyTorch.
if (false && !haveState_.load(std::memory_order_relaxed))
return false;
return get_tls_state();
}
// Call this from the multipy/torchdeploy // codespell:ignore multipy
// top level
static void init_state();
private:
// This only flipped once from false to true during
// torchdeploy/multipy initialization, // codespell:ignore multipy
// and never again.
static std::atomic<bool> haveState_;
static bool get_tls_state();
};
} // namespace c10::impl

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PyInterpreter.h>
#include <c10/core/impl/PyInterpreterHooks.h>
#include <c10/util/python_stub.h>
@ -41,15 +42,32 @@ struct C10_API PyObjectSlot {
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter / PyObj as they may be null
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj() const {
impl::PyInterpreter* interpreter = getGlobalPyInterpreter();
if (interpreter == nullptr || pyobj_ == nullptr) {
return std::nullopt;
}
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
}
return _unchecked_untagged_pyobj();
}

View File

@ -382,6 +382,7 @@ struct ExpandableSegment {
peers_(std::move(peers)) {
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
mapped_size_ = 0;
// we allocate enough address space for 1 1/8 the total memory on the GPU.
// This allows for some cases where we have to unmap pages earlier in the
// segment to put them at the end.
@ -493,6 +494,7 @@ struct ExpandableSegment {
return SegmentRange{range.ptr, 0};
}
unmapHandles(begin, end);
mapped_size_ -= (end - begin) * segment_size_;
return rangeFromHandles(begin, end);
}
@ -632,6 +634,18 @@ struct ExpandableSegment {
return max_handles_ * segment_size_;
}
cudaStream_t getStream() {
return *stream_;
}
size_t getMappedSize() {
return mapped_size_;
}
size_t getSegmentSize() {
return segment_size_;
}
void addPeer(c10::DeviceIndex device) {
peers_.push_back(device);
forEachAllocatedRange(
@ -666,6 +680,7 @@ struct ExpandableSegment {
handles_.at(i).value().handle,
0ULL));
}
mapped_size_ += (end - begin) * segment_size_;
setAccess(device_, begin, end);
for (auto p : peers_) {
setAccess(p, begin, end);
@ -734,6 +749,7 @@ struct ExpandableSegment {
std::optional<cudaStream_t> stream_;
CUdeviceptr ptr_{};
size_t segment_size_;
size_t mapped_size_;
size_t max_handles_;
struct Handle {
CUmemGenericAllocationHandle handle;
@ -779,6 +795,17 @@ struct ExpandableSegment {
size_t size() const {
return 0;
}
cudaStream_t getStream() {
return nullptr;
}
size_t getMappedSize() {
return 0;
}
size_t getSegmentSize() {
return 0;
}
void addPeer(c10::DeviceIndex device) {}
};
#endif
@ -1183,6 +1210,16 @@ class DeviceCachingAllocator {
// ends.
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
// Incremental reverse-traversal state cached per graph.
// We never re-traverse nodes we've already seen
struct GraphReuseContext {
ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
visited;
};
ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
mempool_to_capture_id;
ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;
// outstanding cuda events
ska::flat_hash_map<
cuda::CUDAStream,
@ -1638,44 +1675,70 @@ class DeviceCachingAllocator {
return block;
}
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
struct CaptureInfo {
cudaGraph_t graph{};
CaptureId_t capture_id{0};
const cudaGraphNode_t* terminals{nullptr};
size_t num_terminals{0};
cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
};
inline CaptureInfo stream_get_capture_info(cudaStream_t stream) {
CaptureInfo info{};
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream,
&info.status,
&info.capture_id,
&info.graph,
&info.terminals,
nullptr,
&info.num_terminals));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream,
&info.status,
&info.capture_id,
&info.graph,
&info.terminals,
&info.num_terminals));
#endif
TORCH_INTERNAL_ASSERT(
info.status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
return info;
}
// Record "free marker" of the CUDA graph for all streams that
// have used the block, including the allocation stream. These nodes mark the
// last use of the block in the capture graph. Returns a vector of the
// inserted nodes, or an empty vector if any stream is not capturing.
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
std::vector<cudaGraphNode_t> empty_nodes;
std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
// Is is possible to have the same marker recorded multiple times, so we use
// a set to avoid duplicates
ska::flat_hash_set<cudaGraphNode_t> markers;
cudaGraph_t owning_graph = nullptr;
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* deps = nullptr;
size_t num_deps = 0;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &deps, &num_deps));
#endif
TORCH_INTERNAL_ASSERT(
status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
if (status == cudaStreamCaptureStatusNone) {
return false;
auto try_record = [&](cudaStream_t s) -> bool {
auto info = stream_get_capture_info(s);
if (info.status == cudaStreamCaptureStatusNone) {
return false; // not capturing on this stream -> must defer
}
cudaGraphNode_t node{};
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
#else
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, 1, cudaStreamSetCaptureDependencies));
#endif
empty_nodes.push_back(node);
if (owning_graph == nullptr) {
owning_graph = info.graph;
}
TORCH_INTERNAL_ASSERT(
info.graph == owning_graph,
"All streams in the same capture should agree on the graph");
// Use current terminals as the free markers for the stream
for (size_t i = 0; i < info.num_terminals; ++i) {
auto terminal = info.terminals[i];
markers.insert(terminal);
}
owning_graph = info.graph; // all streams in the same capture should agree
return true;
};
@ -1683,81 +1746,34 @@ class DeviceCachingAllocator {
// An empty vector indicates that the block should be deferred for freeing
// until after capture.
// Attempt to add an empty node for the allocation stream.
if (!try_add_empty_node(block->stream)) {
// Allocation stream
if (!try_record(block->stream)) {
return {};
}
// Attempt to add empty nodes for all streams that have used the block.
// Any extra streams that used this block
for (const auto& s : block->stream_uses) {
if (!try_add_empty_node(s.stream())) {
if (!try_record(s.stream())) {
return {};
}
}
return empty_nodes;
return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
}
// Returns the current set of "terminal" nodes in the CUDA graph for a given
// stream. These represent the current endpoints of the stream, and may
// include additional nodes if the graph branches. Any new work captured will
// be attached after one or more of these terminals.
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
std::vector<cudaGraphNode_t> result;
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* dependencies = nullptr;
size_t num_dependencies = 0;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream,
&status,
nullptr,
&graph,
&dependencies,
nullptr,
&num_dependencies));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
#endif
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatusActive,
"Invalid stream capture status");
for (size_t i = 0; i < num_dependencies; i++) {
auto node = dependencies[i];
if (node != nullptr) {
result.push_back(node);
}
}
return result;
}
// Returns the set of "reusable" free markers (empty nodes) in the current
// Returns the set of "reusable" free markers in the current
// CUDA graph capture. A free marker is considered reusable if it is a
// predecessor of every terminal node.
// This ensures that all future captured work will occur after the free
// marker, making it safe to reuse.
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
cudaStream_t stream) {
auto terminals = get_terminals(stream);
if (terminals.empty()) {
// No terminal nodes found; nothing to free.
return {};
}
auto get_dependencies = [](cudaGraphNode_t node,
cudaGraphNode_t* pDependencies,
size_t* pNumDependencies) -> void {
void update_visited(
const CaptureInfo& info,
ska::flat_hash_set<cudaGraphNode_t>& visited) {
// This is the versioned cudaGraphNodeGetDependencies helper function.
auto node_get_dependencies =
[](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
node, pDependencies, nullptr, pNumDependencies));
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
#else
C10_CUDA_CHECK(
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
#endif
};
@ -1765,62 +1781,43 @@ class DeviceCachingAllocator {
auto get_parents =
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
size_t count = 0;
get_dependencies(node, nullptr, &count);
node_get_dependencies(node, nullptr, &count);
std::vector<cudaGraphNode_t> out(count);
if (count) {
get_dependencies(node, out.data(), &count);
node_get_dependencies(node, out.data(), &count);
out.resize(count);
}
return out;
};
// Helper to determine if a node is an empty node (used as a free marker).
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
cudaGraphNodeType type{};
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
return type == cudaGraphNodeTypeEmpty;
};
// For each terminal node, perform a reverse DFS to count, for each empty
// node, how many terminals it can reach (i.e., for how many terminals it is
// a predecessor). An empty node is reusable if it is a predecessor of all
// terminal nodes.
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
for (auto terminal : terminals) {
ska::flat_hash_set<cudaGraphNode_t> visited;
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
std::function<void(cudaGraphNode_t)> reverse_dfs =
[&](cudaGraphNode_t node) {
if (!visited.insert(node).second)
return;
if (is_empty_node(node)) {
num_terminals_reachable[node]++;
empty_nodes.insert(node);
}
auto parents = get_parents(node);
for (auto p : parents) {
reverse_dfs(p);
}
};
reverse_dfs(terminal);
// For each terminal node, perform a reverse DFS to count, for each free
// marker, how many terminals it can reach (i.e., for how many terminals it
// is a predecessor). A free marker is reusable if it is a predecessor of
// all terminal nodes.
std::deque<cudaGraphNode_t> dfs;
for (size_t i = 0; i < info.num_terminals; ++i) {
dfs.push_back(info.terminals[i]);
}
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
for (auto [node, count] : num_terminals_reachable) {
if (count == terminals.size()) {
reusable_empty_nodes.insert(node);
while (!dfs.empty()) {
auto v = dfs.back();
dfs.pop_back();
if (visited.count(v)) {
continue;
}
visited.insert(v);
auto parents = get_parents(v);
for (auto p : parents) {
dfs.push_back(p);
}
}
return reusable_empty_nodes;
}
// A block is considered reusable during CUDA graph capture if every free
// marker (empty node) associated with the block is a predecessor of every
// marker associated with the block is a predecessor of every
// terminal node.
//
// This ensures that any new operation added to the graph will be attached
@ -1829,36 +1826,52 @@ class DeviceCachingAllocator {
// on every stream, so the block's previous lifetime ends before any new
// lifetime begins. This check relies solely on the DAG topology and does not
// require event queries, making it safe to use during capture.
//
// This function iterates over all deferred blocks, determines if their empty
// nodes are reusable according to the above criteria, and frees the block if
// so.
void free_safe_blocks_in_capture(
const std::shared_ptr<GatheredContext>& context,
cudaStream_t stream) {
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
auto info = stream_get_capture_info(stream);
// If there are no reusable empty nodes (e.g., not currently capturing),
// there is nothing to do.
if (reusable_empty_nodes.empty()) {
if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
return;
}
if (graph_reuse_context.find(info.capture_id) ==
graph_reuse_context.end()) {
bool found = false;
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto graph_pool = graph_pools.find(entry.first);
TORCH_INTERNAL_ASSERT(
graph_pool != graph_pools.end(),
"Could not find graph pool for capture.");
auto mempool_id = graph_pool->first;
graph_reuse_context[info.capture_id] = GraphReuseContext{};
mempool_to_capture_id[mempool_id] = info.capture_id;
found = true;
break;
}
}
TORCH_INTERNAL_ASSERT(
found, "Could not find memory pool id for capture.");
}
auto& graph_context = graph_reuse_context[info.capture_id];
auto& visited = graph_context.visited[stream];
update_visited(info, visited);
std::vector<Block*> blocks_to_erase;
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
// Skip this block if it has no empty nodes, as we defer its freeing until
for (auto& [block, markers] : deferred_blocks) {
// Skip this block if it has no markers, as we defer its freeing until
// after graph capture. Also skip if the block was not allocated on the
// current stream; such blocks will be freed when
// free_safe_blocks_in_capture is attempted on that stream.
if (inserted_empty_nodes.empty() || block->stream != stream) {
if (markers.empty() || block->stream != stream) {
continue;
}
bool is_reusable = true;
for (const auto& node : inserted_empty_nodes) {
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
for (auto m : markers) {
if (!visited.count(m)) {
is_reusable = false;
break;
}
@ -1919,11 +1932,11 @@ class DeviceCachingAllocator {
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(!captures_underway.empty())) {
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
// insert_free_marker returns a vector of free markers,
// record_free_markers returns a vector of free markers,
// or an empty vector if any associated stream is not currently
// capturing. The empty vector means that we will defer the free until
// capture is finished.
deferred_blocks.emplace(block, insert_free_marker(block));
deferred_blocks.emplace(block, record_free_markers(block));
} else {
// If graph_capture_record_stream_reuse is not enabled, always defer
// the free until capture is finished.
@ -2025,6 +2038,22 @@ class DeviceCachingAllocator {
set_fraction = true;
}
/** get expandable segment size for all the streams on device **/
std::vector<StreamSegmentSize> getExpandableSegmentSizes() {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<StreamSegmentSize> sizes;
for (auto& segment : expandable_segments_) {
if (!segment->getStream()) {
continue;
}
sizes.emplace_back(
segment->getStream(),
segment->getSegmentSize() == kSmallBuffer,
segment->getMappedSize());
}
return sizes;
}
/** returns cached blocks to the system allocator **/
void emptyCache(MempoolId_t mempool_id) {
auto context = maybeGatherContext(RecordContext::ALL);
@ -2511,6 +2540,21 @@ class DeviceCachingAllocator {
// Called by CUDAGraph::capture_end
void endAllocateToPool(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
!graph_reuse_context.empty()) {
auto capture_id = mempool_to_capture_id[mempool_id];
auto graph_context = graph_reuse_context[capture_id];
for (auto& [stream, _] : graph_context.visited) {
TORCH_INTERNAL_ASSERT(
stream_get_capture_info(stream).status ==
cudaStreamCaptureStatusNone,
"This stream should not be capturing when the capture is ended");
}
graph_reuse_context.erase(capture_id);
mempool_to_capture_id.erase(mempool_id);
}
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
if (it->first == mempool_id) {
@ -3837,6 +3881,16 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->setMemoryFraction(fraction);
}
std::vector<StreamSegmentSize> getExpandableSegmentSizes(
c10::DeviceIndex device) override {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
return device_allocator[device]->getExpandableSegmentSizes();
}
void recordHistory(
bool enabled,
CreateContextFn context_recorder,

View File

@ -203,6 +203,14 @@ struct ShareableHandle {
std::string handle;
};
struct StreamSegmentSize {
StreamSegmentSize(cudaStream_t s, bool small, size_t sz)
: stream(s), is_small_pool(small), total_size(sz) {}
cudaStream_t stream;
bool is_small_pool;
size_t total_size;
};
class CUDAAllocator : public DeviceAllocator {
public:
virtual void* raw_alloc(size_t nbytes) = 0;
@ -211,6 +219,8 @@ class CUDAAllocator : public DeviceAllocator {
virtual void init(int device_count) = 0;
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
virtual std::vector<StreamSegmentSize> getExpandableSegmentSizes(
c10::DeviceIndex device) = 0;
virtual void enable(bool value) = 0;
virtual bool isEnabled() const = 0;
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
@ -365,6 +375,11 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device);
}
inline std::vector<StreamSegmentSize> getExpandableSegmentSizes(
c10::DeviceIndex device) {
return get()->getExpandableSegmentSizes(device);
}
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
return get()->emptyCache(mempool_id);
}

View File

@ -495,6 +495,13 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// introduces performance nondeterminism.
}
std::vector<StreamSegmentSize> getExpandableSegmentSizes(
c10::DeviceIndex device) override {
TORCH_CHECK(
false,
"CUDAMallocAsyncAllocator does not yet support getExpandableSegmentSizes.");
}
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
std::lock_guard<std::mutex> lk(general_mutex);

View File

@ -16,21 +16,11 @@ cuda_supported_platforms = [
"ovr_config//os:windows-cuda",
]
# rocktenn apparently has its own copy of glog that comes with libmp.dll, so we
# had better not try to use glog from c10 lest the glog symbols not be eliminated.
C10_USE_GLOG = native.read_config("c10", "use_glog", "1") == "1"
# If you don't use any functionality that relies on static initializer in c10 (the
# most notable ones are the allocators), you can turn off link_whole this way.
# In practice, this is only used by rocktenn as well.
C10_LINK_WHOLE = native.read_config("c10", "link_whole", "1") == "1"
def define_c10_ovrsource(name, is_mobile):
pp_flags = []
if is_mobile:
pp_flags.append("-DC10_MOBILE=1")
if C10_USE_GLOG:
pp_flags.append("-DC10_USE_GLOG")
pp_flags = ["-DC10_MOBILE=1"]
else:
pp_flags = []
oxx_static_library(
name = name,
@ -41,7 +31,6 @@ def define_c10_ovrsource(name, is_mobile):
"util/*.cpp",
]),
compatible_with = cpu_supported_platforms,
link_whole = C10_LINK_WHOLE,
compiler_flags = select({
"DEFAULT": [],
"ovr_config//compiler:cl": [
@ -88,7 +77,6 @@ def define_c10_ovrsource(name, is_mobile):
"//arvr/third-party/gflags:gflags",
"//third-party/cpuinfo:cpuinfo",
"//third-party/fmt:fmt",
# For some godforsaken reason, this is always required even when not C10_USE_GLOG
"//third-party/glog:glog",
],
)

View File

@ -2,175 +2,126 @@
#include <arm_neon.h>
#include <arm_neon_sve_bridge.h>
#include <arm_sve.h>
#include <cfloat>
#include <cmath>
#include "c10/macros/Macros.h"
// Log and exp approximations inspired from ACL implementation
/// Select `svlog` accuracy:
/// - 0: original.
/// - 1: more accurate, similar performance.
/// - 2: very high accuracy, a bit lower speed.
#define SVLOG_ACCURACY 2
inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) {
const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f);
const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f);
const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f);
const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f);
const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f);
const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f);
const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f);
const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f);
/// Handle special cases in `svexp`:
/// - 0: original.
/// - 1: use clamp, better performance.
/// - 2: no special case handling.
#define SVEXP_SPECIAL_CLAMP 1
float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x);
float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x);
float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x);
float32x4_t x2 = vmulq_f32(x, x);
float32x4_t D = svget_neonq(svmad_f32_x(
svptrue_b8(),
svset_neonq(svundef_f32(), x),
svset_neonq(svundef_f32(), log_tab_8),
svset_neonq(svundef_f32(), log_tab_4)));
float32x4_t x4 = vmulq_f32(x2, x2);
float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4);
return res;
#if SVLOG_ACCURACY == 2
static inline svfloat32_t svlog(svfloat32_t x) {
const svbool_t ptrue = svptrue_b8();
svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB;
svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f;
svfloat32_t n = svcvt_f32_x(ptrue, u >> 23);
asm("" : "+w"(r)); // NOTE: can improve instruction scheduling.
svfloat32_t r2 = r * r;
svfloat32_t p = -0x1.4F9934p-3f + r * 0x1.5A9AA2p-3f;
svfloat32_t q = -0x1.00187Cp-2f + r * 0x1.961348p-3f;
svfloat32_t y = -0x1.FFFFC8p-2f + r * 0x1.555D7Cp-2f;
return (r + n * 0x1.62E43p-1f) +
(y + (q + (p + -0x1.3E737Cp-3f * r2) * r2) * r2) * r2;
}
#elif SVLOG_ACCURACY == 1
static inline svfloat32_t svlog(svfloat32_t x) {
const svbool_t ptrue = svptrue_b8();
inline float32x4_t vlogq_f32(float32x4_t x) {
const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2)
svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB;
// Extract exponent
int32x4_t m = svget_neonq(svsub_n_s32_x(
svptrue_b8(),
svset_neonq(
svundef_s32(),
vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))),
127));
float32x4_t val = vreinterpretq_f32_s32(
vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23)));
svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f;
svfloat32_t n = svcvt_f32_x(ptrue, u >> 23);
asm("" : "+w"(r)); // NOTE: can improve instruction scheduling.
// Polynomial Approximation
float32x4_t poly = vtaylor_polyq_for_log_f32(val);
svfloat32_t r2 = r * r;
svfloat32_t A = -0x1.923814p-3f + r * 0x1.689E5Ep-3f;
svfloat32_t B = -0x1.FC0968p-3f + r * 0x1.93BF0Cp-3f;
svfloat32_t C = -0x1.000478p-1f + r * 0x1.556906p-2f;
// Reconstruct
poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2);
return (r + n * 0x1.62E43p-1f) + (C + (B + A * r2) * r2) * r2;
}
#elif SVLOG_ACCURACY == 0
static inline svfloat32_t svlog(svfloat32_t x) {
const svbool_t ptrue = svptrue_b8();
svint32_t u = svsra_n_s32(svdup_n_s32(-127), svreinterpret_s32(x), 23);
svfloat32_t n = svcvt_f32_x(ptrue, u);
svfloat32_t r = svreinterpret_f32(svreinterpret_s32(x) - (u << 23));
svfloat32_t D = -0.165253549814f + r * 0.0141278216615f;
svfloat32_t C = -2.47071170807f + r * 0.844007015228f;
svfloat32_t B = -5.68692588806f + r * 4.58445882797f;
svfloat32_t A = -2.29561495781f + r * 5.17591238022f;
svfloat32_t r2 = r * r;
return (A + n * 0.6931471805f) + (B + (C + D * r2) * r2) * r2;
}
#endif
static inline svfloat32_t svexp(svfloat32_t x) {
// Clamp interval set to prevent denormals!
const svfloat32_t max_input = svdup_n_f32(88.722839f);
const svfloat32_t min_input = svdup_n_f32(-87.33654f);
const svfloat32_t shift = svdup_n_f32(0x1.0000FEp+23f);
const svbool_t ptrue = svptrue_b8();
#if SVEXP_SPECIAL_CLAMP == 1
x = svmax_x(ptrue, svmin_x(ptrue, x, max_input), min_input);
#endif
svfloat32_t z = svmla_n_f32_x(ptrue, shift, x, 0x1.715476p+0f);
svfloat32_t n = z - shift;
svfloat32_t scale = svreinterpret_f32(svreinterpret_u32(z) << 23);
svfloat32_t r_hi = x - n * 0x1.62E400p-1f;
svfloat32_t r = r_hi - n * 0x1.7F7D1Cp-20f;
svfloat32_t r2 = r * r;
svfloat32_t C = 0x1.573E2Ep-5f + r * 0x1.0E4020p-7f;
svfloat32_t B = 0x1.FFFDB6p-2f + r * 0x1.555E66p-3f;
svfloat32_t A = r * 0x1.FFFFECp-1f;
svfloat32_t poly = scale + (A + (B + C * r2) * r2) * scale;
#if SVEXP_SPECIAL_CLAMP == 0
const svfloat32_t inf = svdup_n_f32(std::numeric_limits<float>::infinity());
poly = svsel_f32(svcmplt_f32(ptrue, x, min_input), svdup_n_f32(0.0f), poly);
poly = svsel_f32(svcmpgt_f32(ptrue, x, max_input), inf, poly);
#endif
return poly;
}
inline float32x4_t vexpq_f32(float32x4_t x) {
const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6)));
const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb)));
const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33)));
const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17)));
const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010)));
const auto shift = vreinterpretq_f32_u32(
svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f
const auto inv_ln2 = vreinterpretq_f32_u32(
svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f
const auto neg_ln2_hi = vreinterpretq_f32_u32(svget_neonq(
svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
const auto neg_ln2_lo = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(
0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
const auto zero = svdup_n_f32(0.f);
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
// Range reduction:
// e^x = 2^n * e^r
// where:
// n = floor(x / ln(2))
// r = x - n * ln(2)
//
// By adding x / ln(2) with 2^23 + 127 (shift):
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
// forces decimal part
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n)
// + 127 will occupy the whole fraction part of z in FP32 format.
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
// of x / ln(2) (i.e. n) because the decimal part has been pushed out and
// lost.
// * The addition of 127 makes the FP32 fraction part of z ready to be used
// as the exponent
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
const auto z = vfmaq_f32(shift, x, inv_ln2);
const auto n = z - shift;
const auto scale =
vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in term
// of accuracy and performance.
const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi);
const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo);
// Compute the truncated Taylor series of e^r.
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
const auto r2 = r * r;
const auto p1 = c1 * r;
const auto p23 = vfmaq_f32(c2, c3, r);
const auto p45 = vfmaq_f32(c4, c5, r);
const auto p2345 = vfmaq_f32(p23, p45, r2);
const auto p12345 = vfmaq_f32(p1, p2345, r2);
auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale));
auto pHigh = svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input);
auto pLow = svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input);
auto bound = svsel_f32(
pHigh,
inf,
zero);
auto pCombined = svorr_b_z(svptrue_b8(), pLow, pHigh);
// Handle underflow and overflow.
poly = svsel_f32(
pCombined,
bound,
poly);
return svget_neonq(poly);
}
// ln(x) = log2(x) * ln(2)
// pow(x, n) = exp(n * ln(x))
inline float32x4_t compute_batch_box_cox_vec_sve128_float(
static inline svfloat32_t compute_batch_box_cox_vec_sve128_float(
svfloat32_t lambda1_v,
svfloat32_t lambda2_v,
svfloat32_t data_v,
svfloat32_t k_eps) {
// sum_v = lambda2_v + data_v
float32x4_t sum_v = vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v));
const svbool_t ptrue = svptrue_b8();
// test lambda1_v: predNZ == 1 iff lambda1_v != 0
svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0f);
// clamp sum_v: sum_v = max(sum_v, k_eps)
sum_v = vmaxq_f32(sum_v, svget_neonq(k_eps));
// lnData = log(sum_v)
svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(sum_v));
// if any lambda1 != 0, compute pow(sum_v, lambda1) using lnData
// pow(sum_v, lambda1) == exp(lambda1 * ln(sum_v))
svfloat32_t lnData = svlog(svmax_x(ptrue, data_v + lambda2_v, k_eps));
svbool_t predNZ = svcmpne_n_f32(ptrue, lambda1_v, 0.0f);
if (C10_LIKELY(svptest_any(predNZ, predNZ))) {
// mult = lambda1 * ln(sum_v)
float32x4_t mult = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v));
// lambda1_r = 1 / lambda1
svfloat32_t lambda1_r = svdivr_f32_m(predNZ, lambda1_v, svdup_n_f32(1.0f));
// pow = exp(mult)
float32x4_t pow = vexpq_f32(mult);
// merge results
// lnData if lambda1 == 0, (lambda1_r * pow - lambda1_r) if lambda1 != 0
svfloat32_t pow = svexp(lnData * lambda1_v);
lnData = svsel_f32(predNZ, lambda1_r, lnData);
lnData =
svnmsb_f32_m(predNZ, lnData, svset_neonq(svundef_f32(), pow), lnData);
lnData = svnmsb_f32_m(predNZ, lnData, pow, lnData);
}
return svget_neonq(lnData);
return lnData;
}
template <typename T>
@ -186,11 +137,11 @@ template <>
void compute_batch_box_cox_vec_sve128(
std::size_t N,
std::size_t D,
const float* data_ptr,
const float* __restrict lambda1_ptr,
const float* __restrict lambda2_ptr,
float* output_ptr) {
svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
const float *data_ptr,
const float *__restrict lambda1_ptr,
const float *__restrict lambda2_ptr,
float *output_ptr) {
const svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
std::size_t remainder = D % 4;
std::size_t loopBound = D - remainder;
@ -204,17 +155,17 @@ void compute_batch_box_cox_vec_sve128(
svfloat32_t lambda2_v =
svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j));
svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr));
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
svfloat32_t result = compute_batch_box_cox_vec_sve128_float(
lambda1_v, lambda2_v, data_v, k_eps);
vst1q_f32(output_ptr, result);
vst1q_f32(output_ptr, svget_neonq(result));
}
if (C10_LIKELY(remainder > 0)) {
svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound);
svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound);
svfloat32_t data_v = svld1_f32(remainderPred, data_ptr);
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
svfloat32_t result = compute_batch_box_cox_vec_sve128_float(
lambda1_v, lambda2_v, data_v, k_eps);
svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result));
svst1_f32(remainderPred, output_ptr, result);
data_ptr += remainder;
output_ptr += remainder;
}

View File

@ -153,6 +153,7 @@ _ZN3c104impl12PyObjectSlot10owns_pyobjEv
_ZN3c104impl12PyObjectSlot19maybe_destroy_pyobjEv
_ZN3c104impl12PyObjectSlotC1Ev
_ZN3c104impl12PyObjectSlotD2Ev
_ZN3c104impl19HermeticPyObjectTLS13get_tls_stateEv
_ZN3c104impl20TorchDispatchModeTLS13any_modes_setEb
_ZN3c104impl23ExcludeDispatchKeyGuardC1ENS_14DispatchKeySetE
_ZN3c104impl23ExcludeDispatchKeyGuardD2Ev

View File

@ -0,0 +1,72 @@
# Automatic Mixed Precision
## Background
Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.
Key components include:
- [**Autocast**](https://docs.pytorch.org/docs/stable/amp.html#autocasting): Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.
- [**Gradient Scaling**](https://docs.pytorch.org/docs/stable/amp.html#gradient-scaling): Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.
## Design
### Casting Strategy
The [`CastPolicy`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L416-L438) is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.
| Policy | Explanation |
| :--- | :--- |
| **`lower_precision_fp`** | Cast all inputs to `lower_precision_fp` before execute the op. |
| **`fp32`** | Cast all inputs to `at::kFloat` before running the op. |
| **`fp32_set_opt_dtype`** | Execution in `at::kFloat`, while respecting user-specified output dtype if provided. |
| **`fp32_append_dtype`** | Append at::kFloat to the args and redispatch to the type-aware overload |
| **`promote`** | Promote all inputs to the “widest” dtype before execution. |
### Operators Lists
PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.
| Policy | Operators List |
| :--- | :--- |
| **`lower_precision_fp`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L819-L852) |
| **`fp32`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L854-L912) |
| **`fp32_set_opt_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L914-L931) |
| **`fp32_append_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L933-L958) |
| **`promote`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L960-L971) |
## Implementation
### Python Integration
Implement the `get_amp_supported_dtype` method to return the data types supported by the new accelerator in the AMP context.
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py
:language: python
:start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
:end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
:linenos:
```
### C++ Integration
This section shows how AMP registers autocast kernels for the `AutocastPrivateUse1` dispatch key.
- Register a fallback that makes unhandled ops fall through to their normal implementations.
- Register specific aten kernels under `AutocastPrivateUse1` using the `KERNEL_PRIVATEUSEONE` helper macro, which maps an op to the desired precision implementation (with enum `at::autocast::CastPolicy`)
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
:language: c++
:start-after: LITERALINCLUDE START: AMP FALLTHROUTH
:end-before: LITERALINCLUDE END: AMP FALLTHROUTH
:linenos:
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
:language: c++
:start-after: LITERALINCLUDE START: AMP IMPL
:end-before: LITERALINCLUDE END: AMP IMPL
:emphasize-lines: 3,6,8-10
:linenos:
```

View File

@ -44,6 +44,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
autoload
operators
amp
```
[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL"

View File

@ -339,13 +339,16 @@ XLA
~~~
- Jack Cao (`JackCaoG <https://github.com/JackCaoG>`__)
- Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
- Zach Cain (`zcain117 <https://github.com/zcain117>`__)
- Han Qi (`qihqi <https://github.com/qihqi>`__)
- Yifei Teng (`tengyifei <https://github.com/tengyifei>`__)
- Siyuan Liu (`lsy323 <https://github.com/lsy323>`__)
- Brian Hirsh (`bdhirsh <https://github.com/bdhirsh>`__)
- Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
- (emeritus) Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
- (emeritus) Ailing Zhang (`ailzhang <https://github.com/ailzhang>`__)
- (emeritus) Davide Libenzi (`dlibenzi <https://github.com/dlibenzi>`__)
- (emeritus) Alex Suhan (`asuhan <https://github.com/asuhan>`__)
- (emeritus) Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
- (emeritus) Zach Cain (`zcain117 <https://github.com/zcain117>`__)
TorchServe
~~~~~~~~~~

View File

@ -613,8 +613,7 @@ Available options:
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
and reallocate buffers across multiple streams, especially when the capture DAG frequently
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
capturing the graph.
reaches joined frontiers.
.. note::

View File

@ -4,6 +4,7 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
@ -27,7 +28,7 @@ add_executable(test_aoti_abi_check
target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
# WARNING: DO NOT LINK torch!!!
# The purpose is to check if the used aten/c10 headers are writtern in a header-only way
# The purpose is to check if the used aten/c10 headers are written in a header-only way
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})

View File

@ -0,0 +1,35 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/DeviceType.h>
TEST(TestDeviceType, TestDeviceType) {
using torch::headeronly::DeviceType;
constexpr DeviceType expected_device_types[] = {
torch::headeronly::kCPU,
torch::headeronly::kCUDA,
DeviceType::MKLDNN,
DeviceType::OPENGL,
DeviceType::OPENCL,
DeviceType::IDEEP,
torch::headeronly::kHIP,
torch::headeronly::kFPGA,
torch::headeronly::kMAIA,
torch::headeronly::kXLA,
torch::headeronly::kVulkan,
torch::headeronly::kMetal,
torch::headeronly::kXPU,
torch::headeronly::kMPS,
torch::headeronly::kMeta,
torch::headeronly::kHPU,
torch::headeronly::kVE,
torch::headeronly::kLazy,
torch::headeronly::kIPU,
torch::headeronly::kMTIA,
torch::headeronly::kPrivateUse1,
};
for (int8_t i = 0; i <
static_cast<int8_t>(torch::headeronly::COMPILE_TIME_MAX_DEVICE_TYPES);
i++) {
EXPECT_EQ(static_cast<DeviceType>(i), expected_device_types[i]);
}
}

View File

@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo
torch_openreg/
├── CMakeLists.txt
├── csrc
│ ├── amp
│ │ └── autocast_mode.cpp
│ ├── aten
│ │ ├── native
│ │ │ ├── Extra.cpp
@ -59,6 +61,8 @@ torch_openreg/
│ └── stub.c
├── __init__.py
└── openreg
├── amp
│ └── __init__.py
├── __init__.py
├── meta.py
└── random.py
@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
**Key Directories**:
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
- `csrc/amp/`: AMP(Automatic Mixed Precision)
- `csrc/aten/`: Operator registration
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
- `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
- `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
- `torch_openreg/csrc/`: Python C++ binding code.
@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
### Autoload
- Autoload Machanism
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
- Register the backend with Python `entry points`: See `setup` in `setup.py`
- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
- Registering the backend with Python `entry points`: See `setup` in `setup.py`
- Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
### AMP(Automatic Mixed Precision)
`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`.
- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`.
- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py`
## Installation and Usage
@ -168,11 +178,13 @@ print("Result z:\n", z)
print(f"Device of z: {z.device}")
```
## Documentation
Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well.
## Future Plans
- **Enhance Features**:
- Autoload
- AMP
- Device-agnostic APIs
- Memory Management
- Generator
@ -180,5 +192,3 @@ print(f"Device of z: {z.device}")
- Custom Tensor&Storage
- ...
- **Improve Tests**: Add more test cases related to the integration mechanism.
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.

View File

@ -0,0 +1,37 @@
#include <ATen/autocast_mode.h>
using at::Tensor;
Tensor binary_cross_entropy_banned(
const Tensor&,
const Tensor&,
const std::optional<Tensor>&,
int64_t) {
TORCH_CHECK(
false,
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
"or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
"safe to autocast.");
}
// LITERALINCLUDE START: AMP FALLTHROUTH
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
// LITERALINCLUDE END: AMP FALLTHROUTH
// LITERALINCLUDE START: AMP IMPL
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
// lower_precision_fp
KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
// fp32
KERNEL_PRIVATEUSEONE(asin, fp32)
m.impl(
TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&binary_cross_entropy_banned)));
}
// LITERALINCLUDE END: AMP IMPL

View File

@ -0,0 +1,50 @@
# Owner(s): ["module: PrivateUse1"]
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex(
UserWarning,
"In openreg autocast, but the target dtype torch.float32 is not supported.",
):
with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10)
def test_autocast_operator_not_supported(self):
with self.assertRaisesRegex(
RuntimeError,
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.",
):
x = torch.randn(2, 3, device="openreg")
y = torch.randn(2, 3, device="openreg")
with torch.autocast(device_type="openreg", dtype=torch.float16):
_ = torch.nn.functional.binary_cross_entropy(x, y)
def test_autocast_low_precision(self):
with torch.amp.autocast(device_type="openreg", dtype=torch.float16):
x = torch.randn(2, 3, device="openreg")
y = torch.randn(3, 3, device="openreg")
result = torch.mm(x, y)
self.assertEqual(result.dtype, torch.float16)
def test_autocast_fp32(self):
with torch.amp.autocast(device_type="openreg"):
x = torch.randn(2, device="openreg", dtype=torch.float16)
result = torch.asin(x)
self.assertEqual(result.dtype, torch.float32)
def test_autocast_default_dtype(self):
openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg")
self.assertEqual(openreg_fast_dtype, torch.half)
def test_autocast_set_dtype(self):
for dtype in [torch.float16, torch.bfloat16]:
torch.set_autocast_dtype("openreg", dtype)
self.assertEqual(torch.get_autocast_dtype("openreg"), dtype)
if __name__ == "__main__":
run_tests()

View File

@ -3,6 +3,7 @@ import torch
import torch_openreg._C # type: ignore[misc]
from . import meta # noqa: F401
from .amp import get_amp_supported_dtype # noqa: F401
_initialized = False

View File

@ -0,0 +1,9 @@
import torch
# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
def get_amp_supported_dtype():
return [torch.float16, torch.bfloat16]
# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE

View File

@ -15,6 +15,9 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._composable import checkpoint, replicate
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import (
FSDPModule,
@ -58,6 +61,7 @@ from torch.testing._internal.common_fsdp import (
)
from torch.testing._internal.common_utils import run_tests, TEST_XPU, xfailIf
from torch.testing._internal.distributed._tensor.common_dtensor import (
FeedForward,
ModelArgs,
Transformer,
TransformerBlock,
@ -1010,6 +1014,222 @@ class TestFullyShardPrefetch(FSDPTest):
self.assertEqual(events, expected_backward_events)
events.clear()
@skip_if_lt_x_gpu(2)
def test_set_modules_to_backward_prefetch_inside_ac(self):
n_layers = 3
reshard_after_forward = True
# use checkpoint wrapper instead of torch.utils
model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=False)
model = Transformer(model_args)
apply_activation_checkpointing(
model, check_fn=lambda m: isinstance(m, TransformerBlock)
)
apply_activation_checkpointing(
model, check_fn=lambda m: isinstance(m, FeedForward)
)
fully_shard([model.tok_embeddings, model.pos_embeddings])
for layer in model.layers:
# mimic fully_shard(layer.moe.experts)
fully_shard(
layer.feed_forward.w1, reshard_after_forward=reshard_after_forward
)
fully_shard(layer, reshard_after_forward=reshard_after_forward)
fully_shard(
[model.norm, model.output], reshard_after_forward=reshard_after_forward
)
fully_shard(model, reshard_after_forward=reshard_after_forward)
inp = torch.randint(
0,
model_args.vocab_size,
(2, model_args.max_seq_len),
device=device_type.type,
)
def set_backward_prefetch(model: Transformer) -> None:
# tell pyre model.set_modules_to_backward_prefetch is available
assert isinstance(model, FSDPModule)
assert isinstance(model.output, FSDPModule)
# mimic deepseek MOE
# prefetch layer - 1 and its feedforward before cpu sync during a2a
reversed_transformer_blocks = list(reversed(model.layers))
prev_transformer_blocks = reversed_transformer_blocks[1:] + [None]
if (
model.norm is not None
and model.output is not None
and len(model.layers) > 0
):
assert isinstance(reversed_transformer_blocks[0], FSDPModule)
model.output.set_modules_to_backward_prefetch(
[reversed_transformer_blocks[0]]
)
for transformer_block, prev_transformer_block in zip(
reversed_transformer_blocks, prev_transformer_blocks
):
assert isinstance(transformer_block, FSDPModule)
if prev_transformer_block is not None:
assert isinstance(prev_transformer_block, FSDPModule)
assert hasattr(prev_transformer_block.feed_forward, "w1")
assert isinstance(
prev_transformer_block.feed_forward.w1, FSDPModule
)
transformer_block.set_modules_to_backward_prefetch(
[
prev_transformer_block,
prev_transformer_block.feed_forward.w1,
]
)
elif model.tok_embeddings is not None:
assert isinstance(model.tok_embeddings, FSDPModule)
transformer_block.set_modules_to_backward_prefetch(
[model.tok_embeddings]
)
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
with (
patch_unshard(unshard_with_record),
patch_reshard(reshard_with_record),
):
loss = model(inp)
events.clear()
loss.sum().backward()
expected_backward_events = [
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
# layers.2 prefetch w1
(
"unshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
# layers.2.w1 prefetch layers.1
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
(
"reshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
(
"unshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
(
"reshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
(
"unshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
(
"unshard",
"tok_embeddings, pos_embeddings",
TrainingState.PRE_BACKWARD,
),
(
"reshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()
set_backward_prefetch(model)
loss = model(inp)
events.clear()
loss.sum().backward()
expected_backward_events = expected_backward_events = [
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
# root explicit prefetch layers.2
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
# layers.2 prefetch layers.1 and feed_forward
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
(
"unshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
# AC recompute_fn
(
"unshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.FORWARD,
),
(
"reshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
# layers.1 prefetch layers.0
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
(
"unshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
(
"reshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
# layers.0 prefetch embeddings
(
"unshard",
"tok_embeddings, pos_embeddings",
TrainingState.PRE_BACKWARD,
),
(
"reshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()
@skip_if_lt_x_gpu(2)
def test_fully_shard_multi_module_backward_prefetch(self):
n_layers = 5

View File

@ -0,0 +1,626 @@
# Owner(s): ["oncall: distributed"]
import copy
import dataclasses
import functools
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed._composable.replicate_with_fsdp import replicate
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
_get_gradient_divide_factors,
)
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
SaveForwardInputsModel,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
patch_reduce_scatter,
reduce_scatter_with_assert,
)
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocmVersionLessThan,
TEST_HPU,
)
device_type = torch.device(get_devtype())
class TestReplicateMixedPrecisionTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.get_device_module(device_type).device_count())
def _init_models_and_optims(
self,
reshard_after_forward: Union[bool, int],
param_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
use_shard_placement_fn,
):
torch.manual_seed(42)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = -1
largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
assert largest_dim >= 0, f"{param.shape}"
return Shard(largest_dim)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
replicate_fn = functools.partial(
replicate,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=shard_placement_fn,
)
for mlp in model:
replicate_fn(mlp)
replicate_fn(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
return ref_model, ref_optim, model, optim
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
use_shard_placement_fn_vals = [False]
if self.world_size == 2:
# For world size >2, gradient elements get reduced in different
# orders for the baseline vs. dim-1 sharding, leading to numeric
# differences for bf16 reduction, so only test world size 2.
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"param_dtype": [torch.bfloat16, torch.float16],
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_compute_dtype,
)
def _test_compute_dtype(
self,
param_dtype: torch.dtype,
reshard_after_forward: Union[bool, int],
use_shard_placement_fn: bool,
):
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=None,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, param_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
predivide_factor, postdivide_factor, _, _ = _get_gradient_divide_factors(
self.process_group, all_reduce_group=None, reduce_dtype=param_dtype
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
with patch_reduce_scatter(reduce_scatter):
fsdp_loss.backward()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
ref_loss.backward()
for param in ref_model_bf16.parameters():
# Use reduce-scatter -> all-gather as all-reduce because for
# world size >=4, NCCL all-reduce shows numeric differences
# compared with NCCL reduce-scatter
if predivide_factor is not None and predivide_factor > 1:
param.grad.div_(predivide_factor)
elif predivide_factor is None:
param.grad.div_(self.world_size)
output = torch.zeros_like(torch.chunk(param.grad, self.world_size)[0])
dist.reduce_scatter_tensor(output, param.grad)
dist.all_gather_into_tensor(param.grad, output)
if postdivide_factor is not None and postdivide_factor > 1:
param.grad.div_(postdivide_factor)
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_fp32.grad = param_bf16.grad.to(param_fp32.dtype)
param_bf16.grad = None
ref_optim.step() # fp32 optimizer step
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_bf16.detach().copy_(param_fp32)
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": [False, True],
},
self._test_reduce_dtype_fp32_reduce,
)
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_reduce_dtype_bf16_reduce,
)
def _test_reduce_dtype_fp32_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
if (
self.world_size > 2
and isinstance(reshard_after_forward, int)
and use_shard_placement_fn
):
return
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, reduce_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
with patch_reduce_scatter(reduce_scatter):
fsdp_loss.backward()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
ref_loss.backward()
for param in ref_model_bf16.parameters():
param.grad.data = param.grad.to(torch.float32)
dist.all_reduce(param.grad) # fp32 reduction
param.grad.div_(self.world_size)
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_fp32.grad = param_bf16.grad
param_bf16.grad = None
ref_optim.step() # fp32 optimizer step
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_bf16.detach().copy_(param_fp32)
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
def _test_reduce_dtype_bf16_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
group = dist.distributed_c10d._get_default_group()
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, reduce_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
with patch_reduce_scatter(reduce_scatter):
fsdp_loss.backward()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
ref_loss = ref_model(inp).sum()
ref_loss.backward()
for param in ref_model.parameters():
param_grad = param.grad.to(reduce_dtype)
# Use reduce-scatter -> all-gather to implement all-reduce
# since for world size >2, bf16 all-reduce and reduce-scatter
# have numeric differences
sharded_grad = funcol.reduce_scatter_tensor(
param_grad, scatter_dim=0, reduceOp="avg", group=group
) # bf16 reduction
param.grad = funcol.all_gather_tensor(
sharded_grad, gather_dim=0, group=group
).to(param.dtype) # upcast to fp32
ref_optim.step() # fp32 optimizer step
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
@skip_if_lt_x_gpu(2)
def test_grad_acc_with_reduce_dtype(self):
"""
Tests that gradient accumulation without reduce-scatter when using
bf16 compute and fp32 reduction accumulates the unsharded gradients in
fp32.
"""
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_grad_acc_with_reduce_dtype,
)
def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool):
torch.manual_seed(42)
param_dtype, reduce_dtype = (torch.bfloat16, torch.float32)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
# To emulate the mixed precision implementation where forward/backward
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
# and a bf16 copy of the reference model
ref_model = copy.deepcopy(model).to(device_type)
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
replicate(
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, reduce_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
device = device_type
# Train on the same input to avoid loss explosion
num_microbatches = 4
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
for iter_idx in range(10):
microbatch_inps = torch.chunk(inp, 4)
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch)
model.set_reshard_after_backward(
is_last_microbatch or reshard_after_forward
)
losses: list[torch.Tensor] = []
for _model in (ref_model_compute, model):
losses.append(
_model(microbatch_inps[microbatch_idx].detach()).sum()
)
self.assertEqual(losses[-1].dtype, param_dtype)
with patch_reduce_scatter(reduce_scatter):
losses[-1].backward()
self.assertEqual(losses[0], losses[1])
# Manually accumulate gradients into the base reference model
# from the compute reference model in fp32
for ref_param, ref_param_compute in zip(
ref_model.parameters(), ref_model_compute.parameters()
):
self.assertTrue(ref_param_compute.grad is not None)
self.assertEqual(ref_param.dtype, torch.float32)
if ref_param.grad is not None:
ref_param.grad += ref_param_compute.grad
else:
ref_param.grad = ref_param_compute.grad.to(ref_param.dtype)
ref_param_compute.grad = None
# Manually reduce gradients for the reference model on the last
# microbatch to implement data parallelism
if is_last_microbatch:
for ref_param in ref_model.parameters():
self.assertTrue(ref_param.grad is not None)
dist.all_reduce(ref_param.grad)
ref_param.grad /= self.world_size
check_sharded_parity(self, ref_model, model)
ref_optim.step()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
# Manually copy parameters from the base reference model to the
# compute reference model to run the optimizer step for the latter
for ref_param, ref_param_compute in zip(
ref_model.parameters(), ref_model_compute.parameters()
):
ref_param_compute.detach().copy_(ref_param)
class TestReplicateMixedPrecisionCasts(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(1)
def test_float16_on_one_submodule(self):
x = torch.zeros(2, 100, device=device_type)
# Subtest 1: use fp16 on the second child submodule -- does not require
# any additional casting logic
forward_inputs: dict[str, nn.Module] = {}
model = SaveForwardInputsModel(
forward_inputs,
cast_forward_inputs=False,
).to(device_type)
replicate(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
# Subtest 2: use fp16 on the second child module, where the user module
# owns the cast
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=True
).to(device_type)
replicate(
model.c2,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, cast_forward_inputs=False
),
)
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
# Subtest 3: use fp16 on the first child module and specify its output
# dtype so that the second child module does not need to cast
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False
).to(device_type)
replicate(
model.c1,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, output_dtype=torch.float32
),
)
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
@skip_if_lt_x_gpu(1)
def test_submodules_with_external_inputs(self):
self.run_subtests(
{"enable_submodule_cast": [False, True]},
self._test_submodules_with_external_inputs,
)
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
class ToyModule(nn.Module):
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
self.forward_inputs["l2_input_x"] = x
self.forward_inputs["l2_input_y"] = y
return self.l(x)
class ToyModel(nn.Module):
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs)
self.forward_inputs = forward_inputs
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.forward_inputs["model_input_x"] = x
y = torch.ones(
2, 100, device=device_type.type, dtype=torch.float32
) # external input
return self.l2(self.l1(x), y)
forward_inputs: dict[str, torch.Tensor] = {}
model = ToyModel(forward_inputs).to(device_type)
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
replicate(
model.l2,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
),
)
replicate(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
model(x).sum().backward()
# If we enable `model.l2` to cast (as default), then `l2_input_y` gets
# cast to fp16, and if we disable, then it says as fp32.
self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
self.assertEqual(
forward_inputs["l2_input_y"].dtype,
torch.float16 if enable_submodule_cast else torch.float32,
)
@skip_if_lt_x_gpu(1)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_norm_modules_bf16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
self._test_norm_modules(mp_policy)
@skip_if_lt_x_gpu(1)
def test_norm_modules_fp16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
self._test_norm_modules(mp_policy)
def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
def inner(model: nn.Module, x: torch.Tensor):
# Run forward and backward to check for no type mismatch errors
z = model(x)
self.assertEqual(z.dtype, mp_policy.param_dtype)
z.sum().backward()
# Layer norm
model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((4, 32)))
# Batch norm 1D
model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((4, 32)))
# Batch norm 2D: error in backward from buffer dtype mismatch
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
if TEST_HPU:
inner(model, torch.randn((3, 1, 9, 9)))
else:
with self.assertRaisesRegex(
RuntimeError,
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
):
# Errors in batch norm 2D backward
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: cast buffers down to lower precision
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
# Casting batch norm buffers to the lower precision allows backward
model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: use special mixed precision policy
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
replicate(model[1], mp_policy=bn_mp_policy)
for module in (model[0], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((3, 1, 9, 9)))
@skip_if_lt_x_gpu(1)
def test_clamp_reduce_dtype(self):
# Initialize the model directly in bf16
init_dtype = torch.bfloat16
model = nn.Sequential(
nn.Linear(32, 32, dtype=init_dtype),
nn.Linear(32, 32, dtype=init_dtype),
).to(device_type.type)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
)
# Check that we did not clamp the reduce dtype
self.assertEqual(mp_policy.reduce_dtype, torch.bfloat16)
for module in model:
replicate((module), mp_policy=mp_policy)
replicate(model, mp_policy=mp_policy)
# Check that the reduce-scatter runs in bf16 even after we change the
# model from bf16 to fp32
model.to(torch.float32)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, torch.bfloat16)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
with patch_reduce_scatter(reduce_scatter):
inp = torch.randn((4, 32), device=device_type.type)
loss = model(inp).sum()
loss.backward()
@skip_if_lt_x_gpu(1)
def test_dataclass_input(self):
@dataclasses.dataclass
class Input:
x: torch.Tensor
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._layer = nn.Linear(10, 10)
def forward(self, input: Input):
return self._layer(input.x)
mp_policy = MixedPrecisionPolicy(
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
)
model = Model()
inp = Input(torch.randn(2, 10).cuda())
replicate(model, mp_policy=mp_policy)
loss = model(inp).sum()
loss.backward()
if __name__ == "__main__":
run_tests()

View File

@ -5,6 +5,7 @@ import copy
import functools
import itertools
import unittest
from collections import defaultdict
from collections.abc import Iterable
from typing import Union
@ -17,8 +18,20 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
apply_activation_checkpointing,
)
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
OffloadPolicy,
register_fsdp_forward_method,
)
from torch.distributed.tensor import DTensor, init_device_mesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
@ -26,6 +39,7 @@ from torch.testing._internal.common_fsdp import (
FSDPTest,
FSDPTestMultiThread,
MLP,
MLPStack,
patch_all_gather,
patch_reduce_scatter,
)
@ -842,5 +856,385 @@ class TestReplicateSharedParams(FSDPTest):
self.assertEqual(losses[0], losses[1])
class TestReplicateGradientAccumulation(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_gradient_accumulation(self):
"""
Tests gradient accumulation with/without gradient reduction and
with/without resharding after backward.
"""
shard_size, replicate_size = 1, self.world_size
meshes = init_device_mesh(
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("replicate", "shard"),
)
self.run_subtests(
{
"mesh": [meshes],
"reshard_after_forward": [True, False],
# "all": disable reduce-scatter for all modules
# "root_only": disable reduce-scatter for root's linear only
# "some_mlps": disable reduce-scatter for some MLPs
"mode": ["all", "root_only", "some_mlps"],
"reshard_after_backward": [False, True],
"offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
# For HSDP only:
# `True`: reduce-scatter only (no all-reduce) each microbatch
# until the last microbatch
# `False`: neither reduce-scatter nor all-reduce each
# microbatch until the last microbatch
"reduce_scatter_only": [False, True],
},
self._test_gradient_accumulation,
)
def _test_gradient_accumulation(
self,
mesh: DeviceMesh,
reshard_after_forward: Union[bool, int],
mode: str,
reshard_after_backward: bool,
offload_policy: OffloadPolicy,
reduce_scatter_only: bool, # for HSDP
):
if (
(
not reshard_after_backward
and (reshard_after_forward is not False or mode == "some_mlps")
)
or (
isinstance(offload_policy, CPUOffloadPolicy)
and reshard_after_forward is not True
)
or (
mesh.ndim != 2
) # may eventually need to change once decision on device mesh is made
):
return # skip since not common or applicable
torch.manual_seed(42)
batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
if mode == "some_mlps":
num_mlps_to_disable_reduce_scatter = 2
modules = [nn.Linear(lin_dim, lin_dim)]
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
model = nn.Sequential(*modules)
ref_model = copy.deepcopy(model).to(device_type)
replicate_fn = functools.partial(
replicate,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
for mlp in model[1:]:
replicate_fn(mlp)
replicate_fn(model) # root gets the 1st linear
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
def set_grad_sync_flag(
module: nn.Module, is_last_microbatch: bool, recurse: bool = True
):
if reduce_scatter_only:
module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
else:
module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)
def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
if mode == "all":
set_grad_sync_flag(_model, is_last_microbatch)
if not reshard_after_backward:
_model.set_reshard_after_backward(is_last_microbatch)
elif mode == "some_mlps":
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
set_grad_sync_flag(mlp, is_last_microbatch)
if not reshard_after_backward:
mlp.set_reshard_after_backward(is_last_microbatch)
elif mode == "root_only":
set_grad_sync_flag(model, is_last_microbatch, recurse=False)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch, recurse=False)
torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
comm_count_list = []
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
set_backward_flags(model, is_last_microbatch)
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
with CommDebugMode() as comm_mode:
losses.append(_model(inp).sum())
losses[-1].backward()
comm_count_list.append(comm_mode.get_comm_counts())
self.assertEqual(losses[0], losses[1])
comm_counts = defaultdict(int)
for comm_count_dict in comm_count_list:
for collective, count in comm_count_dict.items():
comm_counts[collective] += count
all_gather_count = comm_counts[c10d_ops._allgather_base_]
# reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
all_reduce_count = comm_counts[c10d_ops.allreduce_]
# Expect one reduce-scatter per MLP plus one for the root's linear
# on the last microbatch
# expected_reduce_scatter_count = 0
expected_all_reduce_count = num_mlps + 1
if mode == "some_mlps":
# Expect additional reduce-scatters for non-disabled MLPs and
# the root's linear
expected_all_reduce_count += (
num_mlps - num_mlps_to_disable_reduce_scatter + 1
) * (num_microbatches - 1)
elif mode == "root_only":
# Expect additional reduce-scatters for all MLPs
expected_all_reduce_count += (num_mlps) * (num_microbatches - 1)
# self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
self.assertEqual(all_reduce_count, expected_all_reduce_count)
# Expect one all-gather per MLP plus one for the root's linear in
# the first microbatch's forward
expected_all_gather_count = 0
self.assertEqual(all_gather_count, expected_all_gather_count)
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)
for _optim in (optim, ref_optim):
_optim.step()
# When `set_to_none=False`, we are exercising mixing
# gradient accumulation with and without communication
_optim.zero_grad(set_to_none=(iter_idx % 2))
@skip_if_lt_x_gpu(2)
def test_1f1b_microbatching(self):
self.run_subtests(
{
"use_explicit_unshard": [False, True],
"reshard_after_backward": [False, True],
},
self._test_1f1b_microbatching,
)
def _test_1f1b_microbatching(
self, use_explicit_unshard: bool, reshard_after_backward: bool
):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
replicate(module, reshard_after_forward=False)
replicate(model, reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
num_microbatches = 3
local_batch_size = 2
torch.manual_seed(42 + self.rank + 1)
inps = [
torch.randint(
0,
model_args.vocab_size,
(local_batch_size, 16),
device=device_type.type,
)
for _ in range(num_microbatches)
]
# Before pipelining, we may prefer to issue all all-gathers ahead of
# time to increase overlap opportunity at no difference in parameter
# memory usage since we do not reshard after forward
if use_explicit_unshard:
for module in model.modules():
if isinstance(module, FSDPModule):
module.unshard(async_op=True)
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
# last microbatch
losses: list[torch.Tensor] = []
ref_losses: list[torch.Tensor] = []
for inp_idx, inp in enumerate(inps):
is_last_microbatch = inp_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch)
model.set_is_last_backward(is_last_microbatch)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch)
losses.append(model(inp).sum())
losses[-1].backward()
ref_losses.append(ref_model(inp).sum())
ref_losses[-1].backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
for loss, ref_loss in zip(losses, ref_losses):
self.assertEqual(loss, ref_loss)
optim.step()
ref_optim.step()
check_sharded_parity(self, ref_model, model)
class TestReplicateCustomForwardMethod(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_register_fsdp_forward_method(self):
class VisionTransformer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs).sum(dim=1)
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
# Run `vit.forward_features`, which is not `forward`!
patch_embeddings = self.vit.forward_features(imgs)
return self.projector(patch_embeddings)
torch.manual_seed(42)
model = Model()
ref_model = copy.deepcopy(model).to(device_type)
replicate(model.vit)
replicate(model.projector)
replicate(model)
register_fsdp_forward_method(model.vit, "forward_features")
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)
class TestReplicateTPTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.get_device_module(device_type).device_count())
def init_global_mesh(self) -> DeviceMesh:
return init_device_mesh(
device_type.type,
(2, 1, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
@skip_if_lt_x_gpu(8)
def test_replicate_tp(self):
global_mesh = self.init_global_mesh()
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
},
functools.partial(self._test_replicate_tp, global_mesh),
)
def _test_replicate_tp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
foreach: bool,
):
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
parallelize_plan = {
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"0.in_proj": ColwiseParallel(use_local_output=False),
"0.out_proj": RowwiseParallel(use_local_output=False),
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": (RowwiseParallel()),
}
model = parallelize_module(model, tp_mesh, parallelize_plan)
for module in model:
if isinstance(module, nn.LayerNorm):
continue
if use_activation_checkpointing:
checkpoint(module)
replicate(module, device_mesh=dp_mesh)
replicate(model, device_mesh=dp_mesh)
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
self.assertIsInstance(p, DTensor)
self.assertEqual(ref_p, p.full_tensor())
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
torch.manual_seed(42 + dp_pg.rank() + 1)
device = device_type
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
losses.append(_model(inp).sum())
losses[-1].backward()
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
for _optim in (ref_optim, optim):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
_optim.step()
self.assertEqual(losses[0], losses[1])
check_sharded_parity(self, ref_model, model)
for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 3)
self.assertEqual(len(p.placements), 3)
self.assertEqual(
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
)
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,158 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
run_tests,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
class SimpleModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
return self.mlp_1(self.mlp_0(input))
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
torch.utils._pytree.register_constant(DTensorSpec)
# install_free_tensors is required for dynamo to work
with torch._dynamo.config.patch(
install_free_tensors=True, inline_inbuilt_nn_modules=True
):
with torch._export.utils._disable_aten_to_metadata_assertions():
ep = torch.export.export(model, (inputs,), strict=True)
# joint_gm produced here is missing the backward region, due to incompatiblility
# between ep.module() and aot_export_joint_with_descriptors.
# Keeping this here to show the issue.
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(inputs)
return aot_export_joint_with_descriptors_alone(gm, inputs)
def aot_export_joint_with_descriptors_alone(model, inputs):
with contextlib.ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
(inputs,),
)
return joint_with_descriptors.graph_module
def _count_op(gm, target):
return sum(1 for node in gm.graph.nodes if node.target == target)
@requires_cuda
class DTensorExportTest(TestCase):
def tearDown(self):
super().tearDown()
dist.destroy_process_group()
def setUp(self):
super().setUp()
self.world_size = 8
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=self.world_size, store=store
)
self.device_type = "cuda"
def _run_test(self, export_fn):
dp_degree = 2
tp_degree = self.world_size // dp_degree
# 2-D mesh is [dp, tp]
mesh_2d = init_device_mesh(
self.device_type,
mesh_shape=(dp_degree, tp_degree),
mesh_dim_names=["dp", "tp"],
)
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
joint_gm, None, num_fwd_outputs=1
)
self.assertTrue(
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
3,
)
self.assertTrue(
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
2,
)
self.assertTrue(
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
1,
)
@parametrize(
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors,
aot_export_joint_with_descriptors_alone,
],
)
def test_export_parallelize_module_with_dtensor_input(
self,
export_fn,
):
self._run_test(export_fn)
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
instantiate_parametrized_tests(DTensorExportTest)
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import contextlib
import itertools
import torch
@ -355,7 +356,7 @@ class RedistributeTest(DTensorTestBase):
replica_spec = Replicate()
# 1) test replicate -> partial forward
replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
with self.assertRaisesRegex(RuntimeError, "Can not redistribute"):
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
from torch.distributed.tensor._redistribute import Redistribute
@ -619,6 +620,38 @@ class RedistributeTest(DTensorTestBase):
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(out.placements, [Shard(0), dst])
@with_comms
def test_redistribute_to_partial(self):
mesh = init_device_mesh(self.device_type, (2, 2))
tensor = torch.randn(12, 8, device=self.device_type)
test_cases = [
# Partial to Partial is allowed
([Partial(), Shard(0)], [Partial(), Shard(0)], True),
([Partial(), Shard(0)], [Partial(), Shard(1)], True),
([Shard(0), Partial()], [Replicate(), Partial()], True),
([Shard(0), Partial("prod")], [Replicate(), Partial("prod")], True),
# Non-Partial to Partial is NOT allowed
([Shard(0), Replicate()], [Shard(0), Partial()], False),
([Shard(0), Replicate()], [Replicate(), Partial()], False),
([Shard(0), Shard(1)], [Replicate(), Partial()], False),
# Partial to partial is allowed, if only the reduction ops is the same
([Shard(0), Partial("prod")], [Replicate(), Partial("sum")], False),
]
for src, dst, allow in test_cases:
dt = DTensor.from_local(tensor, mesh, src)
raise_context = (
self.assertRaisesRegex(RuntimeError, "Can not redistribute")
if not allow
else contextlib.nullcontext()
)
with raise_context:
out = dt.redistribute(mesh, dst)
self.assertEqual(out.placements, dst)
instantiate_parametrized_tests(RedistributeTest)

View File

@ -0,0 +1,757 @@
# flake8: noqa: B950
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.utils import counters, same
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
DynamoDistributedMultiProcTestCase,
requires_accelerator_dist_backend,
)
aten = torch.ops.aten
import functools
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
# for tests, assume a matmul can hide a single collective
if "c10" in str(fx_node.target):
return 1.0
elif fx_node.target == aten.mm.default:
return compute_multiplier
else:
return None
device_type = str(get_devtype())
def apply_reordering_and_get_graph(graph, out_li) -> None:
gm = graph.owning_module
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
schedule_overlap_bucketing(gm)
gm.graph.lint()
out_li.append(str(gm.graph))
def run_and_get_aten_graph(fn, *inputs):
li = []
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
with torch._inductor.config.patch(post_grad_custom_post_pass=apply):
out = fn(*inputs)
return out, li[0]
def get_patches():
return {
"test_configs.estimate_aten_runtime": estimate_aten_runtime,
"reorder_for_locality": False,
"reorder_for_compute_comm_overlap_passes": [],
"compile_threads": 1,
"force_disable_caches": True,
}
@requires_accelerator_dist_backend()
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
Note: these tests are a fork of test/distributed/test_compute_comm_reordering.py
"""
def setUp(self):
super().setUp()
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}
@property
def world_size(self) -> int:
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_patches())
def test_sink_waits(self):
def func(a):
ar = _functional_collectives.all_reduce(a, "sum", "0")
b = torch.matmul(a, a)
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but
# above the 2nd matmul.
(
FileCheck()
.check("all_reduce.default")
.check("aten.mm.default")
.check("wait_tensor.default")
.check("aten.mm.default")
.run(aten_graph_str)
)
correct = func(inputs)
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@torch._inductor.config.patch(get_patches())
def test_raise_comms(self):
def func(a):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce((b + 1), "sum", "0")
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
# Verify that the all_reduce_ has been raised above the 2nd matmul
# but below the 1st matmul. Note that the all_reduce_ directly
# writes to the output buffer of the 1st matmul, which is an input
# to the first relu. Therefore, the all_reduce_ should be scheduled
# after the first relu.
(
FileCheck()
.check("aten.mm")
.check("all_reduce.default")
.check("aten.mm")
.check("wait_tensor.default")
.check("aten.mm")
.run(aten_graph_str)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@torch._inductor.config.patch(get_patches())
def test_sink_waits_raise_comms(self):
def func(a, *, tag, ranks, group_size):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
f = torch.relu(d)
g = torch.matmul(f, f)
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(
4, 4, dtype=torch.float, device=device_type
) # + self.rank
kwargs = self.get_world_trs()
func = functools.partial(func, **kwargs)
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# Things to verify:
# - The all_reduce_ and its prologue should be raised above the 2nd
# matmul but below the 1st matmul.
# - The wait_tensor should be sinked below the 3rd matmul but above
# the 4th matmul.
self.assertExpectedInline(
aten_graph_str,
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%mm : [num_users=2] = call_function[target=torch.ops.aten.mm.default](args = (%arg0_1, %arg0_1), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {})
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%mm, sum, 0), kwargs = {})
%mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu, %relu), kwargs = {})
%relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm_1,), kwargs = {})
%mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu_1, %relu_1), kwargs = {})
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
%mm_3 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%wait_tensor, %mm_2), kwargs = {})
return (mm_3,)""",
)
# Note: this triggered an all_reduce_ bug
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@torch._inductor.config.patch(get_patches())
def test_reorder_compute_for_overlap_mul(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
func_c = functools.partial(func, **self.get_world_trs())
compiled = torch.compile(func_c)
out_c, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# Note: because we have given collectives and mms equal estimation,
# we overlap each collective with a single mm.
# Same schedule as in test_reorder_compute_for_overlap_custom_runtime_estimation
# although there is an exposed collective
(
FileCheck()
.check("all_reduce.default")
.check("aten.mm")
.check("aten.mm")
.check("wait_tensor.default")
.check("aten.mul")
.check("all_reduce.default")
.check("wait_tensor.default")
.check("aten.mm")
.run(aten_graph_str)
)
correct = func(inputs, **self.get_world_trs())
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 1)
self.assertTrue(same(out_c, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@unittest.skipIf(True, "Logic not yet implemented")
@torch._inductor.config.patch(get_patches())
def test_grouped_scheduler_node(self):
def func(a, *, tag, ranks, group_size):
add = a + a
div = add / a
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
mul = a * a
mm = torch.matmul(mul, ar)
return (mm,)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations:
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
# still happens among nodes within a GroupedSchedulerNode.
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
"_c10d_functional.all_reduce_."
).check("triton_poi_fused_mul_1.").run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_patches())
def test_inductor_default_comms_ordering(self):
pg_info = self.get_world_trs()
tag = pg_info["tag"]
ranks = pg_info["ranks"]
group_size = pg_info["group_size"]
g1 = torch.ones(10, 10, device=device_type)
g2 = torch.ones(11, 11, device=device_type)
g3 = torch.ones(12, 12, device=device_type)
@torch.compile
def fn(g1, g2, g3):
handle1 = torch.ops.c10d_functional.all_reduce(
g1, "avg", tag, ranks, group_size
)
handle2 = torch.ops.c10d_functional.all_reduce(
g2, "avg", tag, ranks, group_size
)
handle3 = torch.ops.c10d_functional.all_reduce(
g3, "avg", tag, ranks, group_size
)
# wait on them in a different order
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
return grad3, grad2, grad1
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, self.backend(device_type), fake_pg=True
):
# all_reduces remain in order!
# note: this isnt actually invariant of pass currently..
# but we should keep collectives stable without reordering opportunities
_, code = run_and_get_aten_graph(fn, g1, g2, g3)
FileCheck().check("all_reduce").check_same("arg0_1").check(
"all_reduce"
).check_same("arg1_1").check("all_reduce").check_same("arg2_1").run(code)
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 3)
# these have no overlap opportunities
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
def get_bucket_patches(compute_multiplier=1.0):
estimate_aten_runtime_part = functools.partial(
estimate_aten_runtime, compute_multiplier=compute_multiplier
)
return {
"test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
"test_configs.aten_fx_overlap_preserving_bucketing": True,
"reorder_for_locality": False,
"reorder_for_compute_comm_overlap_passes": [],
"compile_threads": 1,
"force_disable_caches": True,
}
class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_basic_all_gather_bucketing(self):
"""Test that independent all_gather operations get bucketed together."""
def func(a, b, c, *, ranks):
# Three independent all_gathers that should be bucketed
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + 3
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + 4
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + 5
return ag1 + ag2 + ag3
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = (
torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
)
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
inputs_c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(
compiled, inputs_a, inputs_b, inputs_c
)
# Should see a single bucketed all_gather
FileCheck().check_count(
"torch.ops._c10d_functional.all_gather_into_tensor", 1, exactly=True
).run(aten_graph_str)
correct = func(inputs_a, inputs_b, inputs_c, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_reduce_scatter_bucketing(self):
"""Test bucketing of reduce_scatter operations."""
def func(a, b, c):
rs1 = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
rs2 = _functional_collectives.reduce_scatter_tensor(b, "sum", 0, "0")
rs3 = _functional_collectives.reduce_scatter_tensor(c, "sum", 0, "0")
return torch.cat([rs1, rs2, rs3])
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = torch.ones(8, 4, dtype=torch.float, device=device_type)
inputs_b = torch.ones(8, 4, dtype=torch.float, device=device_type) * 2
inputs_c = torch.ones(8, 4, dtype=torch.float, device=device_type) * 3
out, aten_graph_str = run_and_get_aten_graph(
torch.compile(func), inputs_a, inputs_b, inputs_c
)
# Should bucket reduce_scatter ops
FileCheck().check_count(
"torch.ops._c10d_functional.reduce_scatter_tensor", 1, exactly=True
).run(aten_graph_str)
# TODO: debug - on ci this fails.
# correct = func(inputs_a, inputs_b, inputs_c)
# self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_no_bucketing_with_dependent_hiding_nodes(self):
"""Test that collectives with dependent hiding nodes don't get bucketed."""
def func(a, b, *, ranks):
# ag1 could be hidden by mm1
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
mm1 = torch.matmul(a, a)
# ag2 can be hidden by mm2, but mm2 depends on ag1's result
# ag2 start
mm2 = torch.matmul(ag1[:4], b)
# ag2 end
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
return ag1.sum() * ag2.sum() * mm1 * mm2
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type)
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs_a, inputs_b)
# mm2 depends on ag1, so if mm2 is to hide ag2, we can't bucket ag1 and ag2
# because that would create a dependency issue, even though we could bucket them
FileCheck().check_count(
"torch.ops._c10d_functional.all_gather_into_tensor", 2, exactly=True
).run(aten_graph_str)
correct = func(inputs_a, inputs_b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_no_bucketing_when_collective_depends_on_hiding_node(self):
"""Test that collectives don't get bucketed when one depends on another's hiding node."""
def func(a, *, ranks):
# ag1 hidden by mm1
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
mm1 = torch.matmul(a, a)
# ag2 depends on mm1 (which hides ag1)
b = mm1 * 2
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
return ag1.sum() * ag2.sum() * mm1
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type)
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# ag2 depends on mm1 (ag1's hiding node), so they can't be bucketed
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor", 2, exactly=True
).run(aten_graph_str)
correct = func(inputs, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_wait_sink(self):
"""Test that 4 independent all-gathers split bucketed."""
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5
mm1 = torch.matmul(e, e.T)
# Second compute - can hide ag3 and ag4
f = b * 6
mm2 = torch.matmul(f, f.T)
# Use all collective results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# The 4 all gathers can be bucketed, and their waits should be sunk below the mms
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
).check_count("ops.aten.mm", 2, exactly=True).check(
"_c10d_functional.wait_tensor"
).run(aten_graph_str)
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap_blocking(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5 # Use a to avoid fusion
mm1 = torch.matmul(e, e.T)
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
# Use first 8x8 elements to match mm1's shape
intermediate = ag1[:8, :8] + ag2[:8, :8]
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
mm2 = torch.matmul(mm1 + intermediate, c[:8])
# Use all results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# The 4 all gathers can be bucketed, and the wait should be sunk below the mms
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
).check_count("ops.aten.mm", 2, exactly=True).check_count(
"_c10d_functional.wait_tensor", 1, exactly=True
).run(aten_graph_str)
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5 # Use a to avoid fusion
mm1 = torch.matmul(e, e.T)
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
intermediate = ag1[:2, :2] + ag2[:2, :2] # Small slice to minimize compute
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
f = b * 6
# Expand intermediate to match mm1's shape for broadcasting
intermediate_expanded = torch.nn.functional.pad(intermediate, (0, 6, 0, 6))
mm2 = torch.matmul(mm1 + intermediate_expanded, f.T)
# Use all results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# Should have 2 bucketed all-gathers (one for ag1+ag2, one for ag3+ag4)
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor_out", 2, exactly=True
).run(aten_graph_str)
# Verify the ordering - first bucket, then mm1, then second bucket, then mm2
FileCheck().check("_c10d_functional.all_gather_into_tensor_out").check(
"ops.aten.mm"
).check("_c10d_functional.all_gather_into_tensor_out").check(
"ops.aten.mm"
).run(aten_graph_str)
# Verify correctness
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_bucket_exposed_with_hidden_single_overlap(self):
"""Test that exposed and hidden collectives bucket together when overlap is preserved."""
def func(a, b, c, *, ranks):
# ag1 will be hidden by mm1
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
# ag2 and ag3 are exposed (no compute to hide them)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks)
# can only hide one collective
mm1 = torch.matmul(a[:2], a[:2].T) # 2x2 matmul, hides only ag1
# All three can bucket together because:
# bucketing ag1, ag2, ag3 together does not prevent ag1 being hidden by mm1.
return ag1.sum() + ag2.sum() + ag3.sum() + mm1.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
# Should have 1 bucketed operation containing all 3 all-gathers
FileCheck().check_count("wait_tensor.default", 1, exactly=True).run(
aten_graph_str
)
# Verify bucketed collective overlaps with mm1
FileCheck().check("functional.all_gather_into_tensor").check(
"aten.mm"
).check("wait_tensor").run(aten_graph_str)
# Verify correctness
correct = func(a, b, c, ranks=ranks)
self.assertTrue(same(out, correct))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -95,6 +95,18 @@ class TestSerialization(TestCase):
result = _streaming_load(file)
torch.testing.assert_close(result, state_dict)
def test_empty_tensor(self) -> None:
state_dict = {
"empty": torch.zeros(0, 10),
}
file = BytesIO()
_streaming_save(state_dict, file)
file.seek(0)
result = _streaming_load(file, weights_only=False)
self.assertEqual(result, state_dict)
def test_dtensor(self) -> None:
dist.init_process_group(
backend="gloo", rank=0, world_size=1, store=dist.HashStore()

View File

@ -4,7 +4,7 @@ import itertools
import os
import random
from contextlib import nullcontext
from unittest import skip, skipIf
from unittest import skip, skipIf, skipUnless
import torch
import torch.distributed as dist
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
SM100OrLater,
SM89OrLater,
SM90OrLater,
xfailIfSM100OrLater,
)
@ -51,10 +52,6 @@ from torch.testing._internal.common_utils import (
test_contexts = [nullcontext, _test_mode]
# Set environment variable to disable multicast for all tests in this module
# Workaround https://github.com/pytorch/pytorch/issues/162429
os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1"
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@ -430,6 +427,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("gather_dim", [0, 1])
@parametrize(
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
@ -545,6 +543,7 @@ class AsyncTPTest(MultiProcContinuousTest):
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
@skip_if_lt_x_gpu(2)
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("scatter_dim", [0, 1])
@parametrize("rowwise", [True, False])
def test_fused_scaled_matmul_reduce_scatter(

View File

@ -759,6 +759,38 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def test_sac_with_partial_context_fn(self):
class CustomPolicy:
def __init__(self):
super().__init__()
def __call__(self, ctx, out, func, *args, **kwargs):
return CheckpointPolicy.MUST_SAVE
def f(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
context_fn1 = functools.partial(
create_selective_checkpoint_contexts, CustomPolicy()
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
f,
x,
y,
use_reentrant=False,
context_fn=context_fn1,
)
opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True)
a = torch.randn(4, 4, requires_grad=True, device="cpu")
b = torch.randn(4, 4, requires_grad=True, device="cpu")
expected = fn(a, b)
result = opt_fn(a, b)
self.assertEqual(result, expected)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):

View File

@ -203,6 +203,22 @@ class TestAOTCompile(torch._inductor.test_case.TestCase):
actual = compiled_fn(*example_inputs)
self.assertEqual(expected, actual)
def test_aot_compile_disable_guard_check(self):
def fn(x, y):
return x + y
with torch.no_grad():
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
((torch.randn(3, 4), torch.randn(3, 4)), {})
)
inputs = (torch.randn(3, 4), torch.randn(3, 4))
expected = fn(*inputs)
with self.assertRaisesRegex(RuntimeError, "GuardManager check failed"):
compiled_fn(*inputs)
compiled_fn.disable_guard_check()
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
def test_aot_compile_source_info(self):
from torch._dynamo.package import SourceInfo

View File

@ -30,7 +30,7 @@ class CallbackTests(TestCase):
def test_callbacks_with_duplicate_prevention(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with (
callback_handler.install_callbacks(trigger, compile_id),
callback_handler.install_callbacks(trigger, compile_id),
@ -40,7 +40,7 @@ class CallbackTests(TestCase):
def test_counter(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with callback_handler.install_callbacks(trigger, compile_id):
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
@ -56,7 +56,7 @@ class CallbackTests(TestCase):
AssertionError, "Pending callbacks counter cannot become negative."
):
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with callback_handler.install_callbacks(trigger, str(compile_id)):
pass
self.assertEqual(

View File

@ -216,7 +216,7 @@ Unsupported context manager
Hint: If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then it may be the case that it was created outside the compiled region, which Dynamo does not support. Supported context managers can cross graph break boundaries only if they are local non-closure variables, or are intermediate values.
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3)
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on ConstantVariable(int: 3)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html

View File

@ -127,6 +127,8 @@ class GraphModule(torch.nn.Module):
def fn(x):
local_rank = device_mesh.get_local_rank()
global_rank = device_mesh.get_rank()
if "dp" not in device_mesh.mesh_dim_names:
x = x * 2
return x + local_rank + global_rank
x = torch.ones(10)

View File

@ -95,7 +95,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
transformed_code = code_map1[frame.f_code]
return wrap_guarded_code(
GuardedCode(
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
transformed_code,
empty_guard_manager,
CompileId(
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
),
)
)
return ConvertFrameReturn()
@ -105,7 +109,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
transformed_code = code_map2[frame.f_code]
return wrap_guarded_code(
GuardedCode(
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
transformed_code,
empty_guard_manager,
CompileId(
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
),
)
)
return ConvertFrameReturn()

View File

@ -329,7 +329,9 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
package=None,
)
with (
compile_context(CompileContext(CompileId(0, 0))),
compile_context(
CompileContext(CompileId(frame_id=0, frame_compile_id=0))
),
tracing(tracer.output.tracing_context),
tracer.set_current_tx(),
get_metrics_context(),

View File

@ -5448,7 +5448,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
# check ShapeEnv counters compared to binding indices
shape_env = _get_shape_env_from_gm(ep.graph_module)
next_index = next(shape_env.unbacked_symint_counter)
next_index = shape_env.unbacked_symint_counter
shape_env.unbacked_symint_counter += 1
for symbol in bound:
self.assertTrue(symbol_is_type(symbol, SymT.UNBACKED_INT))
self.assertTrue(
@ -10293,6 +10294,28 @@ graph():
ep = export(m, args)
self.assertEqual(ep.module()(*args), m(*args))
def test_cdist_forward_compute_mode_zero_export(self):
class CDistModel(torch.nn.Module):
def __init__(self):
super(CDistModel, self).__init__()
def forward(self, x, y, compute_mode):
return torch.ops.aten._cdist_forward(
x, y, p=2.0, compute_mode=compute_mode
)
x = torch.ones([3, 3])
y = torch.ones([3, 3])
model = CDistModel()
expected_none = model(x, y, None)
ep_none = torch.export.export(model, (x, y, None))
self.assertTrue(torch.equal(ep_none.module()(x, y, None), expected_none))
expected_0 = model(x, y, 0)
ep_0 = torch.export.export(model, (x, y, 0))
self.assertTrue(torch.equal(ep_0.module()(x, y, 0), expected_0))
def test_export_then_compile_tensor_ctor(self):
class M(torch.nn.Module):
def forward(self, scores, mask):

View File

@ -56,8 +56,6 @@ fake_export_failures = {
xfail("masked.var"),
xfail("nn.functional.grid_sample"),
xfail("to_sparse"),
# cannot xfail as it is passing for cpu-only build
skip("nn.functional.scaled_dot_product_attention"),
# following are failing due to OptionalDeviceGuard
xfail("__getitem__"),
xfail("nn.functional.batch_norm"),
@ -80,8 +78,7 @@ def _test_export_helper(self, dtype, op):
sample_inputs_itr = op.sample_inputs("cpu", dtype, requires_grad=False)
mode = FakeTensorMode(allow_non_fake_inputs=True)
# intentionally avoid cuda:0 to flush out some bugs
target_device = "cuda:1"
target_device = "cuda:0"
def to_fake_device(x):
return x.to(target_device)
@ -135,8 +132,10 @@ instantiate_device_type_tests(TestExportOpInfo, globals(), only_for="cpu")
selected_ops = {
"__getitem__",
# "nn.functional.batch_norm", # needs to fix
"nn.functional.conv2d",
"nn.functional.instance_norm",
"nn.functional.multi_margin_loss",
"nn.functional.scaled_dot_product_attention",
"nonzero",
}
selected_op_db = [op for op in op_db if op.name in selected_ops]

View File

@ -924,6 +924,26 @@ def forward(self, x):
loaded_ep = load(buffer)
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
def test_non_float_weight(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(
torch.ones(2, 2, dtype=torch.int8), requires_grad=False
)
def forward(self, x):
return x + self.p
m = M()
sample_inputs = (torch.randn(2, 2),)
ep = torch.export.export(m, sample_inputs)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
loaded_ep = load(buffer)
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
def test_complex_constant(self) -> None:
class M(torch.nn.Module):
def forward(self, x):
@ -1166,7 +1186,8 @@ class TestDeserialize(TestCase):
# check ShapeEnv counters
shape_env = _get_shape_env_from_gm(loaded_ep.graph_module)
next_index = next(shape_env.unbacked_symint_counter)
next_index = shape_env.unbacked_symint_counter
shape_env.unbacked_symint_counter += 1
for symbol in bound:
self.assertTrue(symbol_is_type(symbol, SymT.UNBACKED_INT))
self.assertTrue(

View File

@ -42,6 +42,7 @@ from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
IS_SM90,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
@ -1238,6 +1239,72 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(
TEST_WITH_ROCM or not IS_SM90,
"scaled_grouped_mm is only supported on SM90",
)
@skipIfXpu
def test_scaled_grouped_mm(self):
# Test torch._scaled_grouped_mm AOTI lowering
# cuda only
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, scale_a, scale_b, offsets):
# x: [num_groups, batch, in_features] - FP8 inputs
# weight: [total_out_features, in_features] - FP8 weights (transposed)
# scale_a: [num_groups] - input scales
# scale_b: [num_groups] - weight scales
# offsets: [num_groups] - cumulative output sizes
output = torch._scaled_grouped_mm(
x,
weight.t(),
scale_a=scale_a,
scale_b=scale_b,
offs=offsets,
use_fast_accum=True,
)
return output.half()
dtype = torch.float16
num_groups = 3
batch_size = 64
in_features = 128
out_features_list = [64, 128, 256] # Different output sizes for each group
device = GPU_TYPE
# Calculate offsets (cumulative output sizes)
offsets = torch.cumsum(torch.tensor(out_features_list), dim=0).to(
device, dtype=torch.int32
)
total_out_features = sum(out_features_list)
# Create FP8 input tensors - stacked for all groups
x_fp16 = torch.randn(
num_groups, batch_size, in_features, dtype=dtype, device=device
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
# Create FP8 weight tensor - concatenated and transposed
weight_fp16 = torch.randn(
total_out_features, in_features, dtype=dtype, device=device
)
weight_fp8 = weight_fp16.to(torch.float8_e4m3fn)
# Create scales
scale_a = torch.ones(num_groups, batch_size, device=device, dtype=torch.float32)
scale_b = torch.ones(total_out_features, device=device, dtype=torch.float32)
self.check_model(
Model(),
(x_fp8, weight_fp8, scale_a, scale_b, offsets),
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
@ -7265,14 +7332,16 @@ class AOTInductorLoggingTest(LoggingTestCase):
class TestAOTInductorConfig(TestCase):
def test_no_compile_standalone(self):
with config.patch({"aot_inductor.compile_standalone": False}):
with config.patch({"aot_inductor_mode.compile_standalone": False}):
result = maybe_aoti_standalone_config({})
self.assertEqual(result, {})
def test_compile_standalone_sets_package_cpp(self):
result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True})
result = maybe_aoti_standalone_config(
{"aot_inductor_mode.compile_standalone": True}
)
self.assertEqual(result["aot_inductor.package_cpp_only"], True)
self.assertEqual(result["aot_inductor.compile_standalone"], True)
self.assertEqual(result["aot_inductor_mode.compile_standalone"], True)
self.assertEqual(result["aot_inductor.embed_kernel_binary"], True)
self.assertEqual(
result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip
@ -7280,12 +7349,15 @@ class TestAOTInductorConfig(TestCase):
self.assertEqual(
result["aot_inductor.model_name_for_generated_files"], "aoti_model"
)
self.assertEqual(result["aot_inductor.dynamic_linkage"], False)
def test_compile_standalone_explicit_set(self):
patches = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.package_cpp_only": True,
"aot_inductor.embed_kernel_binary": True,
"aot_inductor.dynamic_linkage": False,
"aot_inductor.link_libtorch": False,
"aot_inductor.emit_multi_arch_kernel": not torch.version.hip,
"aot_inductor.model_name_for_generated_files": "aoti_model",
}
@ -7294,7 +7366,7 @@ class TestAOTInductorConfig(TestCase):
def test_compile_standalone_package_cpp_false_raises(self):
patches = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.package_cpp_only": False,
}
with self.assertRaises(RuntimeError):
@ -7302,7 +7374,7 @@ class TestAOTInductorConfig(TestCase):
with config.patch({"aot_inductor.package_cpp_only": False}):
patches = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
}
with self.assertRaises(RuntimeError):
maybe_aoti_standalone_config(patches)

View File

@ -393,7 +393,7 @@ class TestAOTInductorPackage(TestCase):
# Test compilation when no name is passed in
options = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
}
with (
tempfile.TemporaryDirectory() as tmp_dir,
@ -407,7 +407,7 @@ class TestAOTInductorPackage(TestCase):
# Test compilation when model name is passed in
options = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.model_name_for_generated_files": "linear",
}
with (
@ -422,7 +422,7 @@ class TestAOTInductorPackage(TestCase):
# test invalid model name
options = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.model_name_for_generated_files": "linear/linear",
}
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
@ -448,7 +448,7 @@ class TestAOTInductorPackage(TestCase):
# Test compilation when model name is passed in
options = {
"aot_inductor.compile_standalone": True,
"aot_inductor_mode.compile_standalone": True,
"aot_inductor.model_name_for_generated_files": "cos",
}
with (

View File

@ -0,0 +1,69 @@
# Owner(s): ["module: inductor"]
import tempfile
import unittest
import zipfile
import torch
import torch._inductor.config
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import IS_CI
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
class Simple(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(16, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
class TestAOTInductorWindowsCrossCompilation(TestCase):
@requires_gpu()
def test_simple_so(self):
if IS_CI:
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
# TODO: enable in CI
with torch.no_grad():
device = "cuda"
model = Simple().to(device=device)
example_inputs = (torch.randn(8, 10, device=device),)
batch_dim = torch.export.Dim("batch", min=1, max=1024)
exported = torch.export.export(
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
)
package_path = torch._inductor.aoti_compile_and_package(
exported,
inductor_configs={
"aot_inductor.model_name_for_generated_files": "model",
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.link_libtorch": False,
"aot_inductor.aoti_shim_library": "executorch",
# no fallback ops
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON,CPP",
"max_autotune_conv_backends": "TRITON,CPP",
# simplify things for now
"aot_inductor.precompile_headers": False,
},
)
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(package_path, "r") as zf:
zf.extractall(tmpdir)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests(needs="filelock")

View File

@ -0,0 +1,346 @@
# Owner(s): ["module: inductor"]
import operator
import torch
import torch.fx as fx
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
from torch.testing._internal.common_utils import TestCase
class TestAugmentedGraphHelper(TestCase):
"""Test suite for AugmentedGraphHelper dependency and merge management."""
def setUp(self):
"""Create a simple graph structure for testing."""
# Create a torch.fx.Graph with multiple nodes
self.graph = fx.Graph()
# Create placeholder nodes (inputs)
self.x = self.graph.placeholder("x")
self.y = self.graph.placeholder("y")
# Create computation nodes with specific names for easy reference
self.node_a = self.graph.call_function(
torch.add, args=(self.x, self.y), name="A"
)
self.node_b = self.graph.call_function(
torch.mul, args=(self.node_a, self.x), name="B"
)
self.node_c = self.graph.call_function(
torch.sub, args=(self.node_a, self.y), name="C"
)
self.node_d = self.graph.call_function(
torch.div, args=(self.node_b, self.node_c), name="D"
)
self.node_e = self.graph.call_function(
operator.neg, args=(self.node_d,), name="E"
)
self.node_f = self.graph.call_function(torch.abs, args=(self.node_e,), name="F")
self.node_g = self.graph.call_function(
torch.relu, args=(self.node_f,), name="G"
)
self.node_h = self.graph.call_function(
torch.sigmoid, args=(self.node_g,), name="H"
)
# Create output
self.graph.output(self.node_h)
# Create a mapping of nodes by name for easier access in tests
self.nodes = {}
for node in self.graph.nodes:
if hasattr(node, "name") and node.name in [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
]:
self.nodes[node.name] = node
# Get all nodes and create tracker
self.all_nodes = list(self.graph.nodes)
self.tracker = AugmentedGraphHelper(self.graph)
def get_deps(self, node):
"""Helper to get dependencies for a node."""
return list(getattr(node, "args", []))
# ========== Basic Functionality Tests ==========
def test_initial_state(self):
"""Test that nodes start as singletons."""
for node in self.all_nodes:
merge_set = self.tracker.merge_sets[node]
self.assertEqual(merge_set, {node})
self.assertEqual(len(merge_set), 1)
def test_simple_merge(self):
"""Test merging two nodes."""
node_a = self.nodes["A"]
node_b = self.nodes["B"]
self.merge_nodes(self.tracker, [node_a, node_b])
# Both should be in same merge set
self.assertEqual(self.tracker.merge_sets[node_a], {node_a, node_b})
self.assertEqual(self.tracker.merge_sets[node_b], {node_a, node_b})
self.assertEqual(
self.tracker.merge_sets[node_a], self.tracker.merge_sets[node_b]
)
def test_transitive_merge(self):
"""Test merging already merged nodes."""
node_a = self.nodes["A"]
node_b = self.nodes["B"]
node_c = self.nodes["C"]
node_d = self.nodes["D"]
# Merge A-B and C-D separately
for node in node_b, node_c, node_d:
self.tracker.merge_to_set(node_a, node)
expected_set = {node_a, node_b, node_c, node_d}
for node in [node_a, node_b, node_c, node_d]:
self.assertEqual(self.tracker.merge_sets[node], expected_set)
def merge_nodes(self, tracker, nodes):
for n in nodes[1:]:
tracker.merge_to_set(nodes[0], n)
def test_unmerge_node(self):
"""Test removing a node from its merge set."""
node_a = self.nodes["A"]
node_b = self.nodes["B"]
node_c = self.nodes["C"]
# Merge all three
self.merge_nodes(self.tracker, [node_a, node_b, node_c])
self.assertEqual(len(self.tracker.merge_sets[node_a]), 3)
# Unmerge B
self.tracker.unmerge_node(node_b)
# B should be singleton
self.assertEqual(self.tracker.merge_sets[node_b], {node_b})
# A and C should still be together
self.assertEqual(self.tracker.merge_sets[node_a], {node_a, node_c})
self.assertEqual(self.tracker.merge_sets[node_c], {node_a, node_c})
def test_unmerge_from_singleton(self):
"""Test unmerging a node that's already singleton."""
node_a = self.nodes["A"]
# Should be no-op
self.tracker.unmerge_node(node_a)
self.assertEqual(self.tracker.merge_sets[node_a], {node_a})
# ========== Dependency Propagation Tests ==========
def test_merged_deps_collection(self):
"""Test that dependencies are collected from all merged nodes."""
node_a = self.nodes["A"]
node_b = self.nodes["B"]
node_c = self.nodes["C"]
# B already depends on A (and x) from graph construction
# C already depends on A (and y) from graph construction
# Merge B and C
self.merge_nodes(self.tracker, [node_b, node_c])
# Get merged deps for B - should include deps from both B and C
deps = self.tracker.get_merged_deps(node_b)
# Should include all dependencies from both nodes
self.assertIn(node_a, deps) # From both B and C
self.assertIn(self.x, deps) # From B
self.assertIn(self.y, deps) # From C
def test_extra_deps_with_merge(self):
"""Test extra dependencies work correctly with merged nodes."""
node_a = self.nodes["A"]
node_b = self.nodes["B"]
node_c = self.nodes["C"]
node_d = self.nodes["D"]
# Add extra dep from A to C
self.tracker.add_extra_dep(n=node_a, dep=node_c)
# Merge A and B
self.merge_nodes(self.tracker, [node_a, node_b])
# Add extra dep from D to the merged node (via B)
self.tracker.add_extra_dep(n=node_d, dep=node_b)
# D should depend on B through extra deps
deps = self.tracker.get_merged_deps(node_d)
self.assertIn(node_b, deps)
# A should still have its dep on C
deps = self.tracker.get_merged_deps(node_a)
self.assertIn(node_c, deps)
# ========== Path Finding Tests ==========
def test_has_path_direct(self):
"""Test path finding for direct dependencies."""
# In our graph: B depends on A
node_a = self.nodes["A"]
node_b = self.nodes["B"]
self.assertTrue(self.tracker.has_path(node_a, node_b))
self.assertFalse(self.tracker.has_path(node_b, node_a))
def test_has_path_transitive(self):
"""Test path finding through multiple nodes."""
# In our graph: A -> B -> D and A -> C -> D -> E
node_a = self.nodes["A"]
node_e = self.nodes["E"]
self.assertTrue(self.tracker.has_path(node_a, node_e))
self.assertFalse(self.tracker.has_path(node_e, node_a))
def test_has_path_through_merge(self):
"""Test path finding when nodes are merged."""
# Create a new graph for this specific test
graph2 = fx.Graph()
x2 = graph2.placeholder("x")
a2 = graph2.call_function(torch.neg, args=(x2,), name="A2")
b2 = graph2.call_function(torch.abs, args=(a2,), name="B2")
c2 = graph2.call_function(torch.relu, args=(x2,), name="C2")
d2 = graph2.call_function(torch.sigmoid, args=(c2,), name="D2")
graph2.output(d2)
tracker2 = AugmentedGraphHelper(graph2)
# Initially no path from B2 to D2
self.assertFalse(tracker2.has_path(b2, d2))
# Merge B2 and C2
tracker2.merge_to_set(b2, c2)
# Now there should be a path B2/C2 -> D2
self.assertTrue(tracker2.has_path(b2, d2))
def test_has_path_with_extra_deps(self):
"""Test path finding with extra dependencies."""
graph2 = fx.Graph()
x2 = graph2.placeholder("x")
a2 = graph2.call_function(torch.neg, args=(x2,), name="A2")
b2 = graph2.call_function(torch.abs, args=(a2,), name="B2")
c2 = graph2.call_function(torch.relu, args=(x2,), name="C2")
d2 = graph2.call_function(torch.sigmoid, args=(c2,), name="D2")
graph2.output(d2)
tracker2 = AugmentedGraphHelper(graph2)
# Initially no path from B2 to D2
self.assertFalse(tracker2.has_path(b2, d2))
tracker2.add_extra_dep(n=c2, dep=b2)
# Now there should be a path B2/C2 -> D2
self.assertTrue(tracker2.has_path(b2, d2))
# ========== Cycle Detection Tests ==========
def test_no_cycle_in_dag(self):
"""Test that DAG has no cycles."""
# Our original graph is a DAG, should have no cycles
self.assertFalse(self.tracker.has_cycle())
def test_simple_cycle_detection(self):
"""Test detection of simple cycle."""
# Create a graph with a cycle
graph3 = fx.Graph()
x3 = graph3.placeholder("x")
# We can't create true cycles in fx.Graph directly,
# but we can simulate with extra_deps
a3 = graph3.call_function(torch.neg, args=(x3,))
b3 = graph3.call_function(torch.abs, args=(a3,))
c3 = graph3.call_function(torch.relu, args=(b3,))
graph3.output(c3)
tracker3 = AugmentedGraphHelper(graph3)
self.assertFalse(tracker3.has_cycle())
# Add extra dep to create cycle: a3 -> c3
tracker3.add_extra_dep(n=a3, dep=c3)
self.assertTrue(tracker3.has_cycle())
def test_cycle_through_merge(self):
"""Test that merging can create cycles."""
# Create specific graph for this test
graph4 = fx.Graph()
x4 = graph4.placeholder("x")
a4 = graph4.call_function(torch.neg, args=(x4,))
b4 = graph4.call_function(torch.abs, args=(a4,))
c4 = graph4.call_function(torch.relu, args=(x4,))
d4 = graph4.call_function(torch.sigmoid, args=(c4,))
graph4.output(d4)
tracker4 = AugmentedGraphHelper(graph4)
# Add extra dep d4 -> a4
tracker4.add_extra_dep(n=a4, dep=d4)
# Now: a4 -> b4, c4 -> d4 -> a4
# Merging b4 and c4 would create cycle
tracker4.merge_to_set(b4, c4)
self.assertTrue(tracker4.has_cycle())
def test_cycle_with_extra_deps(self):
"""Test cycle detection with extra dependencies."""
node_a = self.nodes["A"]
node_b = self.nodes["B"]
# B already depends on A naturally
# Add reverse dependency to create cycle
self.tracker.add_extra_dep(n=node_a, dep=node_b)
self.assertTrue(self.tracker.has_cycle())
def test_multiple_merge_unmerge(self):
"""Test sequence of merge and unmerge operations."""
nodes = [self.nodes[c] for c in ["A", "B", "C", "D", "E"]]
# Merge A, B, C
self.merge_nodes(self.tracker, nodes[:3])
self.assertEqual(len(self.tracker.merge_sets[nodes[0]]), 3)
# Merge D, E
self.merge_nodes(self.tracker, nodes[3:5])
self.assertEqual(len(self.tracker.merge_sets[nodes[3]]), 2)
# Merge the two groups via B and D
try:
self.merge_nodes(self.tracker, [nodes[1], nodes[3]])
thrown = False
except AssertionError:
thrown = True
self.assertTrue(thrown)
# Unmerge C
self.tracker.unmerge_node(nodes[2])
self.assertEqual(len(self.tracker.merge_sets[nodes[0]]), 2)
self.assertEqual(self.tracker.merge_sets[nodes[2]], {nodes[2]})
# Unmerge A
self.tracker.unmerge_node(nodes[0])
self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]})
self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()

Some files were not shown because too many files have changed in this diff Show More