Compare commits

..

144 Commits

Author SHA1 Message Date
75b8295868 Revert "Warn if AccumulateGrad stream does not match producer node stream (#165065)"
This reverts commit 12f742941d6aecb72c18d8e602f90ac9b4f00af0.

Reverted https://github.com/pytorch/pytorch/pull/165065 on behalf of https://github.com/clee2000 due to broke internal builds D85273204 usages of TORCH_API void add need to be updated? ([comment](https://github.com/pytorch/pytorch/pull/165065#issuecomment-3438061854))
2025-10-23 17:02:49 +00:00
defb6a80d8 Enable torch.Generator to support pytorch/xla generator implementation (#161369)
Currently, the implementation of `torch.Generator` only support "cpu" and "cuda" device type.  https://github.com/pytorch/pytorch/blob/main/torch/csrc/Generator.cpp#L55-L61

This change enables `torch.Generator` to support more device type by allowing any device backend to register their own generator factory through a Generator Registry. This is similar to what "DeviceGuardImpl registry" does today.

# Key Changes:

## New registry API:

* Added GeneratorRegistry.h and GeneratorRegistry.cpp in c10/core/impl.
* API supports registerGenerator(DeviceType, GeneratorFactory), unregisterGenerator(DeviceType), and getGeneratorFactory(DeviceType).
* Uses c10::DeviceType as the key and stores a factory function returning c10::intrusive_ptr<c10::GeneratorImpl>.

## Python/C++ integration:

* The registry is consulted in the torch.Generator constructor path for non-CPU/CUDA devices.
* If a factory is registered for the requested device, it constructs the appropriate generator; otherwise, raises an error.

## Backend extensibility:

* Out-of-tree backends (e.g., torch_xla, torch-directml, torch_npu) can now register their custom generator implementation at module load via a static registrar object.
Example usage:
```
C++
namespace {
  struct Registrar {
    Registrar() {
      at::detail::registerGenerator(c10::DeviceType::XLA, &CreateXlaGenerator);
    }
  } registrar_instance;
}
```

This allows torch.Generator(device='xla') to return an XlaGeneratorImpl when the torch_xla extension is imported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161369
Approved by: https://github.com/FFFrog, https://github.com/qihqi, https://github.com/albanD
2025-10-23 16:49:28 +00:00
f8fccb1e48 [Code Clean] Clean asserts in torch/optim. (#165629)
Replaces 50 assert statements across 15 files in torch.optim with explicit  if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.

fix partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165629
Approved by: https://github.com/albanD
2025-10-23 15:56:29 +00:00
5aac4cfce4 Use is rather than == to work around slow enum comparion in _ops.py (#165936)
This shows up (under _are_we_tracing) in DTensor dispatch. I have some work in flight to speed up enum comparison in pybind11, but `is` is just much faster and easy to use.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165936
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2025-10-23 15:01:55 +00:00
baf91bbbfc Revert "[inductor][choices] lookup table choices 1/3 (#164978)"
This reverts commit ab9e466928e7a37844c4f2a8bf90c76d16ac3c34.

Reverted https://github.com/pytorch/pytorch/pull/164978 on behalf of https://github.com/malfet due to Looks like it broke slow tests, see cbcb4f7768/1 ([comment](https://github.com/pytorch/pytorch/pull/164978#issuecomment-3437424559))
2025-10-23 14:47:07 +00:00
cbcb4f7768 [pytorch][torchelastic] Duplicate stdout and stderr and apply custom filter in torchrun (#160712)
Summary:
Part of an effort to extract some important error logs (e.g. [#157996](https://github.com/pytorch/pytorch/pull/157996)) that was `tee`'ed to `stdout` and `stderr`.

The general idea is to:

- Duplicate the `tee`s on `stdout` and `stderr` to a separate file, `filtered_stdout.log` and `filtered_stderr.log`, respectively.
- In these files, as its name suggests, only log lines matching a customizable filter.
- Later on in another PR, append the contents of these files to the reply file.

Outline of changes in this PR:

- Enhance `TailLog` to be able to 1) stream to a file, and 2) only write when the line matches the passed filter.
- Add `filtered_stdout` and `filtered_stderr` to `LogsDest` and have `LogsSpecs` `reify` them.
- In `start_processes()` and `PContext`, add params `duplicate_stdout_filters` and `duplicate_stderr_filters` to filter and write the duplicated stream to the files above. When no filters are passed in, no duplicated streams are created.

Test Plan:
```
$ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:api_test
```
```
Buck UI: https://www.internalfb.com/buck2/f5c6b7da-217d-4a0b-872a-c7cd3d05587f
Test UI: https://www.internalfb.com/intern/testinfra/testrun/4222124951617688
Network: Up: 398B  Down: 44MiB  (reSessionID-a489a961-b602-45be-b851-3490ebb7a26a)
Analyzing targets. Remaining     0/200
Executing actions. Remaining     0/12856                                                                                                                                        0.1s exec time total
Command: test.     Finished 1 local
Time elapsed: 17:37.9s
Tests finished: Pass 52. Fail 0. Fatal 0. Skip 0. Build failure 0
```
```
$ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:tail_log_test
```
```
Buck UI: https://www.internalfb.com/buck2/d6d5c1c1-db98-4d9c-b608-7ba6fbb5e3ee
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13510798985149262
Network: Up: 94KiB  Down: 417MiB  (reSessionID-27b46fba-d31c-4c04-8ede-a506454e6922)
Analyzing targets. Remaining     0/3                                                                                                                                            536 actions, 555 artifacts declared
Executing actions. Remaining     0/186                                                                                                                                          1:05.5s exec time total
Command: test.     Finished 7 local, 1 remote, 115 cache (93% hit)                                                                                                              37.0s exec time cached (56%)
Time elapsed: 1:11.5s
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Rollback Plan:

Differential Revision: D80188995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160712
Approved by: https://github.com/fduwjj
2025-10-23 14:22:21 +00:00
2b93d5b450 [FlexAttention][CUDA] Add flex configs for Blackwell (#165760)
This PR fixes ULFs on `max_autotune` mode for high head-dim sizes on B200. Closes https://github.com/pytorch/torchtitan/issues/1791

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165760
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-10-23 10:22:06 +00:00
6b7cd48e7e [ROCm] Deserialize loads in planer sum portion of reduce() of norm. (#165927)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165927
Approved by: https://github.com/jeffdaily
2025-10-23 09:45:01 +00:00
bf5aa9e42e [dynamo] Remove ID guard on method object (#166096)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166096
Approved by: https://github.com/tugsbayasgalan
2025-10-23 06:22:49 +00:00
b1eb6dede5 [vision hash update] update the pinned vision hash (#166046)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166046
Approved by: https://github.com/pytorchbot
2025-10-23 04:27:44 +00:00
673060beae [inductor] turn Inductor deterministic mode on with torch.use_deterministic_algorithms (#165950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165950
Approved by: https://github.com/v0i0, https://github.com/eellison
2025-10-23 02:48:42 +00:00
2e8e9a59a8 Revert "[dynamo][easy] Support torch.accelerator.current_accelerator (#165734)" (#166094)
This reverts commit c18ddfc5721dd91bf29c769e850a99c4fdb6f380.

Discovers some latent issues causing internal failures. Will fix those issues first and resend the PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166094
Approved by: https://github.com/bdhirsh
2025-10-23 01:24:46 +00:00
fb277a5916 Enable new tracer by default (#165332)
Differential Revision: [D84516080](https://our.internmc.facebook.com/intern/diff/D84516080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165332
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #165582, #163580
2025-10-23 00:40:29 +00:00
73fa0d0c63 test for #165446 (#165853)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165853
Approved by: https://github.com/drisspg
2025-10-23 00:08:18 +00:00
36c21cc84e state dict staging fixes (#166025)
Summary:
This PR contains three changes -
1. We are losing non-blocking flag value and defaulting to False during the deep_copy. This is introducing a cuda synchronize after each tensor. This is slowing the staging.
2. Adding the capability to skip pinning for scalar tensors to reduce initial staging buffer creation cost. Setting it by default to 65 to avoid pinning small tensors.
3. Tensor share storage but each storage needs to be processed only once in the deep_copy with offloading logic. so, use the memoization table to cache storage ids.

Test Plan:
1. Verified non-blocking copies via kineto profile.
2. ran A/B jobs old and new staging with fixes such that it crashes after ever 2 checkpoints and restarts for several hours and compared loss curves and they are exactly identical.
3. tests

Differential Revision: D85180484

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166025
Approved by: https://github.com/pradeepfn
2025-10-22 23:32:41 +00:00
0b68814b44 Forward fix to D80948073 (#166023)
Summary:
realize tensor before accessing layout.

Differential Revision: D85172267

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166023
Approved by: https://github.com/laithsakka
2025-10-22 22:00:53 +00:00
e64a814ae7 [CUDA] Add experimental green context support for SM carveout (#159104)
Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here...

Built on top of @drisspg 's branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104
Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501

Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-10-22 21:38:52 +00:00
0b58d87aec [Submodule] Bump FBGEMM to latest (#165544)
Summary:

* FBGEMM submodule updated to main
* CMake updated to reflect necessary changes
* Notably pulls in NVFP4 grouped gemm kernels

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165544
Approved by: https://github.com/cyyever, https://github.com/jeffdaily
2025-10-22 20:57:15 +00:00
757975ad50 [export] Unified graph capture with fullgraph_capture. (#165562)
Summary:
_dynamo_graph_capture_for_export in the current form has the compability issue
with the main torch.compile() path despite we reuse fullgraph_capture as the
bytecode tracer. The reason is that we flip on many export specific flags
and even trace with a wrapped function which will cause divergence with
torch.compile() again.

This PR instead creates a new implementation of dynamo_graph_capture_for_export
which 100% relies on fullgraph capture and post-processing on CaptureOutput so
that we can avoid the inversion of phases in PT2 compiler stack.

This also benefits precompile workflow since we want to have a feature that
only accepts pytree inputs and ship portable python wrappers in package. In
other words, I think the code here is sharable between export and precompile
for exporting portable graph.

Test Plan:
===================================================================== test session starts =====================================================================
platform linux -- Python 3.12.11, pytest-7.3.2, pluggy-1.6.0
rootdir: /data/users/zhxchen17/pytorch
configfile: pytest.ini
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, subtests-0.13.1, rerunfailures-14.0, flakefinder-1.1.0, cpp-2.3.0, anyio-4.10.0
collected 9 items
Running 9 items in this shard

test/distributed/tensor/test_dtensor_export.py ........x                                                                                                [100%]

================================================================ 8 passed, 1 xfailed in 11.42s ================================================================

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165562
Approved by: https://github.com/tugsbayasgalan
2025-10-22 20:44:55 +00:00
291712026b [dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build (#165706)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165706
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
2025-10-22 19:28:27 +00:00
3e77a2b478 [PyTorch] Improve aarch64 performance of bfloat16 ops (#166028)
Summary:
PR allows compiler to better optimize some bfloat16-based operations, when ran on NEON

Benchmarks show measurable improvements:

Before:
bfloat16 add: 250.503us
bfloat16 sub: 245.674us
bfloat16 neg: 113.945us

After:
bfloat16 add: 203.862us ---> 23% higher throughput
bfloat16 sub: 201.526us ---> 22% higher throughput
bfloat16 neg: 74.986us ---> 52% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

 binary_test.py has been updated, to run bfloat16 benchmarks using basic arithmetic functions

Differential Revision: D85186786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166028
Approved by: https://github.com/Skylion007
2025-10-22 19:25:33 +00:00
82ef1b5db3 [DebugMode] refactor logs into _DebugCalls (#165376)
Refactors `DebugMode.operators` to be more structured `_DebugCall` objects, instead of (op, args, kwargs, call_depth) tuples. Useful going forward for attaching more information (e.g. output info, call metadata).

Is BC-breaking, but attaches an `__iter__` method for `_OpCall` and `_RedistributeCall` so previous tuple usage is accessible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165376
Approved by: https://github.com/yushangdi
2025-10-22 19:01:56 +00:00
5f370f5c42 inductor_provenance: Correctly handle null provenance (#166019)
Summary:
If the provenance is null, we're getting crashes of the form
```
[trainers0]:E1021 10:51:31.990525  2752 PythonApi.h:87] Exception caught in
GeneratedDynamoCompileLoggerConfig: <class
'dsi.logger.py3.GeneratedDynamoCompile.LogEntry.thrift_types.GeneratedDynamoCompileLogEntryThriftBase'>:
error initializing Thrift struct field 'inductor_provenance_thrift_safe':
Cannot create internal string data representation. Expected type <class 'str'>,
got: <class 'NoneType'>.
```

Also fixed a type signature that wasn't being enforced. (It's still not
enforced, but it's accurate).

Test Plan:
Added a new test which reproduces the logging issue

Differential Revision: D85173596

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166019
Approved by: https://github.com/ppanchalia, https://github.com/yushangdi
2025-10-22 18:21:57 +00:00
05b2e02cb4 Revert "[lint] workflow consistency linter to look at all files instead of just changed files (#165171)"
This reverts commit c746feb86a1459db5f6294730d1d72ed15f16dd3.

Reverted https://github.com/pytorch/pytorch/pull/165171 on behalf of https://github.com/clee2000 due to broke lint [GH job link](https://github.com/pytorch/pytorch/actions/runs/18723760085/job/53402955955) [HUD commit link](c746feb86a) ([comment](https://github.com/pytorch/pytorch/pull/165171#issuecomment-3433501457))
2025-10-22 17:47:29 +00:00
12f742941d Warn if AccumulateGrad stream does not match producer node stream (#165065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165065
Approved by: https://github.com/ngimel
2025-10-22 17:33:27 +00:00
35180fafee Allow GraphPickler to pickle graph modules containing AOTCompiled subgraphs (#165844)
This PR allows GraphPickler to pickle aot_eager graph modules that have regional inductor bits in them, with a few exceptions:
- FlexAttentionBackward isn't marked cacheable, so those tests don't work immediately since we're not sure how to serialize it. But it's safe to serialize/cache, so the next PR fixes those unit tests.
- It seems that when reloading a GraphPickled object, we don't recompile subgraphs. Will investigate this in a future PR

All unit tests in test_regional_inductor are parameterized so that we try serializing and deserializing the returned graph module before returning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165844
Approved by: https://github.com/oulgen
ghstack dependencies: #165843
2025-10-22 17:03:49 +00:00
c746feb86a [lint] workflow consistency linter to look at all files instead of just changed files (#165171)
As in title

If you change only one workflow file, lintrunner (default arg, also the one in CI since it only inputs changed files) won't look at other files in the repo, but the sync-tag might come from those other files

This makes it so that it looks at all workflow files so it will catch those failures

Pros:
catches errors

Cons:
unusual behavior (getting around what lintrunner says the linter should run on)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165171
Approved by: https://github.com/malfet
2025-10-22 16:57:59 +00:00
c5f26db5bf fix #166057: add tmp ptr to avoid gcc internal compiler error (#165717)
Fixes #166057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165717
Approved by: https://github.com/malfet
2025-10-22 16:38:26 +00:00
18e99b6d45 [dirsync] Switch to top-level xplat/third-party/pthreadpool (#165995)
Summary: `fbcode//xplat/third-party/pthreadpool:` just redirects to the xplat version. Switch to the real location

Test Plan: This should be a no-op, so CI?

Differential Revision: D83999534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165995
Approved by: https://github.com/bigfootjon, https://github.com/Skylion007
2025-10-22 16:18:23 +00:00
ab9e466928 [inductor][choices] lookup table choices 1/3 (#164978)
\# why

- enable users to control which choices get used on which inputs
- reduce lowering time, and pin kernel selection, by selecting
  them for the inputs

\# what

- a new InductorChoices subclass that implements a lookup table
- a README explaining the usage
- corresponding testing

- currently only supports templates that go through
  `V.choices.get_template_configs`

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -v
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164978
Approved by: https://github.com/PaulZhang12, https://github.com/eellison
2025-10-22 16:11:31 +00:00
af4ba78543 [scan x vmap] support scan in vmap (#165580)
This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap).

The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580
Approved by: https://github.com/zou3519
ghstack dependencies: #165675
2025-10-22 09:46:00 +00:00
282f39a4bc [vmap][dynamo] use create_proxy instead of create_node in vmap increate nesting ctx manager (#165675)
create_node won't do the auto closure lifting, this cause problems when the context manager is used in a hop region. Switch to create_proxy instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165675
Approved by: https://github.com/zou3519, https://github.com/guilhermeleobas
2025-10-22 09:46:00 +00:00
a479769488 [dynamo] Clean up assert in dynamo [2/N] (#165745)
Extend from #165430
* #165903(Clean up for graph break)
* ->#165745
* #165430

One main refractor from the previous PR:
* For assertions like checking `len(args)` or `len(kwargs)`, using `raise_args_mismatch` instead of `raise_type_error_exc`

I am also considering moving `raise_type_error_exc` into `utils.py` for consistency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165745
Approved by: https://github.com/Lucaskabela
2025-10-22 07:12:37 +00:00
26c7375477 Remove the branch of IS_CUSPARSE11_AVAILABLE is False (#166048)
This PR removes the branch when `IS_CUSPARSE11_AVAILABLE` is 0. Note that the condition `ROCM_VERSION >= 60300` holds currently as the minimum supported ROCm is 6.3 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166048
Approved by: https://github.com/Skylion007
2025-10-22 07:10:11 +00:00
d01f15152c Move toUnderlying to headeronly (#165694)
As in the title. Required in upper PRs of this ghstack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165694
Approved by: https://github.com/janeyx99
2025-10-22 05:31:16 +00:00
4fae6968b1 Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405) (#166018)
This PR is created to replace the reverted PR https://github.com/pytorch/pytorch/pull/164405
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166018
Approved by: https://github.com/janeyx99
2025-10-22 05:16:58 +00:00
f9953e0f61 Enable PLC0414 on ruff (#165828)
This PR enables `PLC0414` that fixes redundant import aliases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165828
Approved by: https://github.com/albanD
2025-10-22 04:56:52 +00:00
34ed7a8f0d [ROCm] Skip test_blockwise_nvfp4_with_global_scale (#165968)
Disable the fp4 global_scale test till the feature is enabled on ROCm.

Fixes #166027.
Not really, but we're trading an issue for a test skip decorator since the test is parameterized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165968
Approved by: https://github.com/jeffdaily, https://github.com/drisspg
2025-10-22 04:23:05 +00:00
2fde10d914 [ROCm] fix test_allocator_backend (#166035)
Fixes #165872.

Forward fix PR #165298. hipify was causing some symbols to be replaced.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166035
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-22 03:46:23 +00:00
0a93295da0 Update doc (#166024)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166024
Approved by: https://github.com/yiming0416
2025-10-22 03:41:31 +00:00
4b898b51b9 [12/n][take2] : Remove fbandroid_compiler_flags platform args (#165916)
Summary: This diff removes the `fbandroid_compiler_flags` and merges its content with `compiler_flags` and wraps it in a android select. My first attempt at this got reverted - D84626885.

Test Plan:
CI and failing builds are now passing
```
buck2 build --target-universe fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore_redex_postprocessed_repack_resign @//fbandroid/mode/nosan @//fbandroid/mode/opt @//fbandroid/mode/milan_build_rdk @//fbandroid/mode/relr-relocations fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore_redex_postprocessed_repack_resign fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore_redex_genrule fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore-mobileconfig-definition-resource-gen fbsource//fbandroid/apps/wearable/system/healthservices:healthservices_target30_mosnative_xhdpi_arm64_release_debug_keystore
File changed: fbsource//tools/build_defs/fb_xplat_cxx_library.bzl
Buck UI: https://www.internalfb.com/buck2/509c0b7b-ada3-421a-8c32-2f1d3a7babdd
Network: Up: 1.3MiB  Down: 293MiB  (reSessionID-17f73b81-3c34-4c01-9f6c-2b4f3c8332e3)
Loading targets.   Remaining     0/1311                                                                                                                                                                                                292986 targets declared
Analyzing targets. Remaining     0/13515                                                                                                                                                                                               216715 actions, 359204 artifacts declared
Executing actions. Remaining     0/40415                                                                                                                                                                                               6:33.3s exec time total
Command: build.    Finished 40 local, 790 remote
Time elapsed: 32.0s
BUILD SUCCEEDED
```

Reviewed By: jaejunku

Differential Revision: D84868234

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165916
Approved by: https://github.com/malfet
2025-10-22 03:01:55 +00:00
550e3e6efb [dynamo] Fix MATCH_KEYS for dict pattern matching (#165956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165956
Approved by: https://github.com/guilhermeleobas, https://github.com/cyyever
2025-10-22 02:52:07 +00:00
715449ca76 [MPS] Fix parity between CPU and MPS on singular matrices in linalg.lu_factor (#165871)
Fixes #165870. Follow up from #165254.

This PR [a] removes the MPS specific version of `lu_factor` in favor of the version in BatchedLinearAlgebra.cpp which uses `lu_factor_ex`, and [b] updates `lu_factor_ex` error codes to match expectations.

When `lu_factor` was first implemented for MPS (#99269), it bypassed the implementation in BatchedLinearAlgebra.cpp since we did not have `lu_factor_ex`. Since #144651 implements `lu_factor_ex`, we can now remove the MPS specific wrapper.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165871
Approved by: https://github.com/kulinseth, https://github.com/albanD
2025-10-22 02:48:40 +00:00
84d8d06fc3 Fixes floating point exception in torch.nn.PixelShuffle (#163154)
Fixes #162251

**Previous Output:**
`Floating point exception (core dumped)`

**Now Output:**
`RuntimeError: upscale factor is too large, (upscale_factor}^2 overflowed: upscale_factor=545460846592`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163154
Approved by: https://github.com/cyyever, https://github.com/albanD
2025-10-22 02:22:16 +00:00
60992d98b2 [dynamo][remaining] Replace UserFunctionVariable with VariableTracker build (#165896)
Audit: To prevent future issues with functools.partial or callable objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165896
Approved by: https://github.com/Lucaskabela
2025-10-22 02:13:00 +00:00
59e015e3a1 Remove outdated CUB macros (#164656)
This PR removes `CUB_SUPPORTS_NV_BFLOAT16` and `CUB_SUPPORTS_FUTURE_VALUE` because they are always true on CUDA >=12 installations with its CUB version. Their branches are also removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164656
Approved by: https://github.com/albanD, https://github.com/eqy, https://github.com/jeffdaily
2025-10-22 02:02:50 +00:00
8904a5a7c9 Move allocation size config to AllocatorConfig for cross-allocator sharing (#159553)
# Motivation
Make CUDA and XPU share the same config and code. And allow the other backends to reuse them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159553
Approved by: https://github.com/albanD
ghstack dependencies: #160067
2025-10-22 01:48:56 +00:00
f5df9ca03a Fix creation of BINARY_SUBSCR in Python 3.14+ (#165864)
Python 3.14 replaced `BINARY_SUBSCR` by `BINARY_OP(opcode=BN_SUBSCR)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165864
Approved by: https://github.com/williamwen42
2025-10-22 01:43:03 +00:00
2998abd777 [Code Clean] Better error handling in torch/csrc/distributed (#165053)
Replace the runtime_error of the vallina C++ exceptions with TORCH_CEHCK
Including:

torch/csrc/distributed/*

fix partialy #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165053
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-22 01:40:36 +00:00
e13580e41c [AMD] Run int4_mm tests only for compatible arch (#165630)
Such tests should be skipped for rest including gfx1100(Navi3x)

Fixes for CI HUD for gfx1100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165630
Approved by: https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2025-10-22 01:38:55 +00:00
f3b8e15f20 [AMD][gfx1100] test_decompose_mem_bound_mm.py tolerance increase (#165625)
test_decompose_mem_bound_mm.py tolerance increase for navi3x(gfx11x)

(cherry picked from commit 03c7da05f61890bbf5ae41e23c8df6d5f6805bac) from

Fixes for CI HUD for gfx1100

Signed-off-by: Artem Kuzmitckii <artem.kuzmitckii@amd.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165625
Approved by: https://github.com/jeffdaily

Co-authored-by: iupaikov-amd <Iurii.Paikov@amd.com>
Co-authored-by: Dmitry Nikolaev <139769634+dnikolaev-amd@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-22 01:38:48 +00:00
5211f4c108 [MPS] Fix SDPA fp16 overflow (#165961)
Do not cast intermediate result back to lower precision data data until
softmax is finished, otherwise it might produce NaN

Adjust the test to use 256 as filler value rather than 64

Fixes https://github.com/pytorch/pytorch/issues/160841
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165961
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #165960
2025-10-22 01:29:42 +00:00
ad9027b80d [BE] Remove unused 'rows' parameter from spmm_bmm_coo_rows_grouped (#166041)
To fix following compilation warning
```
Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/sparse/mps/kernels/Mul.metal:76:14: warning: unused variable 'B' [-Wunused-variable]
  const uint B = dims.x;
             ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/sparse/mps/kernels/Mul.metal:65:26: warning: unused parameter 'rows' [-Wunused-parameter]
    device const long*   rows      [[buffer(0)]],
                         ^
2 warnings generated.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166041
Approved by: https://github.com/Skylion007
2025-10-22 00:59:41 +00:00
a1005427bf [xpu] Support high stream for ProcessGroupXCCL (#163049)
Add high priority stream support for ProcessGroupXCCL. Just like CUDA, XPU streams also support execution with higher priority compared to other streams. Implementation in https://github.com/intel/torch-xpu-ops/pull/1715, add register here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163049
Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
2025-10-22 00:54:25 +00:00
35153d0846 Simplify c10::guts::apply (#164566)
There is only one call site of `c10::guts::apply` that can be replaced by `:std::apply` except for ROCm. This PR therefore simplifies the implementation of `c10::guts::apply`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164566
Approved by: https://github.com/Aidyn-A, https://github.com/albanD
2025-10-22 00:47:43 +00:00
7773a22cdb Revert "[AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)"
This reverts commit 4be1e3bf926b8e798fede3be6a3051560e9e00c5.

Reverted https://github.com/pytorch/pytorch/pull/165221 on behalf of https://github.com/clee2000 due to I think this broke test_openreg [GH job link](https://github.com/pytorch/pytorch/actions/runs/18698271058/job/53322459496) [HUD commit link](4be1e3bf92) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165221#issuecomment-3430012693))
2025-10-22 00:26:57 +00:00
7cb467a169 [CI] Update ONNX CI packages to latest (#165883)
This PR updates ONNX related packages to their latest versions used in CI environments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165883
Approved by: https://github.com/justinchuby, https://github.com/albanD
2025-10-22 00:25:35 +00:00
12aac12b8d [Code Clean] Replace std::runtime_error with TORCH_CHECK (#165209)
Including:
1. `aten/src/ATen/core`
2. `c10/core`

Fixes part of #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165209
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-22 00:05:22 +00:00
2b748d0a56 Add operator name to output json (#164583)
The benchmarks, model_name on dashboard needs to be grouped with operator_name. This PR passed an additional argument operator_name to the json for grouping.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164583
Approved by: https://github.com/yangw-dev
2025-10-21 23:58:39 +00:00
16745a882a [aoti][win] add support for a list of shim libraries (#165914)
As title, support passing in a list of shim libraries when cross compiling artifacts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165914
Approved by: https://github.com/desertfire
2025-10-21 22:55:17 +00:00
8daef35cf1 Revert "[Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) (#165433)"
This reverts commit df64c0c4649984093bd1a46f1e9c658c72018200.

Reverted https://github.com/pytorch/pytorch/pull/165433 on behalf of https://github.com/clee2000 due to I think this broke some quantization tests ([comment](https://github.com/pytorch/pytorch/pull/165433#issuecomment-3429741770))
2025-10-21 22:10:19 +00:00
51319ca090 [Pytorch] Add NEON Vectorized<uint> family of translation layers (#165690)
Summary:
Adding NEON specializations of Vectorized<T> for uint8, uint16, uint32 and uint64.

Correcness has been checked using test_ops.py

operator_benchmark_test.py, which uses the PyTorch API, shows significant enhancements in some operations:

Before:

uint8 mul: 1460.751us
uint8 add: 2359.565us
uint8 lsl: 2151.206us

After:

uint8 mul: 194.792us ---> 650% higher throughput
uint8 add: 195.609us ---> 1100% higher throughput
uint8 lsl: 186.249us ---> 1055% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Reviewed By: mcfi

Differential Revision: D84770153

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165690
Approved by: https://github.com/malfet
2025-10-21 21:46:55 +00:00
d311a3d1dc A temporary fix to autotune out of range and related IMA (#165943)
Summary:
Autotune issue during lowering w/ AOTI:
```
setStorage: sizes [1536, 32, 8192], strides [8192, 8192, 1], storage offset 0, and itemsize 2 requiring a storage size of 25673728 are out of bounds for storage of size 25362432
```
Need a hack to create new base tensor with sufficient storage

Test Plan: Finally be able to see the e2e test passes on CI. See the detailed Test Plan in D83520844

Differential Revision: D84872792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165943
Approved by: https://github.com/laithsakka
2025-10-21 21:40:20 +00:00
04adfe5ba9 Make Backend::setGroupUid virtual (#165957)
As titled, so that we may customize this function in custom backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165957
Approved by: https://github.com/d4l3k
2025-10-21 21:33:24 +00:00
4be1e3bf92 [AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)
This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified.

**The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.**

This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example:
```
File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type
    with self.assertWarnsRegex(
        AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast."
```

Sorry for the inconvenience again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 21:32:12 +00:00
e7592f4005 [CI] Move the periodic debug tests to newer runner (#165158)
Previously g3 = NVIDIA Tesla M60
Now g6 = NVIDIA L4
Also change cuda arch list accordingly

Pros:
More memory, newer GPU

Cons:
That was one of the few remaining tests on g3 runners, so we probably lost coverage?

We can probably run more tests in parallel now but I'm not going to do that here

Disabled a bunch of sparse tests and nestedtensor tests that were previously skipped due to not having sufficient hardware?  They are now failing with
```
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3293, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3292, in wrapper
    with policy():
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2532, in __enter__
    self.beforeStreams[-1].synchronize()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/cuda/streams.py", line 105, in synchronize
    super().synchronize()
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from stream_synchronize at /var/lib/jenkins/workspace/c10/cuda/CUDAFunctions.h:120 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) [clone .cold] from CUDAException.cpp:0
#7 THCPStream_synchronize(_object*, _object*) from Stream.cpp:0
#8 cfunction_vectorcall_NOARGS from /usr/local/src/conda/python-3.10.14/Objects/methodobject.c:489
#9 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#10 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
#11 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#12 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
```
when run with cuda launch blocking I got a ton of stuff like
```

/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [2,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [3,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,4,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,4,0] Assertion `value < upper_bound` failed.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165158
Approved by: https://github.com/seemethere
2025-10-21 21:28:12 +00:00
d334c3649d [CUDA] fix reflection padding for large batch size (#165942)
Fixes [#165861](https://github.com/pytorch/pytorch/issues/165861)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165942
Approved by: https://github.com/eqy
2025-10-21 21:07:38 +00:00
9f82535c5a [ROCm] [Normalization] Update block size (#165941)
* Seeing upto 6x improvement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165941
Approved by: https://github.com/jeffdaily
2025-10-21 20:53:05 +00:00
5b35fc8777 Support multiple commits on push events in trunk tagging workflow (#165937)
Context:
* this workflow is used to create tags like `trunk/{sha}` for all `main` commits
* those tags are used by [autorevert](https://github.com/pytorch/test-infra/blob/main/aws/lambda/pytorch-auto-revert/README.md) to rerun selected workflows

Problem: currently the workflow creates only a single tag per push event, while ghstack pushes multiple commits per single push.

This PR supports tag creation for all commits in the push event.

Complimentary autorevert PR: https://github.com/pytorch/test-infra/pull/7291

---

### Testing

I created an identical copy of this workflow in my personal repo: https://github.com/izaitsevfb/pr-head-test/actions/workflows/trunk-tagging.yml

See action runs there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165937
Approved by: https://github.com/huydhn
2025-10-21 20:52:34 +00:00
2f38eece7c [CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-21 20:48:12 +00:00
830e789a55 [dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)
This avoids generation of bad bytecode, leading to really confusing
error. I am not sure why we can't reconstruct cleanly, it has to do with
the input being a dict, while other supported ctx managers take bools.

Fixing that is for another day. Lets give a good error message for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006
Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
2025-10-21 20:48:04 +00:00
ad4dc52bf6 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit 4e643422f63a3cdd71bd141615f98de6bb54d15f.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/albanD due to Breaks lint ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3429426503))
2025-10-21 20:24:14 +00:00
dac9ed9790 Bump uv from 0.8.6 to 0.9.5 in /.ci/lumen_cli (#166017)
Bumps [uv](https://github.com/astral-sh/uv) from 0.8.6 to 0.9.5.
- [Release notes](https://github.com/astral-sh/uv/releases)
- [Changelog](https://github.com/astral-sh/uv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/uv/compare/0.8.6...0.9.5)

---
updated-dependencies:
- dependency-name: uv
  dependency-version: 0.9.5
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-21 13:16:30 -07:00
1c7fe8f861 [BugFix] chunk_size should always be int64_t (#165971)
aspired by https://github.com/pytorch/pytorch/pull/156872
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165971
Approved by: https://github.com/albanD
2025-10-21 19:52:47 +00:00
4e643422f6 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-21 19:47:33 +00:00
3c3b278872 [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)
Relands #148261 that was reverted by #150542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882
Approved by: https://github.com/ezyang
2025-10-21 19:43:55 +00:00
0bd12c1168 [CI] Extend test_transfomers to MPS (#165960)
Just skip grad_checks as they need float64
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165960
Approved by: https://github.com/Skylion007
2025-10-21 19:27:44 +00:00
ce8a7764e2 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 1290b077f26543a34262587137ef64ca9ca5e17d.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to failing internal tests D85160820 ([comment](https://github.com/pytorch/pytorch/pull/165707#issuecomment-3429084393))
2025-10-21 19:25:03 +00:00
d1269a0434 update fr trace analysis (#165994)
Summary:
- allow empty entries from ranks
- allow not all ranks to provide dump

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/165994).
* #165638
* #165640
* #165642
* __->__ #165994
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165994
Approved by: https://github.com/fduwjj
2025-10-21 19:14:33 +00:00
c87cf1be32 Update workaround to old CUDA bug (#164354) (#165984)
The workaround cannot be removed because of BC. Here we'll
update PyTorch code base to not use the workaround.

See https://github.com/pytorch/pytorch/pull/164354 for the BC breakage issue.

Resolves https://github.com/pytorch/pytorch/issues/164348.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165984
Approved by: https://github.com/janeyx99
2025-10-21 19:09:43 +00:00
2fc5e45a41 better error message when there is no pytree impl (#165955)
Differential Revision: [D85117597](https://our.internmc.facebook.com/intern/diff/D85117597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165955
Approved by: https://github.com/avikchaudhuri
2025-10-21 18:49:22 +00:00
f9022ba93b [PyTorch] Add user_metadata display to memory visualizer (#165939)
Summary: Enhanced the PyTorch CUDA memory visualizer to display user_metadata alongside stack frames when inspecting allocations. The user_metadata field is now shown in all views (Allocator State History, Active Memory Timeline, etc.) with consistent formatting. The implementation handles both string and object metadata types, displaying strings directly and objects as key-value pairs.

Test Plan:
1. Generate a memory snapshot with user_metadata
2. Open the memory visualizer in a browser
3. Load the snapshot file
4. Verify user_metadata appears
5. Test with both string metadata ("testing") and object metadata ({"key": "value"})
6. Verify formatting shows "User Metadata:\n  <value>" for strings

 {F1982860439}

Differential Revision: D85095152

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165939
Approved by: https://github.com/yushangdi
2025-10-21 18:48:33 +00:00
ff8be889ad Remove unused exception parameter from some files, to work with -Wunused-exception-parameter (#165770)
Summary: address compiler complains that were coming up to unblock the build

Test Plan:
before the change
```
aten/src/ATen/native/LinearAlgebra.cpp:3623:36: error: unused exception parameter 'e' [-Werror,-Wunused-exception-parameter]
 3623 |     } catch (const std::exception& e) {
      |
```

after: targets build with `-Wunused-exception-parameter`

Differential Revision: D84876246

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165770
Approved by: https://github.com/Skylion007, https://github.com/cyyever

Co-authored-by: Tony Targonski <tony.targonski@meta.com>
2025-10-21 18:30:29 +00:00
292454942e [CD] Introduce windows.12xlarge runners for CD Windows build (#165287)
Follows https://github.com/pytorch/test-infra/pull/7174. Windows CD build time cost comparison as below

|Runner|cpu|cuda|xpu|
|-|-|-|-|
|windows.4xlarge|1.5h| 4.0h| 5.5h|
|windows.12xlarge|0.5h|1.5h|2.5h|

Fixes #162962
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165287
Approved by: https://github.com/zxiiro, https://github.com/malfet, https://github.com/seemethere
2025-10-21 18:28:23 +00:00
6c4412f72b Revert "[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)"
This reverts commit e9d89734274a4a2640fa77b898c800a87d1d874e.

Reverted https://github.com/pytorch/pytorch/pull/163316 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
78bf6186f2 Revert "[Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)"
This reverts commit e8cb34dd52c063a130f3e659576c313bbe4b4981.

Reverted https://github.com/pytorch/pytorch/pull/163324 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
c40048472c Remove AOTI cross compilation time from internal CI (#165935)
Summary: as title

Test Plan: CI

Differential Revision: D85088451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165935
Approved by: https://github.com/desertfire
2025-10-21 16:58:28 +00:00
3dfd0c7584 Improve PATH hints in FindvecLib.cmake (#165881)
Change  /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk to /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk in `cmake/Modules/FindvecLib.cmake` which is more general (and MacOSX10.9 is not supported now). Otherwise, vecLib can't be found on MacOS 26.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165881
Approved by: https://github.com/ezyang
2025-10-21 16:44:12 +00:00
e6ba4d0725 Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910)
Summary:
Original commit changeset: d6d62d0c96dd

Original Phabricator Diff: D84468451 and D84613184

D84468451 caused CUDA OutOfMemoryError in model.

Test Plan:
D84468451 was found through bisect.  Also double checked on recent trunk 9866939225248c2adc307be7a804b26db0b9b555: f815887517

With this diff that backs out D84468451 and D84613184 : f816114560

Differential Revision: D85025378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165910
Approved by: https://github.com/clee2000
2025-10-21 16:36:38 +00:00
bdf7cb9d9c Revert "[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)"
This reverts commit e20c9bf2889b9252ac45ae6af35c93c795eab701.

Reverted https://github.com/pytorch/pytorch/pull/165410 on behalf of https://github.com/clee2000 due to sorry I'm going to revert this since I want to try to back out some other things that are conflicting with this, there is nothing wrong with this PR, rebasing and resolving the merge conflicts should be enough, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/165410#issuecomment-3427532373))
2025-10-21 16:27:54 +00:00
6aed378958 [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-21 15:53:05 +00:00
8b3dc0d1b0 Better error handling in torch/csrc/jit/runtime/* (#165118)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/runtime/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165118
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 15:22:49 +00:00
06773663b5 Implement an AOT precompile mode for standalone_compile (#165843)
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
2025-10-21 15:02:45 +00:00
0bff65503c Move hardware_destructive_interference_size to c10/core/alignment.h (#160067)
# Motivation
Move `hardware_destructive_interference_size` to `c10/core/alignment.h`, which gives a chance to reuse it across different accelerators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160067
Approved by: https://github.com/Skylion007, https://github.com/EikanWang
2025-10-21 14:39:46 +00:00
21131a2444 Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171))
2025-10-21 14:15:55 +00:00
1009790ad8 [pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972

- #164972

All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
2025-10-21 14:13:08 +00:00
410e6a4321 Better error handling in torch/csrc/jit/frontend/* (#165213)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/frontend/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165213
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 13:54:59 +00:00
23c55c5b66 [Code Clean]Replace assert statements with explicit if/raise patterns (#165735)
Fix part of #164878

Replace 75 assert statements with explicit if/raise patterns in `torch/ao/ns` , include:

- `torch/ao/ns/_numeric_suite_fx.py`  - 5 asserts

- `torch/ao/ns/fx/graph_matcher.py` - 6 asserts

- `torch/ao/ns/fx/graph_passes.py` -12 asserts

- `torch/ao/ns/fx/n_shadows_utils.py` - 20 asserts

- `torch/ao/ns/fx/pattern_utils.py` - 2 asserts

- `torch/ao/ns/fx/utils.py` - 21 asserts

- `torch/ao/ns/fx/weight_utils.py` - 19 asserts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165735
Approved by: https://github.com/albanD
2025-10-21 11:21:57 +00:00
1290b077f2 [dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
2025-10-21 09:27:41 +00:00
9f9ab881b2 [ROCm][inductor] heuristic improvements for reduction kernels (#161280)
Improvements to reduction kernel heuristics for MI350.

Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161280
Approved by: https://github.com/jansel, https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jeffdaily
2025-10-21 07:48:54 +00:00
f2bb22ff84 [Inductor-FX] Support Tensor.item (#165599)
# Feature
This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`.

# Implementation
The implementation is fairly mechanical, following the usual flow for these types of PRs.
1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`.
2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR.
3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars.

# Test plan
Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165599
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-10-21 07:09:56 +00:00
03f3f7899c [ATen] Add reduction tag to reduction operators (#165155)
Add a new 'reduction' tag to tags.yaml and apply it to 98 reduction
operator variants across 21 operator families (sum, mean, min, max,
argmin, argmax, amin, amax, aminmax, prod, all, any, norm, var, std,
std_mean, var_mean, nansum, logsumexp, count_nonzero, linalg_vector_norm).

This tag categorizes operators that perform reduction operations,
computing aggregate values across one or more dimensions of input
tensor(s).

Based on PR #153342 - co-written with @AlonSardas.

Just as we have pointwise tag - this can be useful for compiler passes, or for opting into sharding rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165155
Approved by: https://github.com/ezyang, https://github.com/zou3519, https://github.com/mlazos
2025-10-21 04:35:03 +00:00
771170807b [dynamo][nn_module] Replace UserFunctionVariable with VariableTracker build (#165708)
Audit: To prevent future issues with functools.partial or callable objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165708
Approved by: https://github.com/Lucaskabela
2025-10-21 04:13:12 +00:00
ffa90d46e6 [ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)
* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.

Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165481
Approved by: https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2025-10-21 04:02:04 +00:00
0e083942cc Enable PLW0127 in ruff (#165851)
This PR enables `PLW0127` in ruff, which checks self-assignment of variables with the form `var=var`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165851
Approved by: https://github.com/Lucaskabela
2025-10-21 03:30:57 +00:00
ce1fcff03e [ROCm] Keep amdgpu-coerce-illegal-types flag if rocm version is less than 7.2 (#165789)
The `-amdgpu-coerce-illegal-types=1` flag is for LLVM that is in ROCm 6.3, 6.4, 7.0, and 7.1. It will not be in ROCm7.2. It was added to enable performance improvements for composable kernel. ROCm7.2 and newer changed the compiler so that the flag isn't needed to achieve those performance improvements. Keeping the flag with ROCm 7.2 breaks the PyTorch build.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165789
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
2025-10-21 03:17:33 +00:00
a238a9a100 Add clang-tidy misc-definitions-in-headers check (#164959)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164959
Approved by: https://github.com/Skylion007, https://github.com/mikaylagawarecki
ghstack dependencies: #164882, #164956
2025-10-21 02:59:46 +00:00
fe69a2bbbd Move from/to to torch::stable::detail (#164956)
To not pollute the global namespace, we should move the `from`/`to` APIs into torch::stable::detail. We are also following our normal deprecation cycle and choosing to continue exposing the global `from`/`to` for the time being as people who onboard their extensions onto 2.9 would not be able to build with 2.10 otherwise.

Note that this means that within libtorch, we do not get the luxury of tacking on a `using torch::stable::detail::from` because then it leads to build time ambiguous calls --> both the global and namespace APIs are exposed, which one do I want? So that is why you see every local site is updated.

Note that the update is _not_ necessary from a custom op writer point of view. FA3 can continue to build on torch nightlies without changing any code. (Since this is a header change, this PR has no implication on runtime, a previously built FA3 ABI stable wheel will continue to work fine with newer torch versions after this PR.)

Once TORCH_BOX lands, we would be free to remove these global APIs when the deprecation cycle is up (April 2026) and encourage people to use TORCH_BOX and avoid from/to entirely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164956
Approved by: https://github.com/malfet
ghstack dependencies: #164882
2025-10-21 02:59:46 +00:00
0be0de4ffa Add type suppressions to _inductor/runtime (#165918)
Original PR that did this was reverted due to merge conflicts.

Trying it again

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165918
Approved by: https://github.com/oulgen
2025-10-21 02:54:22 +00:00
7406d2e665 [DeviceMesh] Clean up the call into mesh_resouces to get root mesh (#165787)
We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510. This is to further clean code up.

Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787
Approved by: https://github.com/fegin
2025-10-21 02:54:04 +00:00
303c9cf048 Save Python refcount bump on each arg in maybe_handle_torch_function (#164625)
Pybind's API entails a small unnecessary overhead when working with args. (Similarly, we should probably be using vectorcall, but that's a bigger change for both us and pybind11.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164625
Approved by: https://github.com/albanD
ghstack dependencies: #164624
2025-10-21 02:40:12 +00:00
d7d4bb7c51 Add XPU part for persons_of_interest (#165920)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165920
Approved by: https://github.com/albanD
2025-10-21 01:57:17 +00:00
0b1c462979 Making Numpy depedency in Local Tensor optional to fix broken Torchao CI (#165938)
In recent change LocalTensor introduced dependency on Numpy and has broken Torchao CI.
This dependency cna be made optional and required only when Local Tensor is used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165938
Approved by: https://github.com/atalman
2025-10-21 01:46:53 +00:00
4a6cf0a93e Fix dynamo stack trace (#165930)
Fixes #165911

- Add message to Attribute error so we see `  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])` instead of just `Developer debug context: raised exception AttributeError([])`
- Add stack trace in `ObservedException` so we display the inner most error stack trace back to user code

Output:

```
/data/users/shangdiy/pytorch/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
  warnings.warn(
Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1385, in var_getattr
    subobj = self._getattr_static(name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1256, in _getattr_static
    subobj = type(self.value).__getattribute__(self.value, name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Linear' object has no attribute 'w'

During handling of the above exception, another exception occurred:

torch._dynamo.exc.ObservedAttributeError: 'Linear' object has no attribute 'w'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/test.py", line 34, in <module>
    mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 481, in inner
    out = fullgraph_capture(
          ^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1053, in fullgraph_capture
    return _fullgraph_capture_frame(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1115, in _fullgraph_capture_frame
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
  Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
  Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html

from user code:
   File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 171, in forward
    res = self._export_root(*args, **kwargs)
  File "/data/users/shangdiy/pytorch/test.py", line 31, in forward
    weight = self.linear.w

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165930
Approved by: https://github.com/anijain2305
2025-10-21 01:32:23 +00:00
4c963a68d7 Use inline instead of anon namespace for stableivalue from/to (#164882)
Fixes https://github.com/pytorch/pytorch/issues/163343.

After some consideration, I propose we remove the anonymous namespace around from/to in favor of:
1. Adding inline to the function implementations, assuming that they will not change in the near future
2. If we decide to change them, we will wrap the code in inline versioned namespaces such that the implementations within any versioned namespace will be guaranteed identical.

Note that:
- We eventually intend to abstract away usage of `from`/`to` (related: @lw's TORCH_BOX work)
- The from/to implementations are now powered through class template specializations, where adding a specialization does not change the from/to signatures.

I do plan to deprecate top-level from/to in favor of torch::stable::details::from/to consequently. This way we can stop polluting the global namespace.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164882
Approved by: https://github.com/lw, https://github.com/albanD
2025-10-21 00:12:15 +00:00
b20deec3d1 [PP] Add optional argument to not save outputs (#165822)
Fix https://github.com/pytorch/pytorch/issues/159251

Add an optional argument `return_outputs` to the schedule `step`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165822
Approved by: https://github.com/wconstab
2025-10-21 00:09:31 +00:00
51d0d8ee67 [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```

Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-21 00:09:13 +00:00
70592c6819 [ROCm][CI] Move gfx1100 workflows to own yaml file (#165699)
This should allow us to move gfx1100 workflow to a lower frequency and also allow it to be triggered on PRs via a dedicated label, for any PRs that target Navi fixes such as [this](https://github.com/pytorch/pytorch/pull/165630) or [this](https://github.com/pytorch/pytorch/pull/165625).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165699
Approved by: https://github.com/jeffdaily
2025-10-20 23:52:48 +00:00
259cb945f5 [stage 2c] make autograd and inference functions (#165668)
Add final stage of aot_stage2_compile for autograd and inference.

Differential Revision: D84844699

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165668
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2025-10-20 23:50:31 +00:00
e20c9bf288 [torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)
Including:
- `torch/utils/*.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165410
Approved by: https://github.com/albanD
2025-10-20 23:29:17 +00:00
99c8640b5d [1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)
This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 23:27:13 +00:00
96b0e7aaa6 [Code Clean] Clean asserts in torch/ao/quantization/experimental/* and torch/ao/quantization/pt2e/* (#165317)
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/experimental/* (11 errors)
- torch/ao/quantization/pt2e/* (68 errors)

fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165317
Approved by: https://github.com/albanD
2025-10-20 23:07:11 +00:00
850ba8c96d [Code Clean] Clean asserts in torch/autograd. (#165627)
Replaces 78 assert statements across 10 files in torch.autograd with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag. This ensures error checking remains active in optimized builds.

fix partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165627
Approved by: https://github.com/albanD
2025-10-20 23:03:47 +00:00
1bcd736f91 fix bad merge duplicate pre pass (#165917)
fix for https://github.com/pytorch/pytorch/issues/165624 - we were applying pre pass multiple times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165917
Approved by: https://github.com/bdhirsh
2025-10-20 22:54:36 +00:00
df64c0c464 [Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) (#165433)
Replace assert statements with explicit if/raise patterns in:

- torch/ao/quantization/~
- torch/ao/quantization/quantizer/
- torch/ao/quantization/backend_config/

fix partialy #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165433
Approved by: https://github.com/albanD
2025-10-20 22:42:51 +00:00
1891239a1d [Graph Partition] fix graph partition input signature for fallback kernels (#165815)
Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition.

## Example
```python
def f(x):
    y = x + 1
    z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
    z_cpu = z.cpu()
    u_cuda = z_cpu.cuda()
    return u_cuda
```

In the generated code, we have
```
def partition_0(args):
    ...
    # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view]
    buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0
    buf2 = buf1 # <------- buf2 is buf1
    assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype')
    assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype')
    return (buf2, )

def call(self, args):
    ...
    (buf2,) = self.partitions[0](partition0_args)
    ...
    buf3.copy_(buf2, False)
    del buf0
    del buf1
    del buf2  # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0.
    ...
```

Note: view is treated as a fallback kernel due to its special dtype.
de09bab4b6/torch/_inductor/lowering.py (L841-L843)

## Fix

This PR fixes the issue by also returning these buffers to be freed later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165815
Approved by: https://github.com/eellison
2025-10-20 22:23:29 +00:00
cf280ca1e8 Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit 779296a3fce5db0829377c792f13a8eafe537b30.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3423808492))
2025-10-20 21:36:44 +00:00
efc277cac7 [annotation] add logging for debugging annotation (#165797)
Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" `

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165797
Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/SherlockNoMad
2025-10-20 21:27:38 +00:00
4f7f43253d Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit 8700d68fef855850e2e0aa65056a77b8f80adbdb.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/malfet due to Broke lint somehow, see 8f06a1308f/1 ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3423642456))
2025-10-20 20:39:56 +00:00
779296a3fc [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd
2025-10-20 20:39:04 +00:00
8f06a1308f [MPS] slightly faster cholesky (#165867)
Slightly faster cholesky, removed one redundant simdgroup_multiply
<img width="721" height="593" alt="Screenshot 2025-10-19 at 22 00 19" src="https://github.com/user-attachments/assets/e3a9005b-9347-4e62-a24d-16ba5e28849a" />

Generate benchmarks with(measured on M1 Pro):
```
import torch
import numpy as np
import time
import csv

matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3

def create_spd_matrix(n, batch_size):
    torch.manual_seed(42)
    A = torch.randn(batch_size, n, n, dtype=torch.float32)
    return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)

def run_cholesky_mps(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    b = torch.linalg.cholesky(A, upper=False)
    torch.mps.synchronize()
    end = time.perf_counter()
    return b, end - start

results = {
    'N': [],
    'batch_size': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    for batch_size in batch_sizes:
        print(f"\nBenchmarking N={n}, batch_size={batch_size}")

        try:
            A_cpu = create_spd_matrix(n, batch_size)
            A_mps = A_cpu.to("mps")

            for _ in range(warmup_runs):
                _, _ = run_cholesky_mps(A_mps)

            times = []
            for _ in range(num_runs):
                _, t = run_cholesky_mps(A_mps)
                times.append(t)

            mean_time = np.mean(times)
            std_time = np.std(times)

            results['N'].append(n)
            results['batch_size'].append(batch_size)
            results['mean_time'].append(mean_time)
            results['std_time'].append(std_time)

            print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")

        except RuntimeError as e:
            print(f"Error for N={n}, batch_size={batch_size}: {e}")
            continue

with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['batch_size'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165867
Approved by: https://github.com/malfet
2025-10-20 18:56:17 +00:00
240c13394e Revert "[inductor] require shape in TritonCSEVariable (#162275)"
This reverts commit 3af2f0c12accc6bd10ef2b76fb5c51aa0f6b73a3.

Reverted https://github.com/pytorch/pytorch/pull/162275 on behalf of https://github.com/clee2000 due to still failing due to the above D84932446 ([comment](https://github.com/pytorch/pytorch/pull/162275#issuecomment-3423153819))
2025-10-20 17:55:54 +00:00
150682ba7f Revert "Remove workaround to old CUDA bug (#164354)"
This reverts commit 26f38034332a99f2bdcc67ce1f4ba9403d420e52.

Reverted https://github.com/pytorch/pytorch/pull/164354 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164354#issuecomment-3423132083))
2025-10-20 17:48:08 +00:00
ca7360e996 Revert "Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405)"
This reverts commit ca8bd5dbedb5b46f78026e0378b0f47500ddba38.

Reverted https://github.com/pytorch/pytorch/pull/164405 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164354#issuecomment-3423132083))
2025-10-20 17:48:08 +00:00
0bf604320f Revert "[dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build (#165706)"
This reverts commit 1dc9a05d0323ee3c7a20945c62463959d40f1a51.

Reverted https://github.com/pytorch/pytorch/pull/165706 on behalf of https://github.com/clee2000 due to breaking internal tests D84961097 ([comment](https://github.com/pytorch/pytorch/pull/165706#issuecomment-3423059867))
2025-10-20 17:28:58 +00:00
9875e70da8 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 630520b346b8883db7821562e589ccde7d12687a.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to breaking internal tests D84961097 ([comment](https://github.com/pytorch/pytorch/pull/165706#issuecomment-3423059867))
2025-10-20 17:28:58 +00:00
69a4bfe8bb Revert "Refactor out headeronly ArrayRef (#164991)"
This reverts commit 3806e9767b03d06edc317cb90a3a996abdf192a0.

Reverted https://github.com/pytorch/pytorch/pull/164991 on behalf of https://github.com/clee2000 due to breaking internal tests D84961075 ([comment](https://github.com/pytorch/pytorch/pull/164991#issuecomment-3423058017))
2025-10-20 17:26:42 +00:00
62a263b8d4 Revert "Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)"
This reverts commit e4454947e2c692db1a249591121f8583fefe7df1.

Reverted https://github.com/pytorch/pytorch/pull/165152 on behalf of https://github.com/clee2000 due to breaking internal tests D84961075 ([comment](https://github.com/pytorch/pytorch/pull/164991#issuecomment-3423058017))
2025-10-20 17:26:42 +00:00
0da1f911dc Revert "[Submodule] Bump FBGEMM to latest (#165544)"
This reverts commit 23417ae50f5d9bc02e988d916c103ff3a03c5903.

Reverted https://github.com/pytorch/pytorch/pull/165544 on behalf of https://github.com/clee2000 due to failing in internal D84996252, probably needs some sort of update to fbgemm internally? ([comment](https://github.com/pytorch/pytorch/pull/165544#issuecomment-3422993703))
2025-10-20 17:06:07 +00:00
8700d68fef [ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)
* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.

Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165481
Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/albanD

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2025-10-20 16:06:37 +00:00
ab82456c16 Revert "[1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)"
This reverts commit e1e8491b316df810388d9fa24f135cdba27ab40e.

Reverted https://github.com/pytorch/pytorch/pull/165750 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165750#issuecomment-3422413890))
2025-10-20 14:51:58 +00:00
b23f4687fd [Inductor][CuTeDSL] Move load_template up two directories (#165868)
Summary:
This is a reland of https://github.com/pytorch/pytorch/pull/165347

Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: test/inductor/test_flex_flash

Differential Revision: D85013024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165868
Approved by: https://github.com/jananisriram
2025-10-20 12:14:38 +00:00
2705937080 [CI] Add rocm CI back to trunk for pre-submit/PR jobs (#165674)
Only adding single-GPU shards for now, to observe how current capacity handles it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165674
Approved by: https://github.com/jeffdaily
2025-10-20 12:14:06 +00:00
c1eda348be [cuda] fix triu/tril int32 overflow for large matrices (#164705)
Fixes #136611

Cast blockIdx.x to int64_t before multiplication to prevent overflow when computing linear_idx for matrices larger than 2^31 elements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164705
Approved by: https://github.com/eqy, https://github.com/ngimel
2025-10-20 07:17:41 +00:00
361 changed files with 8405 additions and 4438 deletions

View File

@ -19,7 +19,7 @@ pip_install \
transformers==4.36.2
pip_install coloredlogs packaging
pip_install onnxruntime==1.23.0
pip_install onnxruntime==1.23.1
pip_install onnxscript==0.5.4
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers

View File

@ -334,12 +334,12 @@ sympy==1.13.3
#Pinned versions:
#test that import:
onnx==1.18.0
onnx==1.19.1
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
onnxscript==0.5.3
onnxscript==0.5.4
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.8.6"
"uv==0.9.5"
]
[tool.setuptools]

View File

@ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then
MEMORY_LIMIT_MAX_JOBS=12
NUM_CPUS=$(( $(nproc) - 2 ))
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
if [[ "$(uname)" == Linux ]]; then
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
else
# For other builds
export MAX_JOBS=${NUM_CPUS}
fi
cat >>"$envfile" <<EOL
export MAX_JOBS="${MAX_JOBS}"

View File

@ -1 +1 @@
faffd5cf673615583da6517275e361cb3dbc77e6
1752fe6809b74921644866275ab80244b96e80bc

View File

@ -33,6 +33,7 @@ ciflow_push_tags:
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/rocm-navi31
- ciflow/s390
- ciflow/slow
- ciflow/torchbench

View File

@ -79,9 +79,9 @@ jobs:
runs-on: "windows-11-arm64-preview"
{%- else %}
{%- if branches == "nightly" %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
{%- else %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
{%- endif %}
{%- endif %}
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -279,7 +279,7 @@ jobs:
wheel-py3_10-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -517,7 +517,7 @@ jobs:
wheel-py3_10-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -755,7 +755,7 @@ jobs:
wheel-py3_10-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -993,7 +993,7 @@ jobs:
wheel-py3_10-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1229,7 +1229,7 @@ jobs:
wheel-py3_11-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1464,7 +1464,7 @@ jobs:
wheel-py3_11-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1702,7 +1702,7 @@ jobs:
wheel-py3_11-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1940,7 +1940,7 @@ jobs:
wheel-py3_11-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2178,7 +2178,7 @@ jobs:
wheel-py3_11-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2414,7 +2414,7 @@ jobs:
wheel-py3_12-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2649,7 +2649,7 @@ jobs:
wheel-py3_12-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2887,7 +2887,7 @@ jobs:
wheel-py3_12-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3125,7 +3125,7 @@ jobs:
wheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3363,7 +3363,7 @@ jobs:
wheel-py3_12-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3599,7 +3599,7 @@ jobs:
wheel-py3_13-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3834,7 +3834,7 @@ jobs:
wheel-py3_13-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4072,7 +4072,7 @@ jobs:
wheel-py3_13-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4310,7 +4310,7 @@ jobs:
wheel-py3_13-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4548,7 +4548,7 @@ jobs:
wheel-py3_13-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4784,7 +4784,7 @@ jobs:
wheel-py3_13t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5019,7 +5019,7 @@ jobs:
wheel-py3_13t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5257,7 +5257,7 @@ jobs:
wheel-py3_13t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5495,7 +5495,7 @@ jobs:
wheel-py3_13t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5733,7 +5733,7 @@ jobs:
wheel-py3_13t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5969,7 +5969,7 @@ jobs:
wheel-py3_14-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6204,7 +6204,7 @@ jobs:
wheel-py3_14-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6442,7 +6442,7 @@ jobs:
wheel-py3_14-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6680,7 +6680,7 @@ jobs:
wheel-py3_14-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6918,7 +6918,7 @@ jobs:
wheel-py3_14-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7154,7 +7154,7 @@ jobs:
wheel-py3_14t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7389,7 +7389,7 @@ jobs:
wheel-py3_14t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7627,7 +7627,7 @@ jobs:
wheel-py3_14t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7865,7 +7865,7 @@ jobs:
wheel-py3_14t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -8103,7 +8103,7 @@ jobs:
wheel-py3_14t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -147,15 +147,16 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

63
.github/workflows/rocm-navi31.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: rocm-navi31
on:
push:
tags:
- ciflow/rocm-navi31/*
workflow_dispatch:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 45 */2 * * 1-5
- cron: 45 4,12 * * 0,6
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
target-determination:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/target_determination.yml
permissions:
id-token: write
contents: read
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: >-
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune' || '' }}
secrets: inherit

View File

@ -59,29 +59,3 @@ jobs:
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-gfx1100-test:
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10-gfx1100
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
tests-to-include: >
test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune
secrets: inherit

View File

@ -58,8 +58,10 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -87,7 +89,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag with retry
- name: Create and push tag(s) with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -112,14 +114,23 @@ jobs:
return 1
}
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
# Retry configuration
MAX_RETRIES=5
@ -194,31 +205,111 @@ jobs:
}
}
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
exit 0
else
echo "Tag creation failed after all retry attempts"
exit 1
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
fi

View File

@ -190,6 +190,40 @@ jobs:
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
secrets: inherit
linux-jammy-rocm-py3_10-build:
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
secrets: inherit
inductor-build:
name: inductor-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -314,13 +314,14 @@ IF(USE_FBGEMM_GENAI)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
endif()
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(

View File

@ -19,6 +19,7 @@
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/detail/XLAHooksInterface.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <c10/core/QEngine.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
@ -88,6 +89,8 @@ class TORCH_API Context {
return at::detail::getHIPHooks();
} else if (opt_device_type == at::kHPU) {
return at::detail::getHPUHooks();
} else if (opt_device_type == at::kXLA) {
return at::detail::getXLAHooks();
} else {
TORCH_CHECK(
false,
@ -196,7 +199,7 @@ class TORCH_API Context {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
}
static bool hasXLA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
return detail::getXLAHooks().hasXLA();
}
static bool hasXPU() {
return detail::getXPUHooks().hasXPU();

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(64) FreeBlockList {
struct alignas(hardware_destructive_interference_size) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(64) HostStatsStaged {
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(64) std::mutex blocks_mutex_;
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(64) std::mutex events_mutex_;
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(64) HostStatsStaged stats_;
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -59,9 +59,7 @@ struct TORCH_API Generator {
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
: impl_(std::move(gen_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
}
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
}
bool operator==(const Generator& rhs) const {

View File

@ -111,9 +111,7 @@ class TORCH_API TensorBase {
explicit TensorBase(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: impl_(std::move(tensor_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("TensorImpl with nullptr is not supported");
}
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
}
TensorBase(const TensorBase&) = default;
TensorBase(TensorBase&&) noexcept = default;

View File

@ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) {
return it->second;
auto pos = s.find("::");
if (pos == std::string::npos) {
std::stringstream ss;
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
throw std::runtime_error(ss.str());
}
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
Symbol sym(sym_to_info_.size());
@ -121,12 +117,7 @@ std::string Symbol::domainString() const {
}
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
std::ostringstream ss;
ss << "Symbol: domain string is expected to be prefixed with '"
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
throw std::runtime_error(ss.str());
}
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
return fromQualString(qualString);
}

View File

@ -7,6 +7,7 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
@ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Enum:
case Tag::Stream:
case Tag::Uninitialized:
throw std::runtime_error(
TORCH_CHECK(false,
"unhashable type: '" + v.type()->repr_str() + "'");
}
// the above switch should be exhaustive

View File

@ -8,6 +8,7 @@
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h>
#include <c10/util/Exception.h>
#include <optional>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
@ -116,10 +117,8 @@ struct SingleElementType : public SharedType {
protected:
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
if (!this->elem) {
throw std::runtime_error(c10::str(
TORCH_CHECK(this->elem, c10::str(
"Can not create ", typeKindToString(Kind), " with None type"));
}
}
private:
@ -416,16 +415,12 @@ struct TORCH_API SymbolicShape {
}
ShapeSymbol operator[](size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
ShapeSymbol at(size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
@ -520,9 +515,7 @@ struct VaryingShape {
}
const std::optional<T> &operator[](size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
@ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType {
TypePtr createWithContained(
std::vector<TypePtr> contained_types) const override {
if (contained_types.size() != 2) {
throw std::runtime_error("Expected 2 contained types");
}
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
}

View File

@ -8,6 +8,7 @@
#include <ATen/core/jit_type.h>
#include <c10/macros/Macros.h>
#include <c10/util/env.h>
#include <c10/util/Exception.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <array>
@ -826,9 +827,7 @@ TupleType::TupleType(
: NamedType(TypeKind::TupleType, std::move(name)),
elements_(std::move(elements)),
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
if (!v) {
throw std::runtime_error("Can not create tuple with None type");
}
TORCH_CHECK(v, "Can not create tuple with None type");
return v->hasFreeVariables();
})), schema_(std::move(schema)) {

View File

@ -9,6 +9,7 @@
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
Vectorized<c10::BFloat16> neg() const {
return -values;
}
Vectorized<c10::BFloat16> reciprocal() const {
return 1.0f / values;
}
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
@ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
@ -412,28 +451,52 @@ template <>
Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x + y;
#else
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x - y;
#else
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x * y;
#else
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x / y;
#else
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
#endif
}
// frac. Implement this here so we can use subtraction
@ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y + z;
#else
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
// elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in
// the Vectorized interface...
return a * b + c;
#endif
}
template <>
@ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y + z;
#else
// See NOTE [BF16 FMA] above.
return -a * b + c;
#endif
}
template <>
@ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y - z;
#else
// See NOTE [BF16 FMA] above.
return a * b - c;
#endif
}
template <>
@ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y - z;
#else
// See NOTE [BF16 FMA] above.
return -a * b - c;
#endif
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -0,0 +1,378 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<uint##bit##_t> { \
using neon_type = uint##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = uint##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_u##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(uint##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
values = vld1q_u##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<uint##bit##_t> loadu( \
const void* ptr, \
uint64_t count = size()); \
void store(void* ptr, uint64_t count = size()) const; \
template <uint64_t mask> \
static Vectorized<uint##bit##_t> blend( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b); \
static Vectorized<uint##bit##_t> blendv( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
const Vectorized<uint##bit##_t>& mask_) { \
return vbslq_u##bit(mask_.values, b, a); \
} \
template <typename step_t> \
static Vectorized<uint##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<uint##bit##_t> set( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
uint64_t count = size()); \
const uint##bit##_t& operator[](uint idx) const = delete; \
uint##bit##_t& operator[](uint idx) = delete; \
Vectorized<uint##bit##_t> abs() const { \
return values; \
} \
Vectorized<uint##bit##_t> real() const { \
return values; \
} \
Vectorized<uint##bit##_t> imag() const { \
return vdupq_n_u##bit(0); \
} \
Vectorized<uint##bit##_t> conj() const { \
return values; \
} \
Vectorized<uint##bit##_t> neg() const { \
return vreinterpretq_u##bit##_s##bit( \
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
} \
uint##bit##_t reduce_add() const { \
return vaddvq_u##bit(values); \
} \
uint##bit##_t reduce_max() const; \
Vectorized<uint##bit##_t> operator==( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator!=( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> operator<( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator<=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> eq( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ne( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> gt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ge( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> lt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> le( \
const Vectorized<uint##bit##_t>& other) const; \
}; \
template <> \
Vectorized<uint##bit##_t> inline operator+( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vaddq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator-( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vsubq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator&( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vandq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator|( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vorrq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator^( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return veorq_u##bit(a, b); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this == other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this != other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this > other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this < other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
}
VEC_UINT_NEON_TEMPLATE(16, 8)
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
return vmaxvq_u8(values);
}
template <>
Vectorized<uint8_t> inline operator*(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmulq_u8(a, b);
}
template <>
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
return vmvnq_u8(a);
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
const Vectorized<uint8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<uint8_t> inline minimum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vminq_u8(a, b);
}
template <>
Vectorized<uint8_t> inline maximum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmaxq_u8(a, b);
}
template <uint64_t mask>
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
#define VEC_UINT_NEON_OPS(vl, bit) \
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
values = vdupq_n_u##bit(val); \
} \
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
const void* ptr, uint64_t count) { \
if (count == size()) { \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
} else { \
__at_align__ uint##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const uint##bit##_t*>(ptr), \
count * sizeof(uint##bit##_t)); \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
const { \
if (count == size()) { \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
} else { \
uint##bit##_t tmp_values[size()]; \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
} \
}
VEC_UINT_NEON_OPS(16, 8)
template <typename step_t>
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
uint8_t base,
step_t step) {
const Vectorized<uint8_t> base_vec(base);
const Vectorized<uint8_t> step_vec(step);
const uint8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_u8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<uint8_t> inline operator>>(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return x >> z;
}
template <>
Vectorized<uint8_t> inline operator<<(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return vshlq_u8(a, vreinterpretq_s8_u8(z));
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b,
uint64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<uint8_t> inline operator/(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t y = b;
return x / y;
}
template <>
Vectorized<uint8_t> inline clamp(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min,
const Vectorized<uint8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<uint8_t> inline clamp_max(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<uint8_t> inline clamp_min(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u8x8 = vget_low_u8(src);
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u8x8 = vget_low_u8(src);
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));

View File

@ -0,0 +1,192 @@
#include <ATen/cuda/CUDAGreenContext.h>
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms,
std::optional<uint32_t> device_id) {
#if CUDA_HAS_GREEN_CONTEXT
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (current == context_) {
TORCH_CHECK(
false,
"attempting to overwrite current green ctx "
"when it is active!");
}
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
}
// Take ownership of other's resources
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
}
return *this;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext::~GreenContext() noexcept{
#if CUDA_HAS_GREEN_CONTEXT
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void GreenContext::setContext() {
#if CUDA_HAS_GREEN_CONTEXT
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
at::cuda::CUDAEvent ev;
ev.record(current_stream);
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (!current) {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
} else {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
}
// currently hardcodes the new green context to use the default stream
// TODO(eqy): consider creating a new stream if e.g., it allows interop
// with CUDA Graph captures etc.
auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
void GreenContext::popContext() {
#if CUDA_HAS_GREEN_CONTEXT
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());
CUcontext popped;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
TORCH_INTERNAL_ASSERT(
popped == context_, "expected popped context to be the current ctx");
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
} // namespace at::cuda

View File

@ -0,0 +1,53 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext {
public:
GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current
void setContext();
void popContext();
private:
#if CUDA_HAS_GREEN_CONTEXT
int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};
} // namespace at::cuda

View File

@ -70,11 +70,7 @@
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
#endif
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
#if !defined(USE_ROCM)
namespace at_cuda_detail {
#endif
#if defined(USE_ROCM)
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
@ -96,10 +92,6 @@ template <>
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
#if !defined(USE_ROCM)
} // namespace at_cuda_detail
#endif
#endif
#if !defined(USE_ROCM)
@ -121,7 +113,7 @@ struct cuda_type<c10::Half> {
using type = __half;
};
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
#if !defined(USE_ROCM)
template<>
struct cuda_type<c10::BFloat16> {
@ -203,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
}
#if !CUB_SUPPORTS_FUTURE_VALUE()
template<typename ValueT, typename InputIteratorT>
struct chained_iterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = ValueT;
using pointer = ValueT*;
using reference = ValueT&;
InputIteratorT iter;
ValueT *first;
difference_type offset = 0;
__device__ ValueT operator[](difference_type i) {
i += offset;
if (i == 0) {
return *first;
} else {
return ValueT(iter[i - 1]);
}
}
__device__ chained_iterator operator+(difference_type i) {
return chained_iterator{iter, first, i};
}
__device__ ValueT operator*() {
return (*this)[0];
}
};
#endif
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
@ -277,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using tuple = typename ArgIndexInputIterator::value_type;
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
if (x.key == 0) {
return *first_elem_ptr;
} else {
return x.value;
}
};
auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i + 1,
output + i,
@ -303,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
@ -555,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
input + i, first_elem_ptr};
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i,
output + i,
@ -572,7 +504,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}

View File

@ -10,14 +10,6 @@
#define CUB_VERSION 200001
#endif
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
// https://github.com/NVIDIA/cub/pull/306
#if CUB_VERSION >= 101300
#define CUB_SUPPORTS_NV_BFLOAT16() true
#else
#define CUB_SUPPORTS_NV_BFLOAT16() false
#endif
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// https://github.com/NVIDIA/cub/pull/326
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
@ -28,14 +20,6 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
// cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_FUTURE_VALUE() true
#else
#define CUB_SUPPORTS_FUTURE_VALUE() false
#endif
// There were many bc-breaking changes in major version release of CCCL v3.0.0
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
#if CUB_VERSION >= 200800

View File

@ -0,0 +1,23 @@
#include <ATen/detail/XLAHooksInterface.h>
namespace at {
namespace detail {
const XLAHooksInterface& getXLAHooks() {
auto create_impl = [] {
// Create XLA hooks using the registry
auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
if (hooks) {
return hooks;
}
// If hooks creation fails, fall back to default implementation
return std::make_unique<XLAHooksInterface>();
};
static auto hooks = create_impl();
return *hooks;
}
} // namespace detail
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
} // namespace at

View File

@ -0,0 +1,79 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
constexpr const char* XLA_HELP =
"This error has occurred because you are trying "
"to use some XLA functionality, but the XLA library has not been "
"loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
~XLAHooksInterface() override = default;
void init() const override {
TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
}
virtual bool hasXLA() const {
return false;
}
virtual std::string showConfig() const {
TORCH_CHECK(
false,
"Cannot query detailed XLA version without torch_xla library. ",
XLA_HELP);
}
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
}
virtual DeviceIndex getCurrentDevice() const override {
TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
}
Device getDeviceFromPtr(void* /*data*/) const override {
TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
}
Allocator* getPinnedMemoryAllocator() const override {
TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
}
bool isPinnedPtr(const void* data) const override {
return false;
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
}
};
struct TORCH_API XLAHooksArgs {};
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
#define REGISTER_XLA_HOOKS(clsname) \
C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
namespace detail {
TORCH_API const XLAHooksInterface& getXLAHooks();
} // namespace detail
} // namespace at
C10_DIAGNOSTIC_POP()

View File

@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
try {
mkldnn_matmul_i8i8i32(self, mat2, result);
dispatched = true;
} catch (const std::exception& e) {
} catch ([[maybe_unused]] const std::exception& e) {
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
}
}

View File

@ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
int64_t c = self.size(-3);
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "

View File

@ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
/* GCC 14.2 (RISC-V RVV) ICE workaround:
* Avoid single-statement read-modify-write on MEM_REF like:
* *input_tile_val =
* __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
* This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
* with -march=rv64gcv. Use a temporary then write back.
* Do NOT refactor into the single-statement form. Clang is unaffected.
*/
vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
*input_tile_val = tmp_input_tile_val;
}
inline void winograd_f2k3_output_transform_inplace__rvv(
@ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
* Keep the temporary + write-back pattern to avoid ICE.
* Do NOT rewrite into:
* *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
*/
vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
*input_tile_val = tmp_output_tile_val;
}
inline vfloat32m1_t
@ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv(
const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
* Keep the temporary + write-back pattern to avoid ICE.
* Do NOT rewrite into:
* *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
*/
vfloat32m1x4_t tmp_transform = *transform;
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
*transform = tmp_transform;
}
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {

View File

@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
}
}
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
// Short circuit on empty result
if (result.numel() == 0) {
return result;
}
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
// end Lt path
} else {
// No Lt, we use a GEMM instead
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
}
} // end GEMM path
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (useLtInterface && activation == Activation::GELU) {
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -1,18 +1,17 @@
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/OpMathType.h>
#include <ATen/native/cuda/thread_constants.h>
#include <thrust/tuple.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <tuple>
namespace at::native {
template<int N>
@ -62,7 +61,11 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
#if defined(__HIP__)
results[i] = c10::guts::apply(f, args[i]);
#else
results[i] = std::apply(f, args[i]);
#endif
}
}

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
constexpr int MAX_BLOCK_SIZE = 1024;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
// first the reductions each thread does separately
scalar_t sum = static_cast<scalar_t>(0);
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
#if defined(USE_ROCM)
constexpr int UNRL = 4; // load deserilize factor
scalar_t tmp[UNRL];
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
#pragma unroll
for (int u = 0; u < UNRL; u++)
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
#pragma unroll
for (int u = 0; u < UNRL; u++)
if (x+u*blockDim.x < tensor.size(2))
sum += tmp[u];
}
#else
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
#endif
}
__shared__ scalar_t shared[C10_WARP_SIZE];
SumReduceOp<scalar_t> reduce_op;

View File

@ -653,8 +653,14 @@ struct ReduceOp {
}
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// TODO(PaulZhang12): AMD and internal
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);

View File

@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
const int64_t k,
const int64_t N_padded,
const IndexType last_dim_padded) {
int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
if (linear_idx >= N_padded) {
return;
}

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -466,7 +466,11 @@ struct ReduceJitOp {
__syncthreads();
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = reducer::warp_shfl_down(value[i], offset);

View File

@ -441,7 +441,7 @@ kernel void applySYRK(
uint3 tid [[thread_position_in_threadgroup]],
uint3 tgid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]]) {
uint warp_id [[simdgroup_index_in_threadgroup]]) {
const uint tx = tid.x;
const uint ty = tid.y;
const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32;
@ -474,11 +474,8 @@ kernel void applySYRK(
(actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0);
if (use_simdgroup) {
uint warp_id = sgitg;
simdgroup_matrix<float, 8, 8> negative_identity =
simdgroup_matrix<float, 8, 8>(-1.0);
simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0);
simdgroup_matrix<float, 8, 8> Prod;
simdgroup_matrix<float, 8, 8> Afrag;
simdgroup_matrix<float, 8, 8> Bfrag;
@ -521,8 +518,7 @@ kernel void applySYRK(
/* transpose = */ upper);
simdgroup_multiply(Prod, Afrag, Bfrag);
simdgroup_multiply(Prod, Prod, negative_identity);
simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod);
simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag);
}
simdgroup_store(

View File

@ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
}
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
if ([maskedMM dataType] != MPSDataTypeFloat32) {
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
}
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if ([maskedMM dataType] != qTensor.dataType) {
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
}
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
@ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
@ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = output;
graph->attnTensor = sm;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);

View File

@ -338,6 +338,8 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
}
}
map_mps_decomposition_error_code_to_blas(info);
}
static void linalg_solve_out_mps_impl(const Tensor& A,
@ -1448,20 +1450,6 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
}
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
return std::tie(LU, pivots);
}
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
Tensor LU = at::empty({0}, A.options());
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
return std::make_tuple(std::move(LU), std::move(pivots));
}
TORCH_IMPL_FUNC(lu_unpack_out_mps)
(const Tensor& LU_data,
const Tensor& LU_pivots,

View File

@ -706,6 +706,7 @@
variants: function, method
dispatch:
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all
tags: reduction
- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
@ -715,6 +716,7 @@
cpp_no_default_args: ['dim']
dispatch:
CompositeExplicitAutograd: all_dims_default
tags: reduction
- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -723,6 +725,7 @@
CPU, CUDA: all_out
MPS: all_out_mps
MTIA: all_out_mtia
tags: reduction
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -731,13 +734,16 @@
CPU, CUDA: all_dims_out
CompositeExplicitAutograd: all_dims_out_default
cpp_no_default_args: ['dim']
tags: reduction
- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
variants: function, method
@ -749,14 +755,14 @@
device_check: NoCheck # TensorIterator
structured_delegate: any.out
variants: function, method
tags: core
tags: [core, reduction]
- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.dims_out
variants: function, method
cpp_no_default_args: ['dim']
tags: core
tags: [core, reduction]
dispatch:
CompositeExplicitAutograd: any_dims_default
@ -766,6 +772,7 @@
dispatch:
CPU, CUDA: any_out
MPS: any_out_mps
tags: reduction
- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -774,13 +781,16 @@
CPU, CUDA: any_dims_out
CompositeExplicitAutograd: any_dims_out_default
cpp_no_default_args: ['dim']
tags: reduction
- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
@ -826,25 +836,27 @@
structured_delegate: argmax.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: argmax_out
MPS: argmax_out_mps
tags: reduction
- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
structured_delegate: argmin.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: argmin_out
MPS: argmin_out_mps
tags: reduction
- func: acosh(Tensor self) -> Tensor
variants: function, method
@ -1869,12 +1881,14 @@
CUDA: count_nonzero_cuda
MPS: count_nonzero_mps
autogen: count_nonzero.dim_IntList_out
tags: reduction
- func: count_nonzero(Tensor self, int? dim=None) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: count_nonzero
autogen: count_nonzero.out
tags: reduction
- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
variants: function, method
@ -3795,19 +3809,23 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: logsumexp
tags: reduction
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
# calls squeeze
CompositeExplicitAutogradNonFunctional: logsumexp_out
tags: reduction
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
@ -3857,6 +3875,7 @@
device_check: NoCheck # TensorIterator
structured_delegate: aminmax.out
variants: function, method
tags: reduction
- func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
device_check: NoCheck # TensorIterator
@ -3864,6 +3883,7 @@
dispatch:
CPU, CUDA, MTIA: aminmax_out
MPS: aminmax_out_mps
tags: reduction
- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor
dispatch:
@ -3879,7 +3899,7 @@
variants: function, method
dispatch:
QuantizedCPU, QuantizedCUDA: qmax
tags: core
tags: [core, reduction]
- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
@ -3889,13 +3909,16 @@
dispatch:
CPU, CUDA, MTIA: max_out
MPS: max_out_mps
tags: reduction
- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
tags: reduction
- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
variants: function
@ -3908,13 +3931,14 @@
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amax.out
tags: core
tags: [core, reduction]
- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA, MTIA: amax_out
MPS: amax_out_mps
tags: reduction
# Return: (Tensor output, Tensor indices)
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
@ -3976,13 +4000,14 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mean
tags: core
tags: [core, reduction]
# For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this.
- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: mean_dtype_out
tags: reduction
- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: mean.out
@ -3990,7 +4015,7 @@
variants: function, method
dispatch:
QuantizedCPU: mean_quantized_cpu
tags: core
tags: [core, reduction]
- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -3999,13 +4024,16 @@
CPU, CUDA: mean_out
MPS: mean_out_mps
QuantizedCPU: mean_out_quantized_cpu
tags: reduction
- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # Composite
@ -4068,7 +4096,7 @@
variants: function, method
dispatch:
QuantizedCPU, QuantizedCUDA: qmin
tags: core
tags: [core, reduction]
- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
@ -4078,24 +4106,28 @@
dispatch:
CPU, CUDA, MTIA: min_out
MPS: min_out_mps
tags: reduction
- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
tags: reduction
- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amin.out
tags: core
tags: [core, reduction]
- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA, MTIA: amin_out
MPS: amin_out_mps
tags: reduction
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
@ -5860,6 +5892,7 @@
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
autogen: sum.out
tags: reduction
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
# TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype
@ -5870,11 +5903,12 @@
NestedTensorCPU: NestedTensor_sum_dim_CPU
SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed
tags: core
tags: [core, reduction]
- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -5882,9 +5916,11 @@
dispatch:
CPU, CUDA: sum_out
MPS: sum_out_mps
tags: reduction
- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
# TODO: this function will be replaced once nested expand semantics have been settled on
- func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
@ -5896,11 +5932,13 @@
dispatch:
CPU, CUDA: nansum
MPS: nansum_mps
tags: reduction
- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: nansum_out
MPS: nansum_out_mps
tags: reduction
- func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
variants: function, method
@ -5964,11 +6002,13 @@
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
@ -5977,16 +6017,19 @@
CPU, CUDA: std
MPS: std_mps
QuantizedCPU: std_quantized_cpu
tags: reduction
- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
@ -5995,42 +6038,51 @@
CPU, CUDA: std_mean
MPS: std_mean_mps
autogen: std_mean.correction_out
tags: reduction
- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: std_out
QuantizedCPU: std_out_quantized_cpu
tags: reduction
- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
@ -6039,13 +6091,13 @@
CPU, CUDA: prod
MPS: prod_mps
autogen: prod.out
tags: core
tags: [core, reduction]
- func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: prod.int_out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
- func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6053,13 +6105,16 @@
dispatch:
CPU, CUDA: prod_out
MPS: prod_out_mps
tags: reduction
- func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: t(Tensor(a) self) -> Tensor(a)
device_check: NoCheck
@ -6520,11 +6575,12 @@
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
cpp_no_default_args: ["unbiased"]
- func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
@ -6534,43 +6590,51 @@
CPU, CUDA: var
MPS: var_mps
MTIA: var_mtia
tags: core
tags: [core, reduction]
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: var_out
tags: reduction
- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
@ -6579,15 +6643,18 @@
CPU, CUDA: var_mean
MPS: var_mean_mps
autogen: var_mean.correction_out
tags: reduction
- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a)
variants: method
@ -6847,6 +6914,7 @@
dispatch:
CompositeExplicitAutograd: norm
autogen: norm.ScalarOpt_dtype_out
tags: reduction
- func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor
device_check: NoCheck # TensorIterator
@ -6854,6 +6922,7 @@
dispatch:
CompositeExplicitAutograd: norm
autogen: norm.Scalar_out
tags: reduction
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
structured_delegate: norm.dtype_out
@ -6861,6 +6930,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm
tags: reduction
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
structured_delegate: norm.out
@ -6868,6 +6938,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_norm
tags: reduction
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6875,6 +6946,7 @@
dispatch:
CPU, CUDA: norm_dtype_out
MPS: norm_dtype_out_mps
tags: reduction
- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6882,21 +6954,26 @@
dispatch:
CPU, CUDA: norm_out
MPS: norm_out_mps
tags: reduction
# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd
- func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
variants: method, function
@ -10082,12 +10159,14 @@
CPU, CUDA: min
MPS: min_mps
QuantizedCPU: min_quantized_cpu
tags: [reduction]
- func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: min_unary_out
QuantizedCPU: min_quantized_unary_out
tags: [reduction]
- func: fmin(Tensor self, Tensor other) -> Tensor
structured_delegate: fmin.out
@ -10110,6 +10189,7 @@
CPU, CUDA: max
MPS: max_mps
QuantizedCPU: max_quantized_cpu
tags: [reduction]
- func: fmax(Tensor self, Tensor other) -> Tensor
structured_delegate: fmax.out
@ -10156,6 +10236,7 @@
dispatch:
CPU, CUDA: max_unary_out
QuantizedCPU: max_quantized_unary_out
tags: [reduction]
- func: minimum(Tensor self, Tensor other) -> Tensor
structured_delegate: minimum.out
@ -10275,6 +10356,7 @@
device_check: NoCheck # TensorIterator
structured_delegate: all.all_out
variants: method, function
tags: reduction
- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
@ -10283,6 +10365,7 @@
CPU, CUDA: all_all_out
MTIA: all_all_out_mtia
MPS: all_all_out_mps
tags: reduction
- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -10290,7 +10373,7 @@
variants: method, function
dispatch:
SparseCPU, SparseCUDA, SparseMPS: any_sparse
tags: core
tags: [core, reduction]
- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
@ -10298,6 +10381,7 @@
dispatch:
CPU, CUDA: any_all_out
MPS: any_all_out_mps
tags: reduction
- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -14073,16 +14157,10 @@
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor
MPS: linalg_lu_factor_mps
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor_out
MPS: linalg_lu_factor_out_mps
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
python_module: linalg
@ -14349,6 +14427,7 @@
python_module: linalg
variants: function
structured_delegate: linalg_vector_norm.out
tags: reduction
- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
@ -14356,6 +14435,7 @@
dispatch:
CPU, CUDA: linalg_vector_norm_out
MPS: linalg_vector_norm_out_mps
tags: reduction
- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
python_module: linalg

View File

@ -40,15 +40,7 @@
#include <thrust/iterator/discard_iterator.h>
#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
#define IS_CUSPARSE11_AVAILABLE() 1
#else
#define IS_CUSPARSE11_AVAILABLE() 0
#endif
#if IS_CUSPARSE11_AVAILABLE()
#include <library_types.h>
#endif
namespace at::native {
@ -103,17 +95,9 @@ struct csrMatrixRef {
int nnz_{0};
std::vector<int> size_{};
#if IS_CUSPARSE11_AVAILABLE()
cusparseSpMatDescr_t description_{0};
#else
cusparseMatDescr_t description_{0};
#endif
cusparseSpMatDescr_t description_{0};
csrMatrixRef() {
#if !IS_CUSPARSE11_AVAILABLE()
create_general_description_(description_);
#endif
}
csrMatrixRef() = default;
csrMatrixRef(
int* csr_indices,
@ -126,7 +110,6 @@ struct csrMatrixRef {
csr_values_{csr_values},
nnz_{nnz},
size_{size} {
#if IS_CUSPARSE11_AVAILABLE()
cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
&description_,
@ -140,17 +123,10 @@ struct csrMatrixRef {
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO,
cuda_data_type));
#else
create_general_description_(description_);
#endif
}
~csrMatrixRef() {
#if IS_CUSPARSE11_AVAILABLE()
cusparseDestroySpMat(description_);
#else
cusparseDestroyMatDescr(description_);
#endif
cusparseDestroySpMat(description_);
}
int size(int index) const {
@ -196,8 +172,6 @@ struct csrOutput {
}
};
#if IS_CUSPARSE11_AVAILABLE()
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
template <class scalar_t>
@ -396,284 +370,6 @@ template struct CusparseMatrixMultiplyOp<float>;
template struct CusparseMatrixMultiplyOp<double>;
#else // if not IS_CUSPARSE11_AVAILABLE()
using DcsrMatrixRef = csrMatrixRef<double>;
using ScsrMatrixRef = csrMatrixRef<float>;
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
template <class scalar_t>
struct CusparseMatrixMultiplyOp {
csrOutput operator()(
const csrMatrixRef<scalar_t>& lhs,
const csrMatrixRef<scalar_t>& rhs,
Tensor &output_values,
Tensor &output_indices)
{
static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
}
};
// Specializacion for `A @ B` operation for double values with cuSparse
template<> struct CusparseMatrixMultiplyOp<double> {
csrgemm2Info_t gemm2Info_;
CusparseMatrixMultiplyOp() {
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
}
~CusparseMatrixMultiplyOp() {
cusparseDestroyCsrgemm2Info(gemm2Info_);
}
csrOutput operator ()(
const DcsrMatrixRef& lhs,
const DcsrMatrixRef& rhs,
Tensor &output_values,
Tensor &output_indices) {
double alpha = 1.0;
DcsrMatrixRef empty;
return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
}
csrOutput Dgemm2(
const DcsrMatrixRef& A,
const DcsrMatrixRef& B,
const DcsrMatrixRef& C,
const double* alpha,
const double* beta,
Tensor &output_values,
Tensor &output_indices) {
void* buffer_{nullptr};
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
csrOutput out({A.size(0), B.size(1)});
int innerSize = confirm_mult_size(A.size_, B.size_);
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
// Compute needed buffer size
size_t new_bubber_sz;
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
gemm2Info_,
&new_bubber_sz));
// (Re)allocate buffer if needed
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
buffer_ = data_ptr.get();
// Find the resulting non-zero pattern.
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_pointers_.data_ptr<int>(),
&out.nnz_,
gemm2Info_,
buffer_));
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
// Perform the gemm2 operation for doubles
// out = alpha A B + beta C
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_values_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_values_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_values_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_values_.data_ptr<double>(),
out.csr_pointers_.data_ptr<int>(),
out.csr_indices_.data_ptr<int>(),
gemm2Info_,
buffer_));
return out;
}
};
// Specializacion for `A @ B` operation for float values with cuSparse
template<> struct CusparseMatrixMultiplyOp<float> {
csrgemm2Info_t gemm2Info_;
CusparseMatrixMultiplyOp() {
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
}
~CusparseMatrixMultiplyOp() {
cusparseDestroyCsrgemm2Info(gemm2Info_);
}
csrOutput operator()(
const ScsrMatrixRef& lhs,
const ScsrMatrixRef& rhs,
Tensor &output_values,
Tensor &output_indices) {
float alpha = 1.0;
ScsrMatrixRef empty;
return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
}
csrOutput Sgemm2(
const ScsrMatrixRef& A,
const ScsrMatrixRef& B,
const ScsrMatrixRef& C,
const float* alpha,
const float* beta,
Tensor &output_values,
Tensor &output_indices) {
void* buffer_{nullptr};
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
csrOutput out({A.size(0), B.size(1)});
int innerSize = confirm_mult_size(A.size_, B.size_);
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
// Compute needed buffer size
size_t new_bubber_sz;
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
gemm2Info_,
&new_bubber_sz));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
buffer_ = data_ptr.get();
// Find the resulting non-zero pattern.
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_pointers_.data_ptr<int>(),
&out.nnz_,
gemm2Info_,
buffer_));
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
// Perform the gemm2 operation for doubles
// out = alpha A B + beta C
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_values_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_values_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_values_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_values_.data_ptr<float>(),
out.csr_pointers_.data_ptr<int>(),
out.csr_indices_.data_ptr<int>(),
gemm2Info_,
buffer_));
return out;
}
};
#endif // IS_CUSPARSE11_AVAILABLE()
template <typename scalar_t>
void sparse_sparse_matmul_cuda_kernel(
Tensor& result,
@ -815,19 +511,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
auto output = at::native::empty_like(mat1_);
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
#if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM)
#if !defined(USE_ROCM)
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM)
#else
// ROCm does not support half and bfloat16 types for sparse_matmul
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#else
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#endif
return output;
}

View File

@ -62,7 +62,6 @@ kernel void build_row_ptr_from_sorted_rows_by_batch(
template <typename T>
kernel void spmm_bmm_coo_rows_grouped(
device const long* rows [[buffer(0)]],
device const long* cols [[buffer(1)]],
device const T* vals [[buffer(2)]],
device const T* dense [[buffer(3)]],
@ -73,7 +72,6 @@ kernel void spmm_bmm_coo_rows_grouped(
uint3 ltid [[thread_position_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]])
{
const uint B = dims.x;
const uint I = dims.y;
const uint J = dims.z;
const uint K = dims.w;
@ -321,7 +319,6 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
spmm_bmm_coo_rows_grouped<DTYPE>( \
device const long* rows [[buffer(0)]], \
device const long* cols [[buffer(1)]], \
device const DTYPE* vals [[buffer(2)]], \
device const DTYPE* dense [[buffer(3)]], \

View File

@ -93,3 +93,7 @@
This operator does not support cudagraphs. The presence of this tag on an operator will cause
Inductor to split the graph around this operator. Note that operators without this tag may still
not support CUDAGraphs. Inductor may have other hardcoded lists around that.
- tag: reduction
desc: |
This tag indicates that an operator performs a reduction operation, computing aggregate values
(sum, mean, max, min, etc.) across one or more dimensions of the input tensor(s).

View File

@ -202,7 +202,6 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.1
update_hint_regression,compile_time_instruction_count,1645000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 3184000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 4595000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1048000000 1096000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17720000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3442000000 3152000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9239000000 8301000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4820968837 4958000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9554000000 9990000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7618000000 8126000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -43,6 +43,7 @@ tolerance:
- doctr_reco_predictor
- drq
- phlippe_resnet
- pytorch_CycleGAN_and_pix2pix
higher_bf16:
- doctr_reco_predictor

View File

@ -127,7 +127,7 @@ def trainbench(
bwd_time = bwd_start_event.elapsed_time(bwd_end_event)
return fwd_time, bwd_time
creator_args = creator_args = {
creator_args = {
"seqLength": seqLength,
"numLayers": numLayers,
"inputSize": inputSize,

View File

@ -12,7 +12,7 @@ def modeldef(request, net_name, executor, fuser):
# Given a 'net_name' provided by generate_tests, build the thing
name, rnn_creator, context = get_nn_runners(net_name)[0]
creator_args = creator_args = {
creator_args = {
"seqLength": 100,
"numLayers": 1,
"inputSize": 512,

View File

@ -48,17 +48,89 @@ PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,Fa
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
@ -71,6 +143,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000

1 Benchmarking Framework Benchmarking Module Name Case Name tag run_backward Execution Time Peak Memory (KB)
48 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 58.529255 0.000000
49 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 54.645077 0.000000
50 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 4.397014 0.000000
51 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.739000 0.000000
52 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.786000 0.000000
53 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.911000 0.000000
54 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 59.243500 0.000000
55 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.066000 0.000000
56 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.076000 0.000000
57 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.225000 0.000000
58 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.947691 0.000000
59 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.291000 0.000000
60 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.224000 0.000000
61 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.912000 0.000000
62 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.925851 0.000000
63 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.0240000 0.000000
64 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.069000 0.000000
65 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.938000 0.000000
66 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.308320 0.000000
67 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.091000 0.000000
68 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.710000 0.000000
69 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.502000 0.000000
70 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.787743 0.000000
71 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.863000 0.000000
72 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.939000 0.000000
73 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.603000 0.000000
74 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 7.978539 0.000000
75 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.741000 0.000000
76 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.757000 0.000000
77 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 8.774000 0.000000
78 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 159.754860 0.000000
79 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 165.552000 0.000000
80 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 165.755000 0.000000
81 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 165.714000 0.000000
82 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 165.360235 0.000000
83 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 168.376000 0.000000
84 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 169.604000 0.000000
85 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 168.428000 0.000000
86 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 3.928136 0.000000
87 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.402000 0.000000
88 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.567000 0.000000
89 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 4.020000 0.000000
90 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 56.413499 0.000000
91 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 104.638000 0.000000
92 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.335000 0.000000
93 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.612000 0.000000
94 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.925090 0.000000
95 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.110000 0.000000
96 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.389000 0.000000
97 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.195000 0.000000
98 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.989000 0.000000
99 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.999000 0.000000
100 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.939000 0.000000
101 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.980000 0.000000
102 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.408000 0.000000
103 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.647000 0.000000
104 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.476000 0.000000
105 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.784000 0.000000
106 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.583000 0.000000
107 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.083000 0.000000
108 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.663000 0.000000
109 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.283000 0.000000
110 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.986000 0.000000
111 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.676000 0.000000
112 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.618000 0.000000
113 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.982000 0.000000
114 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.698000 0.000000
115 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.899000 0.000000
116 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.741000 0.000000
117 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.182000 0.000000
118 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.290000 0.000000
119 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.744000 0.000000
120 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.820000 0.000000
121 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.298000 0.000000
122 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.988000 0.000000
123 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.689000 0.000000
124 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.695000 0.000000
125 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.978000 0.000000
126 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.934000 0.000000
127 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.217000 0.000000
128 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.215000 0.000000
129 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.115000 0.000000
130 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.974000 0.000000
131 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.828000 0.000000
132 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.879000 0.000000
133 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.197000 0.000000
134 PyTorch logical_and logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool short False 78.404254 0.000000
135 PyTorch logical_and logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 5.354032 0.000000
136 PyTorch logical_and logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 54.072783 0.000000
143 PyTorch baddbmm baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16 short False 6.476986 0.000000
144 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32 short False 266.065131 0.000000
145 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16 short False 295.503063 0.000000
146 PyTorch all all_M1_N1_K1_cpu short False 5.773000 0.000000
147 PyTorch all all_M64_N64_K64_cpu short False 89.427000 0.000000
148 PyTorch all all_M64_N64_K128_cpu short False 120.119000 0.000000
149 PyTorch cat cat_sizes(1,1,1)_N2_dim0_cpu short False 4.301950 0.000000
150 PyTorch cat cat_sizes(512,512,2)_N2_dim1_cpu short False 99.093415 0.000000
151 PyTorch cat cat_sizes(128,1024,2)_N2_dim1_cpu short False 96.771578 0.000000

View File

@ -580,6 +580,9 @@ class BenchmarkRunner:
else "unknown"
)
# Extract operator name from test_name
operator_name = test_name.split("_")[0]
# Create the record
@dataclass
class BenchmarkInfo:
@ -593,6 +596,7 @@ class BenchmarkRunner:
name: str
type: str
origins: list[str]
extra_info: dict[str, Any]
@dataclass
class MetricInfo:
@ -618,10 +622,14 @@ class BenchmarkRunner:
"device": device,
"arch": device_arch,
"use_compile": use_compile,
"operator_name": operator_name,
},
),
model=ModelInfo(
name=test_name, type="micro-benchmark", origins=["pytorch"]
name=test_name,
type="micro-benchmark",
origins=["pytorch"],
extra_info={"operator_name": operator_name},
),
metric=MetricInfo(
name="latency",

View File

@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu"],
"dtype": [torch.float],
"dtype": [torch.float, torch.bfloat16],
},
tags=["short"],
)
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.int32],
"dtype_two": [torch.int32],
"dtype_one": [torch.int32, torch.uint8],
"dtype_two": [torch.int32, torch.uint8],
},
tags=["short"],
)
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.int8, torch.int32],
dtype_two=[torch.int8, torch.int32],
dtype_one=[torch.int8, torch.int32, torch.uint8],
dtype_two=[torch.int8, torch.int32, torch.uint8],
tags=["long"],
)

View File

@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
@ -1729,8 +1729,10 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2023,6 +2025,9 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/Stream.cpp",
"torch/csrc/cuda/Graph.cpp",
"torch/csrc/cuda/MemPool.cpp",
"torch/csrc/cuda/GreenContext.cpp",
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",

View File

@ -9,6 +9,7 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -13,7 +13,17 @@
namespace c10::CachingAllocator {
// "large" allocations may be packed in 20 MiB blocks
const size_t kLargeBuffer = 20971520;
constexpr size_t kLargeBuffer = 20971520;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// all sizes are rounded to at least 512 bytes
constexpr size_t kMinBlockSize = 512;
// largest "small" allocation is 1 MiB
constexpr size_t kSmallSize = 1048576;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB
constexpr size_t kRoundLarge = 2097152;
// A utility class for tokenizing allocator configuration strings into discrete
// parts. For example, the config string:

View File

@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
case Backend::PrivateUse1:
return DispatchKey::PrivateUse1;
default:
throw std::runtime_error("Unknown backend");
TORCH_CHECK(false, "Unknown backend");
}
}

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -336,7 +336,7 @@ class C10_API Scalar {
} else if (isBoolean()) {
return ScalarType::Bool;
} else {
throw std::runtime_error("Unknown scalar type.");
TORCH_CHECK(false, "Unknown scalar type.");
}
}

View File

@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
case c10::ScalarType::Float4_e2m1fn_x2:
return std::make_pair("float4_e2m1fn_x2", "");
default:
throw std::runtime_error("Unimplemented scalar type");
TORCH_CHECK(false, "Unimplemented scalar type");
}
}

View File

@ -137,22 +137,6 @@ inline ScalarType toQIntType(ScalarType t) {
}
}
inline ScalarType toUnderlying(ScalarType t) {
switch (t) {
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
[[fallthrough]];
case ScalarType::QUInt2x4:
return ScalarType::Byte;
case ScalarType::QInt8:
return ScalarType::Char;
case ScalarType::QInt32:
return ScalarType::Int;
default:
return t;
}
}
inline bool isSignedType(ScalarType t) {
#define CASE_ISSIGNED(name) \
case ScalarType::name: \

View File

@ -1,6 +1,7 @@
#pragma once
#include <cstddef>
#include <new>
namespace c10 {
@ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096;
// since the default thp pagesize is 2MB, enable thp only
// for buffers of size 2MB or larger to avoid memory bloating
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
// Cache line size used to avoid false sharing between threads. Falls back to 64
// bytes if C++17 feature is unavailable.
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
} // namespace c10

View File

@ -87,9 +87,7 @@ bool ThreadPool::inThreadPool() const {
}
void ThreadPool::run(std::function<void()> func) {
if (threads_.empty()) {
throw std::runtime_error("No threads to run a task");
}
TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will

View File

@ -131,15 +131,6 @@ namespace Native {
* notifyCaptureDestroy.
*/
constexpr size_t kMinBlockSize =
512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
constexpr size_t kSmallBuffer =
2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
static char SHAREABLE_HANDLE_VERSION = 2;
enum ShareableHandleType : char {
SHAREABLE_CUDA_MALLOC = 'c',
@ -941,7 +932,7 @@ class EventPool {
private:
struct PerDevicePool {
alignas(64) std::mutex mutex_;
alignas(hardware_destructive_interference_size) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
@ -3758,11 +3749,6 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
static constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
class NativeCachingAllocator : public CUDAAllocator {
private:
@ -4483,7 +4469,10 @@ struct BackendStaticInitializer {
if (key == "backend") {
tokenizer.checkToken(++i, ":");
i++; // Move to the value after the colon
if (tokenizer[i] == "cudaMallocAsync"
// break up token to trick hipify
if (tokenizer[i] ==
"c"
"udaMallocAsync"
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
|| tokenizer[i] == "hipMallocAsync"

View File

@ -913,7 +913,9 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
}
}
std::string name() override {
return "cudaMallocAsync";
// break up token to trick hipify
return "c"
"udaMallocAsync";
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
C10_CUDA_CHECK(

View File

@ -51,6 +51,17 @@
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuCtxFromGreenCtx, 12080) \
_(cuCtxGetCurrent, 12080) \
_(cuCtxPopCurrent, 12080) \
_(cuCtxPushCurrent, 12080) \
_(cuCtxSetCurrent, 12080) \
_(cuGreenCtxCreate, 12080) \
_(cuGreenCtxDestroy, 12080) \
_(cuDevSmResourceSplitByCount, 12080) \
_(cuDeviceGet, 12080) \
_(cuDeviceGetDevResource, 12080) \
_(cuDevResourceGenerateDesc, 12080) \
_(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030) \

View File

@ -18,7 +18,6 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <array>
#include <cstddef>
@ -41,106 +40,200 @@ namespace c10 {
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
///
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
/// the underlying constexpr calls, we rely on apparent-type dispatch for
/// inheritance. This should be fine because their memory format is the same,
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
/// However, you should prefer to use ArrayRef when possible, because its use
/// of TORCH_CHECK will lead to better user-facing error messages.
template <typename T>
class ArrayRef final : public HeaderOnlyArrayRef<T> {
class ArrayRef final {
public:
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
/// SmallVector.
using iterator = const T*;
using const_iterator = const T*;
using size_type = size_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
void debugCheckNullptrInvariant() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Data != nullptr || Length == 0,
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
}
public:
/// @name Constructors
/// @{
using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
/// Construct an empty ArrayRef.
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a std::vector.
/// This constructor is identical to the one in HeaderOnlyArrayRef, but we
/// include it to help with Class Template Argument Deduction (CTAD).
/// Without it, CTAD can fail sometimes due to the indirect constructor
/// inheritance. So we explicitly include this constructor.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
/// Construct an ArrayRef from a single element.
// TODO Make this explicit
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
constexpr ArrayRef(const T* data, size_t length)
: Data(data), Length(length) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a range.
constexpr ArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a SmallVector. This is templated in order to
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
/// copy-construct an ArrayRef.
/// NOTE: this is the only constructor that is not inherited from
/// HeaderOnlyArrayRef.
template <typename U>
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
: Data(Vec.data()), Length(Vec.size()) {
debugCheckNullptrInvariant();
}
template <
typename Container,
typename U = decltype(std::declval<Container>().data()),
typename = std::enable_if_t<
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
/* implicit */ ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(
!std::is_same_v<T, bool>,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct an ArrayRef from a std::array
template <size_t N>
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
/// @}
/// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
/// @name Simple Operations
/// @{
constexpr iterator begin() const {
return Data;
}
constexpr iterator end() const {
return Data + Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
constexpr reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// Check if all elements in the array satisfy the given expression
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
return std::all_of(cbegin(), cend(), pred);
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return Length == 0;
}
constexpr const T* data() const {
return Data;
}
/// size - Get the array size.
constexpr size_t size() const {
return Length;
}
/// front - Get the first element.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& front() const {
TORCH_CHECK(
!this->empty(), "ArrayRef: attempted to access front() of empty list");
return this->Data[0];
!empty(), "ArrayRef: attempted to access front() of empty list");
return Data[0];
}
/// back - Get the last element.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& back() const {
TORCH_CHECK(
!this->empty(), "ArrayRef: attempted to access back() of empty list");
return this->Data[this->Length - 1];
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
return Data[Length - 1];
}
/// equals - Check for element-wise equality.
constexpr bool equals(ArrayRef RHS) const {
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
}
/// slice(n, m) - Take M elements of the array starting at element N
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr ArrayRef<T> slice(size_t N, size_t M) const {
TORCH_CHECK(
N + M <= this->size(),
N + M <= size(),
"ArrayRef: invalid slice, N = ",
N,
"; M = ",
M,
"; size = ",
this->size());
return ArrayRef<T>(this->data() + N, M);
size());
return ArrayRef<T>(data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr ArrayRef<T> slice(size_t N) const {
TORCH_CHECK(
N <= this->size(),
"ArrayRef: invalid slice, N = ",
N,
"; size = ",
this->size());
return slice(N, this->size() - N); // should this slice be this->slice?
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
return slice(N, size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const T& operator[](size_t Index) const {
return Data[Index];
}
/// Vector compatibility
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& at(size_t Index) const {
TORCH_CHECK(
Index < this->Length,
Index < Length,
"ArrayRef: invalid index Index = ",
Index,
"; Length = ",
this->Length);
return this->Data[Index];
Length);
return Data[Index];
}
/// Disallow accidental assignment from a temporary.
@ -160,6 +253,13 @@ class ArrayRef final : public HeaderOnlyArrayRef<T> {
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
};

View File

@ -45,14 +45,7 @@ constexpr bool is_pod_v = is_pod<T>::value;
namespace guts {
#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__)
template <class F, class Tuple>
C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
}
#else
#if defined(__HIP__)
// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
// modified)

View File

@ -14,16 +14,6 @@ using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
// all sizes are rounded to at least 512 bytes
constexpr size_t kMinBlockSize = 512;
// largest "small" allocation is 1 MiB
constexpr size_t kSmallSize = 1048576;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB
constexpr size_t kRoundLarge = 2097152;
namespace {
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
@ -554,7 +544,7 @@ static void local_raw_delete(void* ptr);
class XPUAllocator : public DeviceAllocator {
private:
std::mutex mutex;
alignas(hardware_destructive_interference_size) std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
void add_allocated_block(Block* block) {

View File

@ -607,6 +607,12 @@ if(USE_CUDA)
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
endif()
endif()
if(NOT WIN32)
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"

View File

@ -16,7 +16,7 @@ find_path(vecLib_INCLUDE_DIR vecLib.h
DOC "vecLib include directory"
PATHS /System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
/System/Library/${__veclib_include_suffix}
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
NO_DEFAULT_PATH)

View File

@ -224,6 +224,12 @@ AMD/ROCm/HIP
- Jithun Nair (`jithunnair-amd <https://github.com/jithunnair-amd>`__)
- (emeritus) Junjie Bai (`bddppq <https://github.com/bddppq>`__)
XPU/Intel GPU
~~~~~~~~~~~~~
- Eikan Wang (`EikanWang <https://github.com/EikanWang>`__)
- Guangye Yu (`guangyey <https://github.com/guangyey>`__)
Build + CI
~~~~~~~~~~

View File

@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
```
## Green Contexts (experimental)
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
to enable more general carveout of SM resources for CUDA kernels.
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
```{eval-rst}
.. currentmodule:: torch.cuda.green_contexts
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
GreenContext
```
% This module needs to be documented. Adding here in the meantime
% for tracking purposes
@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
.. py:module:: torch.cuda.gds
```
```{eval-rst}
.. py:module:: torch.cuda.green_contexts
```
```{eval-rst}
.. py:module:: torch.cuda.jiterator
```

View File

@ -44,9 +44,9 @@ following invariants. More specifications about the IR can be found
- **Normalized**: There are no Python semantics within the graph. Submodules
from the original programs are inlined to form one fully flattened
computational graph.
- **Graph properties**: The graph is purely functional, meaning it does not
contain operations with side effects such as mutations or aliasing. It does
not mutate any intermediate values, parameters, or buffers.
- **Graph properties**: By default, the graph may contain both functional and
non-functional operators (including mutations). To obtain a purely functional
graph, use `run_decompositions()` which removes mutations and aliasing.
- **Metadata**: The graph contains metadata captured during tracing, such as a
stacktrace from user's code.
@ -56,8 +56,8 @@ Under the hood, `torch.export` leverages the following latest technologies:
called the Frame Evaluation API to safely trace PyTorch graphs. This
provides a massively improved graph capturing experience, with much fewer
rewrites needed in order to fully trace the PyTorch code.
- **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph
is decomposed/lowered to the ATen operator set.
- **AOT Autograd** ensures the graph is decomposed/lowered to the ATen operator
set. When using `run_decompositions()`, it can also provide functionalization.
- **Torch FX (torch.fx)** is the underlying representation of the graph,
allowing flexible Python-based transformations.
@ -444,23 +444,31 @@ saved_exported_program = torch.export.load('exported_program.pt2')
(training-export)=
## Export IR, Decompositions
## Export IR: Training vs Inference
The graph produced by `torch.export` returns a graph containing only
[ATen operators](https://pytorch.org/cppdocs/#aten), which are the basic unit of
computation in PyTorch. As there are over
3000 ATen operators, export provides a way to narrow down the operator set used
in the graph based on certain characteristics, creating different IRs.
computation in PyTorch. Export provides different IR levels based on your use case:
By default, export produces the most generic IR which contains all ATen
operators, including both functional and non-functional operators. A functional
operator is one that does not contain any mutations or aliasing of the inputs.
| IR Type | How to Obtain | Properties | Operator Count | Use Case |
|---------|---------------|------------|----------------|----------|
| Training IR | `torch.export.export()` (default) | May contain mutations | ~3000 | Training with autograd |
| Inference IR | `ep.run_decompositions(decomp_table={})` | Purely functional | ~2000 | Inference deployment |
| Core ATen IR | `ep.run_decompositions(decomp_table=None)` | Purely functional, highly decomposed | ~180 | Minimal backend support |
### Training IR (Default)
By default, export produces a **Training IR** which contains all ATen
operators, including both functional and non-functional (mutating) operators.
A functional operator is one that does not contain any mutations or aliasing
of the inputs, while non-functional operators may modify their inputs in-place.
You can find a list of all ATen operators
[here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)
and you can inspect if an operator is functional by checking
`op._schema.is_mutable`.
This generic IR can be used to train in eager PyTorch Autograd.
This Training IR, which may contain mutations, is designed for training use
cases and can be used with eager PyTorch Autograd.
```{code-cell}
import torch
@ -480,15 +488,18 @@ ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
```
However, if you want to use the IR for inference, or decrease the amount of
operators being used, you can lower the graph through the
{func}`ExportedProgram.run_decompositions` API. This method decomposes the
ATen operators into the ones specified in the decomposition table, and
functionalizes the graph.
### Inference IR (via run_decompositions)
By specifying an empty set, we're only performing functionalization, and does
not do any additional decompositions. This results in an IR which contains ~2000
operators (instead of the 3000 operators above), and is ideal for inference cases.
To obtain an **Inference IR** suitable for deployment, use the
{func}`ExportedProgram.run_decompositions` API. This method automatically:
1. Functionalizes the graph (removes all mutations and converts them to functional equivalents)
2. Optionally decomposes ATen operators based on the provided decomposition table
This produces a purely functional graph ideal for inference scenarios.
By specifying an empty decomposition table (`decomp_table={}`), you get just
the functionalization without additional decompositions. This produces an
Inference IR with ~2000 functional operators (compared to 3000+ in Training IR).
```{code-cell}
import torch
@ -514,11 +525,14 @@ As we can see, the previously in-place operator,
`torch.ops.aten.add_.default` has now been replaced with
`torch.ops.aten.add.default`, a functional operator.
We can also further lower this exported program to an operator set which only
contains the
### Core ATen IR
We can further lower the Inference IR to the
`Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__,
which is a collection of only ~180 operators. This IR is optimal for backends
who do not want to reimplement all ATen operators.
which contains only ~180 operators. This is achieved by passing `decomp_table=None`
(which uses the default decomposition table) to `run_decompositions()`. This IR
is optimal for backends who want to minimize the number of operators they need
to implement.
```{code-cell}
import torch

View File

@ -208,11 +208,13 @@ select = [
"PLC1802", # len({expression}) used as condition without comparison
"PLC0205", # string as __slots__
"PLC3002", # unnecessary-direct-lambda-call
"PLC0414", # Import alias does not rename original package
"PLE",
"PLR0133", # constant comparison
"PLR0206", # property with params
"PLR1722", # use sys exit
"PLR1736", # unnecessary list index
"PLW0127", # Self-assignment of variable
"PLW0129", # assert on string literal
"PLW0131", # named expr without context
"PLW0133", # useless exception statement

View File

@ -23,10 +23,12 @@ project-includes = [
project-excludes = [
# ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime",
"torch/_inductor/codegen/triton.py",
"tools/linter/adapters/test_device_bias_linter.py",
"tools/code_analyzer/gen_operators_yaml.py",
"torch/_inductor/runtime/triton_heuristics.py",
"torch/_inductor/runtime/triton_helpers.py",
"torch/_inductor/runtime/halide_helpers.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"tools/flight_recorder/components/types.py",

View File

@ -7,7 +7,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp

View File

@ -1,52 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <vector>
using torch::headeronly::HeaderOnlyArrayRef;
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
HeaderOnlyArrayRef<float> arr;
ASSERT_TRUE(arr.empty());
}
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
float val = 5.0f;
HeaderOnlyArrayRef<float> arr(val);
ASSERT_FALSE(arr.empty());
EXPECT_EQ(arr.size(), 1);
EXPECT_EQ(arr[0], val);
}
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr(vec);
ASSERT_FALSE(arr.empty());
EXPECT_EQ(arr.size(), 7);
for (size_t i = 0; i < arr.size(); i++) {
EXPECT_EQ(arr[i], i + 1);
EXPECT_EQ(arr.at(i), i + 1);
}
EXPECT_EQ(arr.front(), 1);
EXPECT_EQ(arr.back(), 7);
ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
}
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
auto res_vec = arr.vec();
for (size_t i = 0; i < vec.size(); i++) {
EXPECT_EQ(vec[i], res_vec[i]);
}
}
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
auto res_vec = arr.vec();
for (size_t i = 0; i < res_vec.size(); i++) {
EXPECT_EQ(vec[i + 3], res_vec[i]);
}
}

View File

@ -74,3 +74,19 @@ TEST(TestScalarType, operator_left_shift) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, toUnderlying) {
using torch::headeronly::ScalarType;
using torch::headeronly::toUnderlying;
EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char);
EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int);
#define DEFINE_CHECK(_, name) \
EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -311,9 +311,10 @@ void boxed_fill_infinity(
}
Tensor my_pad(Tensor t) {
std::vector<int64_t> padding = {1, 2, 2, 1};
std::string mode = "constant";
double value = 0.0;
return pad(t, {1, 2, 2, 1}, mode, value);
return pad(t, padding, mode, value);
}
void boxed_my_pad(
@ -341,9 +342,6 @@ void boxed_my_narrow(
}
Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);
@ -355,8 +353,9 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, {2, 5}, dtype);
return new_zeros(t, sizes, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
@ -430,7 +429,8 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
}
Tensor my_amax_vec(Tensor t) {
return amax(t, {0,1}, false);
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {

View File

@ -1166,7 +1166,7 @@ class TestFullyShardPrefetch(FSDPTest):
loss = model(inp)
events.clear()
loss.sum().backward()
expected_backward_events = expected_backward_events = [
expected_backward_events = [
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
# root explicit prefetch layers.2
("unshard", "layers.2", TrainingState.PRE_BACKWARD),

View File

@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
# the input shape was (1, 768), so that the forward gemm used
# cublaslt, and the backward used cublas.
# With the aforementioned PR, and with shape (1, 768),
# the cublas path is used both in forward and in backward,
# altering peak memory usage not accounting for cublaslt.
# Here we change the input shape to (2, 768), and that swaps
# the cublas/cublaslt selection in the forward/backward,
# but that does not affect the peak memory usage stored in `base_mem_mb`.
# Reasons for the flip:
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
# since the input preparation can swap matrices based on output
# row-/col-majorness.
inp = torch.randn(2, 768, device=device_type)
lin(inp).sum().backward()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()

View File

@ -127,8 +127,9 @@ def echo1(msg: str, exitcode: int = 0) -> str:
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
print(f"{msg} stdout from {rank}")
print(f"{msg} stderr from {rank}", file=sys.stderr)
for m in msg.split(","):
print(f"{m} stdout from {rank}")
print(f"{m} stderr from {rank}", file=sys.stderr)
return f"{msg}_{rank}"
@ -247,6 +248,13 @@ class _StartProcessesTest(TestCase):
for line in expected:
self.assertIn(line, actual)
def assert_not_in_file(self, lines: list[str], filename: str) -> None:
lines = [f"{line.rstrip()}\n" for line in lines]
with open(filename) as fp:
actual = fp.readlines()
for line in lines:
self.assertNotIn(line, actual)
def assert_pids_noexist(self, pids: dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
@ -360,8 +368,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
self.assertIsNotNone(pc.wait(period=0.1))
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_pcontext_wait_on_a_child_thread(self):
asyncio.run(asyncio.to_thread(self.test_pcontext_wait))
@ -379,8 +387,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_function_with_tensor(self):
for start_method in self._start_methods:
@ -482,8 +490,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
int(error_file_data["message"]["extraInfo"]["timestamp"]),
int(failure.timestamp),
)
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_wait_for_all_child_procs_to_exit(self):
"""
@ -580,8 +588,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assert_in_file([], results.stdouts[0])
self.assertFalse(results.stderrs[1])
self.assertFalse(results.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
failure = results.failures[1]
self.assertEqual(-15, failure.exitcode)
@ -731,8 +739,37 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
self.assertFalse(pc.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_binary_duplicate_log_filters(self):
pc = start_processes(
name="trainer",
entrypoint=bin("echo1.py"),
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
),
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
duplicate_stdout_filters=["helloA"],
duplicate_stderr_filters=["worldA", "B"],
start_method="spawn",
)
result = pc.wait()
self.assertFalse(result.is_failed())
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
self.assert_not_in_file(
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
)
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
@ -794,8 +831,44 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
self.assertFalse(pc.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_function_duplicate_log_filters(self):
for start_method in self._start_methods:
with self.subTest(start_method=start_method):
pc = start_processes(
name="trainer",
entrypoint=echo1,
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
),
duplicate_stdout_filters=["helloA"],
duplicate_stderr_filters=["worldA", "B"],
start_method="spawn",
)
result = pc.wait()
self.assertFalse(result.is_failed())
self.assert_in_file(
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
)
self.assert_not_in_file(
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
)
self.assert_in_file(
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
)
self.assert_in_file(
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
)
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_function(self):
for start_method, redirs in product(self._start_methods, redirects_all()):
@ -880,8 +953,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assertFalse(results.stdouts[0])
self.assertFalse(results.stderrs[1])
self.assertFalse(results.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_no_zombie_process_function(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]

View File

@ -23,5 +23,6 @@ if __name__ == "__main__":
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
print(f"{args.msg} stdout from {rank}")
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
for msg in args.msg.split(","):
print(f"{msg} stdout from {rank}")
print(f"{msg} stderr from {rank}", file=sys.stderr)

View File

@ -84,6 +84,53 @@ class TailLogTest(unittest.TestCase):
)
self.assertTrue(tail.stopped())
def test_tail_write_to_dst_file(self):
"""
writer() writes 0 - max (on number on each line) to a log file.
Run nprocs such writers and tail the log files into a temp file
and validate that all lines are accounted for.
"""
nprocs = 32
max = 1000
interval_sec = 0.0001
log_files = {
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
for local_rank in range(nprocs)
}
dst = os.path.join(self.test_dir, "tailed_stdout.log")
tail = TailLog(
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
).start()
# sleep here is intentional to ensure that the log tail
# can gracefully handle and wait for non-existent log files
time.sleep(interval_sec * 10)
futs = []
for local_rank, file in log_files.items():
f = self.threadpool.submit(
write, max=max, sleep=interval_sec * local_rank, file=file
)
futs.append(f)
wait(futs, return_when=ALL_COMPLETED)
self.assertFalse(tail.stopped())
tail.stop()
actual: dict[int, set[int]] = {}
with open(dst) as dst_file:
for line in dst_file:
header, num = line.split(":")
nums = actual.setdefault(header, set())
nums.add(int(num))
self.assertEqual(nprocs, len(actual))
self.assertEqual(
{f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
)
self.assertTrue(tail.stopped())
def test_tail_with_custom_prefix(self):
"""
writer() writes 0 - max (on number on each line) to a log file.
@ -131,6 +178,52 @@ class TailLogTest(unittest.TestCase):
self.assertIn(f"[worker{i}][{i}]", headers)
self.assertTrue(tail.stopped())
def test_tail_with_custom_filter(self):
"""
writer() writes 0 - max (on number on each line) to a log file.
Run nprocs such writers and tail the log files into an IOString
and validate that all lines are accounted for.
"""
nprocs = 3
max = 20
interval_sec = 0.0001
log_files = {
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
for local_rank in range(nprocs)
}
dst = io.StringIO()
tail = TailLog(
"writer",
log_files,
dst,
interval_sec=interval_sec,
log_line_filter=lambda line: "2" in line, # only print lines containing '2'
).start()
# sleep here is intentional to ensure that the log tail
# can gracefully handle and wait for non-existent log files
time.sleep(interval_sec * 10)
futs = []
for local_rank, file in log_files.items():
f = self.threadpool.submit(
write, max=max, sleep=interval_sec * local_rank, file=file
)
futs.append(f)
wait(futs, return_when=ALL_COMPLETED)
self.assertFalse(tail.stopped())
tail.stop()
dst.seek(0)
actual: dict[int, set[int]] = {}
for line in dst.readlines():
header, num = line.split(":")
nums = actual.setdefault(header, set())
nums.add(int(num))
self.assertEqual(nprocs, len(actual))
self.assertEqual({f"[writer{i}]": {2, 12} for i in range(nprocs)}, actual)
self.assertTrue(tail.stopped())
def test_tail_no_files(self):
"""
Ensures that the log tail can gracefully handle no log files

View File

@ -55,9 +55,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Remove environment variable if it exists to test default behavior
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
@ -84,8 +85,8 @@ class SignalHandlingTest(TestCase):
# Verify _start was called
mock_pcontext._start.assert_called_once()
# Verify _stdout_tail.start() and _stderr_tail.start() were called
mock_pcontext._stdout_tail.start.assert_called_once()
mock_pcontext._stderr_tail.start.assert_called_once()
mock_stdout_tail.start.assert_called_once()
mock_stderr_tail.start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@ -99,9 +100,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set custom signals in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
@ -139,9 +141,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set invalid signals in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
@ -180,9 +183,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set signals including ones not supported on Windows
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
@ -234,9 +238,10 @@ class SignalHandlingTest(TestCase):
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
mock_threading.main_thread.return_value = MagicMock()
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Call the start method
PContext.start(mock_pcontext)
@ -262,9 +267,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set environment variable to include SIGUSR1 and SIGUSR2
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
@ -323,8 +329,8 @@ class SignalHandlingTest(TestCase):
# Verify _start was called
mock_pcontext._start.assert_called_once()
# Verify _stdout_tail.start() and _stderr_tail.start() were called
mock_pcontext._stdout_tail.start.assert_called_once()
mock_pcontext._stderr_tail.start.assert_called_once()
mock_stdout_tail.start.assert_called_once()
mock_stderr_tail.start.assert_called_once()
if __name__ == "__main__":

View File

@ -337,6 +337,70 @@ class ScheduleTest(MultiProcContinuousTest):
if self.rank == self.world_size - 1:
self.assertTrue(len(losses) > 0, "Losses should be computed during eval()")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize(
"ScheduleClass",
[
ScheduleGPipe,
Schedule1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleInterleavedZeroBubble,
],
)
def test_return_output(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleInterleavedZeroBubble,
]:
# Multi-stage schedules
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
mod, _, x, target, loss_fn = setup_models_and_data(
self.config, n_layers=n_stages
)
# Create multi-stage pipeline
stages, stage_modules, _ = create_multi_stage_pipeline(
self.config, mod, stages_per_rank, n_stages
)
schedule = ScheduleClass(
stages,
num_microbatches,
loss_fn=loss_fn,
scale_grads=False,
)
else:
# Single-stage schedules
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
# Create single-stage pipeline
stage, stage_module, _ = create_single_stage_pipeline(
self.config, mod, x, num_microbatches
)
schedule = ScheduleClass(
stage,
num_microbatches,
loss_fn=loss_fn,
scale_grads=False,
)
losses = []
if self.rank == self.world_size - 1:
output = schedule.step(target=target, losses=losses, return_outputs=False)
else:
schedule.step(x)
# Verify that output is None
if self.rank == self.world_size - 1:
self.assertTrue(output is None, "Output should be None")
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"

View File

@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._debug_mode import DebugMode
from torch.utils._debug_mode import _OpCall, _RedistributeCall, DebugMode
from torch.utils._python_dispatch import TorchDispatchMode
@ -60,6 +60,10 @@ class TestDTensorDebugMode(TestCase):
aten::sum(t: f32[1, 32])""",
)
self.assertTrue(isinstance(debug_mode.operators[0], _OpCall))
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

View File

@ -6,7 +6,10 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._guards import tracing, TracingContext
@ -96,6 +99,13 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
gm = dynamo_graph_capture_for_export(model)(inputs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
return aot_export_joint_with_descriptors_alone(gm, inputs)
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
@ -288,6 +298,7 @@ class DTensorExportTest(TestCase):
@parametrize(
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
aot_export_joint_with_descriptors_alone,
],
@ -307,7 +318,21 @@ class DTensorExportTest(TestCase):
def test_annotate_aot_export_joint_with_descriptors_alone(self):
self._run_test(aot_export_joint_with_descriptors_alone, True)
def test_dynamic_shapes(self):
@parametrize(
"export_fn_with_answer",
[
(
graph_capture_and_aot_export_joint_with_descriptors_v2,
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
),
(
graph_capture_and_aot_export_joint_with_descriptors,
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
),
],
)
def test_dynamic_shapes(self, export_fn_with_answer):
export_fn, answer = export_fn_with_answer
dp_degree = 2
tp_degree = self.world_size // dp_degree
@ -331,7 +356,7 @@ class DTensorExportTest(TestCase):
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(tp_model, inputs)
joint_gm = export_fn(tp_model, inputs)
res = []
for node in joint_gm.graph.nodes:
@ -341,12 +366,16 @@ class DTensorExportTest(TestCase):
if isinstance(fake_val, torch._subclasses.fake_tensor.FakeTensor):
res.append(list(fake_val.shape))
self.assertExpectedInline(
str(res),
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
)
self.assertEqual(str(res), answer)
def test_einsum_dtensor_export(self):
@parametrize(
"export_fn",
[
dynamo_graph_capture_for_export,
_dynamo_graph_capture_for_export,
],
)
def test_einsum_dtensor_export(self, export_fn):
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
world_size = 4
# Create device mesh
@ -366,9 +395,7 @@ class DTensorExportTest(TestCase):
output = model(x_dtensor, y_dtensor, z_dtensor)
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(
x_dtensor, y_dtensor, z_dtensor
)
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
self.assertEqual(output, output_gm)

View File

@ -55,7 +55,7 @@ if TEST_WITH_DEV_DBG_ASAN:
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
load_tests = load_tests # noqa: PLW0127
if platform == "darwin":
LOOPBACK = "lo0"

View File

@ -1459,7 +1459,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
@requires_gloo()
def test_reduce_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = pg = self._create_process_group_gloo(
pg = self._create_process_group_gloo(
store, self.rank, self.world_size, self.opts()
)

Some files were not shown because too many files have changed in this diff Show More