Compare commits

...

196 Commits

Author SHA1 Message Date
5b6cc8215f Change python doc push script to print the undocumented modules 2025-10-21 12:30:49 -07:00
1c43c9cfd0 Update 2025-10-21 12:30:49 -07:00
102e0d5437 Test 2025-10-21 12:30:49 -07:00
0bd12c1168 [CI] Extend test_transfomers to MPS (#165960)
Just skip grad_checks as they need float64
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165960
Approved by: https://github.com/Skylion007
2025-10-21 19:27:44 +00:00
ce8a7764e2 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 1290b077f26543a34262587137ef64ca9ca5e17d.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to failing internal tests D85160820 ([comment](https://github.com/pytorch/pytorch/pull/165707#issuecomment-3429084393))
2025-10-21 19:25:03 +00:00
d1269a0434 update fr trace analysis (#165994)
Summary:
- allow empty entries from ranks
- allow not all ranks to provide dump

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/165994).
* #165638
* #165640
* #165642
* __->__ #165994
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165994
Approved by: https://github.com/fduwjj
2025-10-21 19:14:33 +00:00
c87cf1be32 Update workaround to old CUDA bug (#164354) (#165984)
The workaround cannot be removed because of BC. Here we'll
update PyTorch code base to not use the workaround.

See https://github.com/pytorch/pytorch/pull/164354 for the BC breakage issue.

Resolves https://github.com/pytorch/pytorch/issues/164348.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165984
Approved by: https://github.com/janeyx99
2025-10-21 19:09:43 +00:00
2fc5e45a41 better error message when there is no pytree impl (#165955)
Differential Revision: [D85117597](https://our.internmc.facebook.com/intern/diff/D85117597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165955
Approved by: https://github.com/avikchaudhuri
2025-10-21 18:49:22 +00:00
f9022ba93b [PyTorch] Add user_metadata display to memory visualizer (#165939)
Summary: Enhanced the PyTorch CUDA memory visualizer to display user_metadata alongside stack frames when inspecting allocations. The user_metadata field is now shown in all views (Allocator State History, Active Memory Timeline, etc.) with consistent formatting. The implementation handles both string and object metadata types, displaying strings directly and objects as key-value pairs.

Test Plan:
1. Generate a memory snapshot with user_metadata
2. Open the memory visualizer in a browser
3. Load the snapshot file
4. Verify user_metadata appears
5. Test with both string metadata ("testing") and object metadata ({"key": "value"})
6. Verify formatting shows "User Metadata:\n  <value>" for strings

 {F1982860439}

Differential Revision: D85095152

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165939
Approved by: https://github.com/yushangdi
2025-10-21 18:48:33 +00:00
ff8be889ad Remove unused exception parameter from some files, to work with -Wunused-exception-parameter (#165770)
Summary: address compiler complains that were coming up to unblock the build

Test Plan:
before the change
```
aten/src/ATen/native/LinearAlgebra.cpp:3623:36: error: unused exception parameter 'e' [-Werror,-Wunused-exception-parameter]
 3623 |     } catch (const std::exception& e) {
      |
```

after: targets build with `-Wunused-exception-parameter`

Differential Revision: D84876246

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165770
Approved by: https://github.com/Skylion007, https://github.com/cyyever

Co-authored-by: Tony Targonski <tony.targonski@meta.com>
2025-10-21 18:30:29 +00:00
292454942e [CD] Introduce windows.12xlarge runners for CD Windows build (#165287)
Follows https://github.com/pytorch/test-infra/pull/7174. Windows CD build time cost comparison as below

|Runner|cpu|cuda|xpu|
|-|-|-|-|
|windows.4xlarge|1.5h| 4.0h| 5.5h|
|windows.12xlarge|0.5h|1.5h|2.5h|

Fixes #162962
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165287
Approved by: https://github.com/zxiiro, https://github.com/malfet, https://github.com/seemethere
2025-10-21 18:28:23 +00:00
6c4412f72b Revert "[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)"
This reverts commit e9d89734274a4a2640fa77b898c800a87d1d874e.

Reverted https://github.com/pytorch/pytorch/pull/163316 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
78bf6186f2 Revert "[Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)"
This reverts commit e8cb34dd52c063a130f3e659576c313bbe4b4981.

Reverted https://github.com/pytorch/pytorch/pull/163324 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
c40048472c Remove AOTI cross compilation time from internal CI (#165935)
Summary: as title

Test Plan: CI

Differential Revision: D85088451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165935
Approved by: https://github.com/desertfire
2025-10-21 16:58:28 +00:00
3dfd0c7584 Improve PATH hints in FindvecLib.cmake (#165881)
Change  /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk to /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk in `cmake/Modules/FindvecLib.cmake` which is more general (and MacOSX10.9 is not supported now). Otherwise, vecLib can't be found on MacOS 26.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165881
Approved by: https://github.com/ezyang
2025-10-21 16:44:12 +00:00
e6ba4d0725 Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910)
Summary:
Original commit changeset: d6d62d0c96dd

Original Phabricator Diff: D84468451 and D84613184

D84468451 caused CUDA OutOfMemoryError in model.

Test Plan:
D84468451 was found through bisect.  Also double checked on recent trunk 9866939225248c2adc307be7a804b26db0b9b555: f815887517

With this diff that backs out D84468451 and D84613184 : f816114560

Differential Revision: D85025378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165910
Approved by: https://github.com/clee2000
2025-10-21 16:36:38 +00:00
bdf7cb9d9c Revert "[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)"
This reverts commit e20c9bf2889b9252ac45ae6af35c93c795eab701.

Reverted https://github.com/pytorch/pytorch/pull/165410 on behalf of https://github.com/clee2000 due to sorry I'm going to revert this since I want to try to back out some other things that are conflicting with this, there is nothing wrong with this PR, rebasing and resolving the merge conflicts should be enough, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/165410#issuecomment-3427532373))
2025-10-21 16:27:54 +00:00
6aed378958 [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-21 15:53:05 +00:00
8b3dc0d1b0 Better error handling in torch/csrc/jit/runtime/* (#165118)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/runtime/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165118
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 15:22:49 +00:00
06773663b5 Implement an AOT precompile mode for standalone_compile (#165843)
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
2025-10-21 15:02:45 +00:00
0bff65503c Move hardware_destructive_interference_size to c10/core/alignment.h (#160067)
# Motivation
Move `hardware_destructive_interference_size` to `c10/core/alignment.h`, which gives a chance to reuse it across different accelerators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160067
Approved by: https://github.com/Skylion007, https://github.com/EikanWang
2025-10-21 14:39:46 +00:00
21131a2444 Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171))
2025-10-21 14:15:55 +00:00
1009790ad8 [pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972

- #164972

All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
2025-10-21 14:13:08 +00:00
410e6a4321 Better error handling in torch/csrc/jit/frontend/* (#165213)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/frontend/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165213
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 13:54:59 +00:00
23c55c5b66 [Code Clean]Replace assert statements with explicit if/raise patterns (#165735)
Fix part of #164878

Replace 75 assert statements with explicit if/raise patterns in `torch/ao/ns` , include:

- `torch/ao/ns/_numeric_suite_fx.py`  - 5 asserts

- `torch/ao/ns/fx/graph_matcher.py` - 6 asserts

- `torch/ao/ns/fx/graph_passes.py` -12 asserts

- `torch/ao/ns/fx/n_shadows_utils.py` - 20 asserts

- `torch/ao/ns/fx/pattern_utils.py` - 2 asserts

- `torch/ao/ns/fx/utils.py` - 21 asserts

- `torch/ao/ns/fx/weight_utils.py` - 19 asserts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165735
Approved by: https://github.com/albanD
2025-10-21 11:21:57 +00:00
1290b077f2 [dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
2025-10-21 09:27:41 +00:00
9f9ab881b2 [ROCm][inductor] heuristic improvements for reduction kernels (#161280)
Improvements to reduction kernel heuristics for MI350.

Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161280
Approved by: https://github.com/jansel, https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jeffdaily
2025-10-21 07:48:54 +00:00
f2bb22ff84 [Inductor-FX] Support Tensor.item (#165599)
# Feature
This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`.

# Implementation
The implementation is fairly mechanical, following the usual flow for these types of PRs.
1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`.
2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR.
3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars.

# Test plan
Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165599
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-10-21 07:09:56 +00:00
03f3f7899c [ATen] Add reduction tag to reduction operators (#165155)
Add a new 'reduction' tag to tags.yaml and apply it to 98 reduction
operator variants across 21 operator families (sum, mean, min, max,
argmin, argmax, amin, amax, aminmax, prod, all, any, norm, var, std,
std_mean, var_mean, nansum, logsumexp, count_nonzero, linalg_vector_norm).

This tag categorizes operators that perform reduction operations,
computing aggregate values across one or more dimensions of input
tensor(s).

Based on PR #153342 - co-written with @AlonSardas.

Just as we have pointwise tag - this can be useful for compiler passes, or for opting into sharding rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165155
Approved by: https://github.com/ezyang, https://github.com/zou3519, https://github.com/mlazos
2025-10-21 04:35:03 +00:00
771170807b [dynamo][nn_module] Replace UserFunctionVariable with VariableTracker build (#165708)
Audit: To prevent future issues with functools.partial or callable objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165708
Approved by: https://github.com/Lucaskabela
2025-10-21 04:13:12 +00:00
ffa90d46e6 [ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)
* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.

Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165481
Approved by: https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2025-10-21 04:02:04 +00:00
0e083942cc Enable PLW0127 in ruff (#165851)
This PR enables `PLW0127` in ruff, which checks self-assignment of variables with the form `var=var`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165851
Approved by: https://github.com/Lucaskabela
2025-10-21 03:30:57 +00:00
ce1fcff03e [ROCm] Keep amdgpu-coerce-illegal-types flag if rocm version is less than 7.2 (#165789)
The `-amdgpu-coerce-illegal-types=1` flag is for LLVM that is in ROCm 6.3, 6.4, 7.0, and 7.1. It will not be in ROCm7.2. It was added to enable performance improvements for composable kernel. ROCm7.2 and newer changed the compiler so that the flag isn't needed to achieve those performance improvements. Keeping the flag with ROCm 7.2 breaks the PyTorch build.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165789
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
2025-10-21 03:17:33 +00:00
a238a9a100 Add clang-tidy misc-definitions-in-headers check (#164959)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164959
Approved by: https://github.com/Skylion007, https://github.com/mikaylagawarecki
ghstack dependencies: #164882, #164956
2025-10-21 02:59:46 +00:00
fe69a2bbbd Move from/to to torch::stable::detail (#164956)
To not pollute the global namespace, we should move the `from`/`to` APIs into torch::stable::detail. We are also following our normal deprecation cycle and choosing to continue exposing the global `from`/`to` for the time being as people who onboard their extensions onto 2.9 would not be able to build with 2.10 otherwise.

Note that this means that within libtorch, we do not get the luxury of tacking on a `using torch::stable::detail::from` because then it leads to build time ambiguous calls --> both the global and namespace APIs are exposed, which one do I want? So that is why you see every local site is updated.

Note that the update is _not_ necessary from a custom op writer point of view. FA3 can continue to build on torch nightlies without changing any code. (Since this is a header change, this PR has no implication on runtime, a previously built FA3 ABI stable wheel will continue to work fine with newer torch versions after this PR.)

Once TORCH_BOX lands, we would be free to remove these global APIs when the deprecation cycle is up (April 2026) and encourage people to use TORCH_BOX and avoid from/to entirely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164956
Approved by: https://github.com/malfet
ghstack dependencies: #164882
2025-10-21 02:59:46 +00:00
0be0de4ffa Add type suppressions to _inductor/runtime (#165918)
Original PR that did this was reverted due to merge conflicts.

Trying it again

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165918
Approved by: https://github.com/oulgen
2025-10-21 02:54:22 +00:00
7406d2e665 [DeviceMesh] Clean up the call into mesh_resouces to get root mesh (#165787)
We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510. This is to further clean code up.

Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787
Approved by: https://github.com/fegin
2025-10-21 02:54:04 +00:00
303c9cf048 Save Python refcount bump on each arg in maybe_handle_torch_function (#164625)
Pybind's API entails a small unnecessary overhead when working with args. (Similarly, we should probably be using vectorcall, but that's a bigger change for both us and pybind11.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164625
Approved by: https://github.com/albanD
ghstack dependencies: #164624
2025-10-21 02:40:12 +00:00
d7d4bb7c51 Add XPU part for persons_of_interest (#165920)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165920
Approved by: https://github.com/albanD
2025-10-21 01:57:17 +00:00
0b1c462979 Making Numpy depedency in Local Tensor optional to fix broken Torchao CI (#165938)
In recent change LocalTensor introduced dependency on Numpy and has broken Torchao CI.
This dependency cna be made optional and required only when Local Tensor is used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165938
Approved by: https://github.com/atalman
2025-10-21 01:46:53 +00:00
4a6cf0a93e Fix dynamo stack trace (#165930)
Fixes #165911

- Add message to Attribute error so we see `  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])` instead of just `Developer debug context: raised exception AttributeError([])`
- Add stack trace in `ObservedException` so we display the inner most error stack trace back to user code

Output:

```
/data/users/shangdiy/pytorch/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
  warnings.warn(
Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1385, in var_getattr
    subobj = self._getattr_static(name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1256, in _getattr_static
    subobj = type(self.value).__getattribute__(self.value, name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Linear' object has no attribute 'w'

During handling of the above exception, another exception occurred:

torch._dynamo.exc.ObservedAttributeError: 'Linear' object has no attribute 'w'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/test.py", line 34, in <module>
    mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 481, in inner
    out = fullgraph_capture(
          ^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1053, in fullgraph_capture
    return _fullgraph_capture_frame(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1115, in _fullgraph_capture_frame
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
  Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
  Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html

from user code:
   File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 171, in forward
    res = self._export_root(*args, **kwargs)
  File "/data/users/shangdiy/pytorch/test.py", line 31, in forward
    weight = self.linear.w

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165930
Approved by: https://github.com/anijain2305
2025-10-21 01:32:23 +00:00
4c963a68d7 Use inline instead of anon namespace for stableivalue from/to (#164882)
Fixes https://github.com/pytorch/pytorch/issues/163343.

After some consideration, I propose we remove the anonymous namespace around from/to in favor of:
1. Adding inline to the function implementations, assuming that they will not change in the near future
2. If we decide to change them, we will wrap the code in inline versioned namespaces such that the implementations within any versioned namespace will be guaranteed identical.

Note that:
- We eventually intend to abstract away usage of `from`/`to` (related: @lw's TORCH_BOX work)
- The from/to implementations are now powered through class template specializations, where adding a specialization does not change the from/to signatures.

I do plan to deprecate top-level from/to in favor of torch::stable::details::from/to consequently. This way we can stop polluting the global namespace.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164882
Approved by: https://github.com/lw, https://github.com/albanD
2025-10-21 00:12:15 +00:00
b20deec3d1 [PP] Add optional argument to not save outputs (#165822)
Fix https://github.com/pytorch/pytorch/issues/159251

Add an optional argument `return_outputs` to the schedule `step`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165822
Approved by: https://github.com/wconstab
2025-10-21 00:09:31 +00:00
51d0d8ee67 [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```

Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-21 00:09:13 +00:00
70592c6819 [ROCm][CI] Move gfx1100 workflows to own yaml file (#165699)
This should allow us to move gfx1100 workflow to a lower frequency and also allow it to be triggered on PRs via a dedicated label, for any PRs that target Navi fixes such as [this](https://github.com/pytorch/pytorch/pull/165630) or [this](https://github.com/pytorch/pytorch/pull/165625).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165699
Approved by: https://github.com/jeffdaily
2025-10-20 23:52:48 +00:00
259cb945f5 [stage 2c] make autograd and inference functions (#165668)
Add final stage of aot_stage2_compile for autograd and inference.

Differential Revision: D84844699

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165668
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2025-10-20 23:50:31 +00:00
e20c9bf288 [torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)
Including:
- `torch/utils/*.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165410
Approved by: https://github.com/albanD
2025-10-20 23:29:17 +00:00
99c8640b5d [1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)
This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 23:27:13 +00:00
96b0e7aaa6 [Code Clean] Clean asserts in torch/ao/quantization/experimental/* and torch/ao/quantization/pt2e/* (#165317)
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/experimental/* (11 errors)
- torch/ao/quantization/pt2e/* (68 errors)

fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165317
Approved by: https://github.com/albanD
2025-10-20 23:07:11 +00:00
850ba8c96d [Code Clean] Clean asserts in torch/autograd. (#165627)
Replaces 78 assert statements across 10 files in torch.autograd with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag. This ensures error checking remains active in optimized builds.

fix partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165627
Approved by: https://github.com/albanD
2025-10-20 23:03:47 +00:00
1bcd736f91 fix bad merge duplicate pre pass (#165917)
fix for https://github.com/pytorch/pytorch/issues/165624 - we were applying pre pass multiple times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165917
Approved by: https://github.com/bdhirsh
2025-10-20 22:54:36 +00:00
df64c0c464 [Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) (#165433)
Replace assert statements with explicit if/raise patterns in:

- torch/ao/quantization/~
- torch/ao/quantization/quantizer/
- torch/ao/quantization/backend_config/

fix partialy #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165433
Approved by: https://github.com/albanD
2025-10-20 22:42:51 +00:00
1891239a1d [Graph Partition] fix graph partition input signature for fallback kernels (#165815)
Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition.

## Example
```python
def f(x):
    y = x + 1
    z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
    z_cpu = z.cpu()
    u_cuda = z_cpu.cuda()
    return u_cuda
```

In the generated code, we have
```
def partition_0(args):
    ...
    # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view]
    buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0
    buf2 = buf1 # <------- buf2 is buf1
    assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype')
    assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype')
    return (buf2, )

def call(self, args):
    ...
    (buf2,) = self.partitions[0](partition0_args)
    ...
    buf3.copy_(buf2, False)
    del buf0
    del buf1
    del buf2  # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0.
    ...
```

Note: view is treated as a fallback kernel due to its special dtype.
de09bab4b6/torch/_inductor/lowering.py (L841-L843)

## Fix

This PR fixes the issue by also returning these buffers to be freed later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165815
Approved by: https://github.com/eellison
2025-10-20 22:23:29 +00:00
cf280ca1e8 Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit 779296a3fce5db0829377c792f13a8eafe537b30.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3423808492))
2025-10-20 21:36:44 +00:00
efc277cac7 [annotation] add logging for debugging annotation (#165797)
Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" `

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165797
Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/SherlockNoMad
2025-10-20 21:27:38 +00:00
4f7f43253d Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit 8700d68fef855850e2e0aa65056a77b8f80adbdb.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/malfet due to Broke lint somehow, see 8f06a1308f/1 ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3423642456))
2025-10-20 20:39:56 +00:00
779296a3fc [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd
2025-10-20 20:39:04 +00:00
8f06a1308f [MPS] slightly faster cholesky (#165867)
Slightly faster cholesky, removed one redundant simdgroup_multiply
<img width="721" height="593" alt="Screenshot 2025-10-19 at 22 00 19" src="https://github.com/user-attachments/assets/e3a9005b-9347-4e62-a24d-16ba5e28849a" />

Generate benchmarks with(measured on M1 Pro):
```
import torch
import numpy as np
import time
import csv

matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3

def create_spd_matrix(n, batch_size):
    torch.manual_seed(42)
    A = torch.randn(batch_size, n, n, dtype=torch.float32)
    return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)

def run_cholesky_mps(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    b = torch.linalg.cholesky(A, upper=False)
    torch.mps.synchronize()
    end = time.perf_counter()
    return b, end - start

results = {
    'N': [],
    'batch_size': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    for batch_size in batch_sizes:
        print(f"\nBenchmarking N={n}, batch_size={batch_size}")

        try:
            A_cpu = create_spd_matrix(n, batch_size)
            A_mps = A_cpu.to("mps")

            for _ in range(warmup_runs):
                _, _ = run_cholesky_mps(A_mps)

            times = []
            for _ in range(num_runs):
                _, t = run_cholesky_mps(A_mps)
                times.append(t)

            mean_time = np.mean(times)
            std_time = np.std(times)

            results['N'].append(n)
            results['batch_size'].append(batch_size)
            results['mean_time'].append(mean_time)
            results['std_time'].append(std_time)

            print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")

        except RuntimeError as e:
            print(f"Error for N={n}, batch_size={batch_size}: {e}")
            continue

with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['batch_size'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165867
Approved by: https://github.com/malfet
2025-10-20 18:56:17 +00:00
240c13394e Revert "[inductor] require shape in TritonCSEVariable (#162275)"
This reverts commit 3af2f0c12accc6bd10ef2b76fb5c51aa0f6b73a3.

Reverted https://github.com/pytorch/pytorch/pull/162275 on behalf of https://github.com/clee2000 due to still failing due to the above D84932446 ([comment](https://github.com/pytorch/pytorch/pull/162275#issuecomment-3423153819))
2025-10-20 17:55:54 +00:00
150682ba7f Revert "Remove workaround to old CUDA bug (#164354)"
This reverts commit 26f38034332a99f2bdcc67ce1f4ba9403d420e52.

Reverted https://github.com/pytorch/pytorch/pull/164354 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164354#issuecomment-3423132083))
2025-10-20 17:48:08 +00:00
ca7360e996 Revert "Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405)"
This reverts commit ca8bd5dbedb5b46f78026e0378b0f47500ddba38.

Reverted https://github.com/pytorch/pytorch/pull/164405 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164354#issuecomment-3423132083))
2025-10-20 17:48:08 +00:00
0bf604320f Revert "[dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build (#165706)"
This reverts commit 1dc9a05d0323ee3c7a20945c62463959d40f1a51.

Reverted https://github.com/pytorch/pytorch/pull/165706 on behalf of https://github.com/clee2000 due to breaking internal tests D84961097 ([comment](https://github.com/pytorch/pytorch/pull/165706#issuecomment-3423059867))
2025-10-20 17:28:58 +00:00
9875e70da8 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 630520b346b8883db7821562e589ccde7d12687a.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to breaking internal tests D84961097 ([comment](https://github.com/pytorch/pytorch/pull/165706#issuecomment-3423059867))
2025-10-20 17:28:58 +00:00
69a4bfe8bb Revert "Refactor out headeronly ArrayRef (#164991)"
This reverts commit 3806e9767b03d06edc317cb90a3a996abdf192a0.

Reverted https://github.com/pytorch/pytorch/pull/164991 on behalf of https://github.com/clee2000 due to breaking internal tests D84961075 ([comment](https://github.com/pytorch/pytorch/pull/164991#issuecomment-3423058017))
2025-10-20 17:26:42 +00:00
62a263b8d4 Revert "Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)"
This reverts commit e4454947e2c692db1a249591121f8583fefe7df1.

Reverted https://github.com/pytorch/pytorch/pull/165152 on behalf of https://github.com/clee2000 due to breaking internal tests D84961075 ([comment](https://github.com/pytorch/pytorch/pull/164991#issuecomment-3423058017))
2025-10-20 17:26:42 +00:00
0da1f911dc Revert "[Submodule] Bump FBGEMM to latest (#165544)"
This reverts commit 23417ae50f5d9bc02e988d916c103ff3a03c5903.

Reverted https://github.com/pytorch/pytorch/pull/165544 on behalf of https://github.com/clee2000 due to failing in internal D84996252, probably needs some sort of update to fbgemm internally? ([comment](https://github.com/pytorch/pytorch/pull/165544#issuecomment-3422993703))
2025-10-20 17:06:07 +00:00
8700d68fef [ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)
* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.

Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165481
Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/albanD

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2025-10-20 16:06:37 +00:00
ab82456c16 Revert "[1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)"
This reverts commit e1e8491b316df810388d9fa24f135cdba27ab40e.

Reverted https://github.com/pytorch/pytorch/pull/165750 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165750#issuecomment-3422413890))
2025-10-20 14:51:58 +00:00
b23f4687fd [Inductor][CuTeDSL] Move load_template up two directories (#165868)
Summary:
This is a reland of https://github.com/pytorch/pytorch/pull/165347

Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: test/inductor/test_flex_flash

Differential Revision: D85013024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165868
Approved by: https://github.com/jananisriram
2025-10-20 12:14:38 +00:00
2705937080 [CI] Add rocm CI back to trunk for pre-submit/PR jobs (#165674)
Only adding single-GPU shards for now, to observe how current capacity handles it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165674
Approved by: https://github.com/jeffdaily
2025-10-20 12:14:06 +00:00
c1eda348be [cuda] fix triu/tril int32 overflow for large matrices (#164705)
Fixes #136611

Cast blockIdx.x to int64_t before multiplication to prevent overflow when computing linear_idx for matrices larger than 2^31 elements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164705
Approved by: https://github.com/eqy, https://github.com/ngimel
2025-10-20 07:17:41 +00:00
ba93d5636e [cuda] fix nll_loss2d backward bounds check with reduction=none (#165247)
Fixes #49882

Add missing bounds check in nll_loss2d backward kernel with reduction=none. Forward kernel already had CUDA_KERNEL_ASSERT for target bounds, now backward kernel matches.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165247
Approved by: https://github.com/ngimel
2025-10-20 06:25:11 +00:00
722b2b86c9 [dynamo] Remove duplicated guards (#165806)
This is by looking at a tlparse of an internal job. We will need deeper audit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165806
Approved by: https://github.com/jansel
2025-10-20 05:50:33 +00:00
e1e8491b31 [1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)
This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 04:36:19 +00:00
767199fd9b [flex_attention] replace sliced BlockMask noop with helpful error (#164702)
Fixes part of #163314

After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`.

Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702
Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
2025-10-20 03:46:16 +00:00
602ace5eb4 Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 36371b8ec7a1baed255c18451b2c716386a54c95.

Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3420373755))
2025-10-20 03:06:52 +00:00
47804ce467 Revert "12/n : Remove fbandroid_compiler_flags (#165558)"
This reverts commit aead9270f56ebc7302c7f5fa7e5dff959f26608e.

Reverted https://github.com/pytorch/pytorch/pull/165558 on behalf of https://github.com/clee2000 due to Diff was actually reverted internally D84832629 ([comment](https://github.com/pytorch/pytorch/pull/165558#issuecomment-3420367955))
2025-10-20 03:03:13 +00:00
e8cb34dd52 [Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)
**Summary:**
Support masked vectorization for the tail_loop for fp8 datatype.

**Example:**
```
import torch

def fn(
    x,
    scale,
    zero_point,
    quant_min,
    quant_max,
    dtype,
):
    x = torch.ops.quantized_decomposed.dequantize_per_tensor(
        x,
        scale,
        zero_point,
        quant_min,
        quant_max,
        dtype,
    )
    x = torch.relu(x)
    x = torch.ops.quantized_decomposed.quantize_per_tensor(
        x, scale, zero_point, quant_min, quant_max, dtype
    )
    return x

quant_min = -128
quant_max = 127
dtype = torch.float8_e4m3fn
x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype)
zero_point = 100
scale = 0.01

with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype)
```

**Generated code:**

- Before
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = c10::convert<float>(tmp0);
                        auto tmp2 = static_cast<float>(100.0);
                        auto tmp3 = float(tmp1 - tmp2);
                        auto tmp4 = static_cast<float>(0.01);
                        auto tmp5 = float(tmp3 * tmp4);
                        auto tmp6 = c10::convert<float>(tmp5);
                        auto tmp7 = std::max(tmp6, decltype(tmp6)(0));
                        auto tmp8 = float(tmp7 * tmp2);
                        auto tmp9 = std::nearbyint(tmp8);
                        auto tmp10 = float(tmp9 + tmp2);
                        auto tmp11 = static_cast<float>(-128.0);
                        auto tmp12 = max_propagate_nan(tmp10, tmp11);
                        auto tmp13 = static_cast<float>(127.0);
                        auto tmp14 = min_propagate_nan(tmp12, tmp13);
                        auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324
Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel
ghstack dependencies: #163316
2025-10-20 01:56:00 +00:00
e9d8973427 [Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)
**Summary:**
Support masked vectorization for the tail_loop for float64 datatype.

**Example:**
```
import torch

def fn(x):
    return x * x

x = torch.randn((22, 22), dtype=torch.double)
with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x)
```

**Generated code:**

- Before
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const double* in_ptr0,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = double(tmp0 * tmp0);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (22, 22), (22, 1))
        buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
        # [Provenance debug handles] cpp_fused_mul_0:1
        cpp_fused_mul_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const double* in_ptr0,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (22, 22), (22, 1))
        buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
        # [Provenance debug handles] cpp_fused_mul_0:1
        cpp_fused_mul_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163316
Approved by: https://github.com/mingfeima, https://github.com/jansel
2025-10-20 01:41:38 +00:00
61d9a5180e [Fix XPU CI] [Inductor UT] Fix test cases broken by community. (#165714)
Fixes #165719, Fixes #165771

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165714
Approved by: https://github.com/jansel
2025-10-19 23:59:04 +00:00
8a8329b51f [ATen] Switch order of blocked reduce when vectorize loads (#165178)
Performance benchmarking, perf neutral:
```
================================================================================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full diff %     Non-Contig diff %    Contig diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.015684             0.017056               0.008287             0.016015             0.016929               0.008170                      -2.07%               +0.75%          +1.43%
(256, 256)           sum          0.015774             0.016638               0.007926             0.015811             0.016935               0.008330                      -0.23%               -1.75%          -4.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013385             0.025742               0.008629             0.013046             0.026005               0.008924                      +2.60%               -1.01%          -3.31%
(512, 512)           sum          0.013390             0.026059               0.009116             0.013054             0.025696               0.008952                      +2.57%               +1.41%          +1.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014213             0.015467               0.010334             0.013862             0.015082               0.010318                      +2.53%               +2.55%          +0.16%
(1024, 1024)         sum          0.014179             0.015446               0.010774             0.014132             0.015073               0.010350                      +0.33%               +2.47%          +4.10%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018234             0.019487               0.014812             0.018482             0.019397               0.014802                      -1.34%               +0.46%          +0.07%
(2048, 2048)         sum          0.018202             0.019529               0.015195             0.018122             0.019485               0.015129                      +0.44%               +0.23%          +0.44%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.033582             0.039378               0.030751             0.033810             0.039673               0.031019                      -0.67%               -0.74%          -0.86%
(4096, 4096)         sum          0.033604             0.039777               0.030809             0.033530             0.039386               0.031113                      +0.22%               +0.99%          -0.98%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.085824             0.091133               0.084200             0.085431             0.091364               0.084303                      +0.46%               -0.25%          -0.12%
(8192, 8192)         sum          0.085763             0.091442               0.084180             0.085508             0.091419               0.084595                      +0.30%               +0.03%          -0.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.146480             0.147666               0.138807             0.146515             0.147987               0.138930                      -0.02%               -0.22%          -0.09%
(8192, 16384)        sum          0.146446             0.147593               0.138559             0.146151             0.147982               0.139120                      +0.20%               -0.26%          -0.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266047             0.265386               0.253837             0.265648             0.265885               0.253652                      +0.15%               -0.19%          +0.07%
(8192, 32768)        sum          0.266093             0.265421               0.253890             0.265458             0.265591               0.253567                      +0.24%               -0.06%          +0.13%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.498632             0.508976               0.481865             0.498237             0.508777               0.481476                      +0.08%               +0.04%          +0.08%
(8192, 65536)        sum          0.498917             0.508202               0.481883             0.498104             0.508016               0.481972                      +0.16%               +0.04%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.957633             0.968519               0.938172             0.956766             0.968267               0.938196                      +0.09%               +0.03%          -0.00%
(8192, 131072)       sum          0.956972             0.968140               0.937741             0.957365             0.968404               0.938056                      -0.04%               -0.03%          -0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.906661             1.928377               1.861846             1.907327             1.928811               1.862083                      -0.03%               -0.02%          -0.01%
(8192, 262144)       sum          1.905976             1.928362               1.862399             1.907098             1.928844               1.861782                      -0.06%               -0.02%          +0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.956852             0.970101               0.936524             0.957263             0.969809               0.936965                      -0.04%               +0.03%          -0.05%
(4096, 262144)       sum          0.957117             0.969933               0.936247             0.956675             0.969451               0.936395                      +0.05%               +0.05%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.498813             0.511299               0.483415             0.498567             0.511482               0.483376                      +0.05%               -0.04%          +0.01%
(2048, 262144)       sum          0.498813             0.510834               0.483641             0.498875             0.511036               0.483338                      -0.01%               -0.04%          +0.06%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266157             0.276751               0.255192             0.265966             0.276808               0.255544                      +0.07%               -0.02%          -0.14%
(1024, 262144)       sum          0.266133             0.276709               0.255528             0.265658             0.276685               0.255287                      +0.18%               +0.01%          +0.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.085941             0.081184               0.087931             0.085591             0.080832               0.088008                      +0.41%               +0.44%          -0.09%
(512, 131072)        sum          0.085962             0.081107               0.088045             0.085882             0.081160               0.088024                      +0.09%               -0.07%          +0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014203             0.045859               0.010310             0.013885             0.046132               0.010621                      +2.29%               -0.59%          -2.93%
(1000, 1000)         sum          0.014180             0.046165               0.010756             0.013893             0.046109               0.010338                      +2.07%               +0.12%          +4.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.012953             0.016751               0.008536             0.012977             0.016714               0.008916                      -0.18%               +0.22%          -4.26%
(1024, 129)          sum          0.013356             0.016806               0.008722             0.013003             0.017071               0.008611                      +2.71%               -1.55%          +1.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.013075             0.016787               0.009102             0.013116             0.016769               0.008679                      -0.31%               +0.11%          +4.87%
(1024, 257)          sum          0.013092             0.016842               0.008786             0.013126             0.017128               0.008771                      -0.26%               -1.67%          +0.17%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.013662             0.017412               0.010055             0.013659             0.017019               0.010033                      +0.02%               +2.31%          +0.22%
(1024, 587)          sum          0.013636             0.017473               0.010163             0.013642             0.017363               0.010101                      -0.04%               +0.63%          +0.61%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015276             0.027873               0.012531             0.015241             0.027783               0.012467                      +0.23%               +0.32%          +0.51%
(2048, 977)          sum          0.015345             0.027949               0.012192             0.015255             0.027839               0.012485                      +0.59%               +0.40%          -2.35%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.012806             0.014020               0.008291             0.013137             0.014309               0.007908                      -2.52%               -2.02%          +4.84%
(1024, 128)          sum          0.012769             0.014308               0.007924             0.012788             0.014236               0.008038                      -0.15%               +0.51%          -1.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014145             0.023049               0.009143             0.014104             0.023298               0.009501                      +0.29%               -1.07%          -3.77%
(8192, 128)          sum          0.014132             0.023082               0.009638             0.014107             0.023331               0.009244                      +0.18%               -1.07%          +4.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.013420             0.025834               0.008949             0.013368             0.025724               0.008918                      +0.39%               +0.43%          +0.35%
(1024, 130)          sum          0.013300             0.025940               0.009113             0.013266             0.025419               0.008922                      +0.26%               +2.05%          +2.14%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.013993             0.017883               0.009661             0.014275             0.018220               0.009596                      -1.98%               -1.85%          +0.68%
(8192, 130)          sum          0.014026             0.018297               0.010066             0.014326             0.018257               0.009659                      -2.09%               +0.22%          +4.21%
================================================================================================================================================================================================================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165178
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790, #165055
2025-10-19 23:39:05 +00:00
6b80c94901 [FlexAttention] Fix dynamic shaped heads flex_flash check (#165866)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #165729
2025-10-19 23:10:16 +00:00
8951df03de test_scaled_matmul_cuda: fix infer_scale_swizzle (#165788)
Extend #165747 fix to other cases.
Add parentheses to clarify operator precedence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165788
Approved by: https://github.com/jeffdaily, https://github.com/slayton58
2025-10-19 21:42:01 +00:00
8139f33fa5 [dynamo] Add recompile reason for set_stance fail_on_recompile (#165445)
Fixes #163500

### Summary:
For `set_stance("fail_on_recompile")` failures will provide the reason why the recompilation occurred

### Impacts:
module: dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165445
Approved by: https://github.com/williamwen42
2025-10-19 21:12:19 +00:00
a88587348b [dynamo] Clean up assert in dynamo [1/N] (#165430)
Fixes some part of #162852 and #164878. These two issues have some relationship though.

* __->__ #165430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165430
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42

Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
2025-10-19 21:00:05 +00:00
633a3b7f67 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217))
2025-10-19 19:20:45 +00:00
fa0db212e7 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-19 18:00:08 +00:00
15ff1cd28b Remove E721 suppression in flake8 (#165855)
Currently all files pass the E721 check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165855
Approved by: https://github.com/albanD
2025-10-19 17:51:12 +00:00
c73f5080de Migrating some more callsites (#163580)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163580
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #165582
2025-10-19 15:52:17 +00:00
22ae059d32 AOTI util deprecated flow using the new tracer (#165582)
Reapply of https://github.com/pytorch/pytorch/pull/163260

AOTI utils expect free function sometimes so adjust export API to handle that, haven't seen any methods getting exported. Some AOTI flows also require we populate dynamo_flat_name_to_original_fqn so i just copy how it is done in eval_frame.py. I also cleaned up how we get rid of export_root and fixed some overcomplicated nn_module_stack handling in export code. The logic is simpler now thanks to @anijain2305 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165582
Approved by: https://github.com/anijain2305
2025-10-19 15:52:16 +00:00
1b121d636e Fix AllocatorConfig parse roundup division bug (#165304)
* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165304
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289, #165291, #165298
2025-10-19 15:34:44 +00:00
1ba808dd97 Refine CUDA BackendStaticInitializer for allocator select (#165298)
* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165298
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289, #165291
2025-10-19 15:34:44 +00:00
b2f5c25b27 Introduce a generic API torch._C._accelerator_setAllocatorSettings (#165291)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165291
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289
2025-10-19 15:34:36 +00:00
a1114beed2 Deprecate overlapped functions in CUDAAllocatorConfig (#165289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165289
Approved by: https://github.com/albanD
ghstack dependencies: #165288
2025-10-19 15:34:26 +00:00
4888ed440e Refine Allocator Config error message friendly (#165288)
* __->__ #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165288
Approved by: https://github.com/albanD
2025-10-19 15:34:17 +00:00
5d62b63a76 [BE] Use Python-3.14 GE build (#165804)
3.14 reached general availability on Oct 7th 2025, so we can remove all pre-release workarounds
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165804
Approved by: https://github.com/yangw-dev, https://github.com/Skylion007, https://github.com/cyyever
2025-10-19 11:45:10 +00:00
57ba575242 [BE][Ez]: Update torch.is_tensor documentation (#165841)
TypeIs propogates the isinstance check with the typing system. They are now equivalent.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165841
Approved by: https://github.com/albanD
2025-10-19 09:24:11 +00:00
ceb11a584d [BE]: Update kleidai submodule to v1.15.0 (#165842)
This mostly just adds a few new kernels and fixes some IMA and performance improvement of prev kernels. Also improves compiler support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165842
Approved by: https://github.com/albanD
2025-10-19 08:25:03 +00:00
33adb276fe [BE][Ez]: Update Eigen to 5.0.0. C++14 support and more! (#165840)
Update Eigen pin to 5.0.0 . Tons of new features and perf improvements. Most importantly updates minimum from C++03 to C++14 giving a ton of performance optimizations like properly implemented move operators, simplified code, etc. Also improved vectorization particularily on ARM. We really only use this library as a fallback for sparse operators, but still useful to update it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165840
Approved by: https://github.com/albanD
2025-10-19 08:00:06 +00:00
e939651972 [audio hash update] update the pinned audio hash (#165807)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165807
Approved by: https://github.com/pytorchbot
2025-10-19 04:45:20 +00:00
3255e7872b Enable all flake8-logging-format rules (#164655)
These rules are enabled by removing existing suppressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164655
Approved by: https://github.com/janeyx99, https://github.com/mlazos
2025-10-19 00:59:28 +00:00
c4f6619330 Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
2025-10-18 23:33:24 +00:00
f18041cca8 Fix missing closing quote in __init__.py documentation (#165827)
Title says it all.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165827
Approved by: https://github.com/Skylion007
2025-10-18 22:09:18 +00:00
35e51893bd Remove CUDA 11 workarounds for CUB_SUPPORTS_SCAN_BY_KEY and CUB_SUPPORTS_UNIQUE_BY_KEY (#164637)
`CUB_SUPPORTS_SCAN_BY_KEY` and `CUB_SUPPORTS_UNIQUE_BY_KEY` are true since CUDA 12. This PR removes the old branches and source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164637
Approved by: https://github.com/ezyang
2025-10-18 20:05:54 +00:00
1f43d17ce6 Fix self assignment (#165816)
This PR removes assignments of the form `var=var`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165816
Approved by: https://github.com/jansel
2025-10-18 18:51:52 +00:00
032bed95cd Various C++ code fixes in LSAN integration (#165818)
This PR extracts the C++ code fixes from #154584, which are fixes in enabling LSAN.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165818
Approved by: https://github.com/ezyang
2025-10-18 17:59:23 +00:00
d14cbb4476 Add NVFP4 two-level scaling to scaled_mm (#165774)
Summary:

* Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing
* Add two-level tests

Test Plan:

```
pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774
Approved by: https://github.com/drisspg
2025-10-18 13:06:04 +00:00
f510d0dbc0 Clarrifying input output angle unit in the docs for trigonometric fun… (#161248)
…ctions

Fixes #[160995](https://github.com/pytorch/pytorch/issues/160995)

Modified the docs to clarify that input tensor  values for torch.sin, torch.cos and torch.tan should be in radians and the output tensor  values for torch.acos, torch.asin and torch.atan is in radians.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161248
Approved by: https://github.com/isuruf

Co-authored-by: Isuru Fernando <isuruf@gmail.com>
2025-10-18 11:53:48 +00:00
beb6b62e8c Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"
This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb.

Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
2025-10-18 09:15:49 +00:00
4740ce7787 [CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)
https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792
Approved by: https://github.com/XilunWu
2025-10-18 09:11:16 +00:00
ad67170c8b [MPS] sparse matmuls (#165232)
Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes:
https://github.com/pytorch/pytorch/issues/156540
https://github.com/pytorch/pytorch/issues/129842

Should be merged after:
https://github.com/pytorch/pytorch/pull/165102

To compare MPS and CPU, you can use this script:
```python
import torch
import time
import matplotlib.pyplot as plt

B, I, J, K = 8, 20000, 20000, 20000
num_iterations = 500

nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000]
speedups = []

for nnz in nnz_values:
    indices = torch.stack([
        torch.randint(0, B, (nnz,)),
        torch.randint(0, I, (nnz,)),
        torch.randint(0, J, (nnz,)),
    ])
    values = torch.rand(nnz)

    sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce()
    dense = torch.randn(B, J, 200, device="mps")

    t1 = time.time()
    for _ in range(num_iterations):
        result = torch.bmm(sparse, dense)
    torch.mps.synchronize()
    t2 = time.time()
    mps_time = (t2 - t1) / num_iterations

    sparse_cpu = sparse.cpu()
    dense_cpu = dense.cpu()
    t1 = time.time()
    for _ in range(num_iterations):
        result_cpu = torch.bmm(sparse_cpu, dense_cpu)
    t2 = time.time()
    cpu_time = (t2 - t1) / num_iterations

    speedup = cpu_time / mps_time
    speedups.append(speedup)
    print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x")

plt.figure(figsize=(10, 6))
plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8)
plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12)
plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12)
plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axhline(y=1, color='r', linestyle='--', alpha=0.5)
plt.xscale('log')
plt.tight_layout()
plt.show()
```

## Tested on M1 Pro
<img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232
Approved by: https://github.com/malfet
2025-10-18 09:04:42 +00:00
fdab48a7c1 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 07:36:18 +00:00
a0948d4d23 [ROCm][inductor] autotune support for persistent reduction kernels (#163908)
After the removal of want_no_x_dim for persistent reduction kernels, we can improve the autotuning setup for persistent reduction kernels.

Currently even with tuning enable, filtering will only try a single config in many cases. Avoid filtering with autotune mode, and override MAX_BLOCK limit. Also we always include tiny_config when autotuning is enabled.

Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163908
Approved by: https://github.com/jansel, https://github.com/PaulZhang12
2025-10-18 07:33:24 +00:00
0bbdd6b8db [ROCm][inductor] heuristic improvements for pointwise kernels (#163197)
Heuristic improvements for pointwise kernels for MI350.

Contributions from several members of the AMD Inductor and Triton teams:
@jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197
Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel

Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
2025-10-18 07:23:41 +00:00
24520b8386 Revert "Enable all PIE rules on ruff (#165814)"
This reverts commit c79dfdc6550e872783aa5cb5fc9e86589bf18872.

Reverted https://github.com/pytorch/pytorch/pull/165814 on behalf of https://github.com/cyyever due to Need to cover more files ([comment](https://github.com/pytorch/pytorch/pull/165814#issuecomment-3417931863))
2025-10-18 07:21:08 +00:00
c79dfdc655 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 06:40:12 +00:00
e595136187 Enable PLC1802 on ruff (#165813)
This PR enables ruff check `PLC1802`, which detects len calls on sequences in a boolean test context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165813
Approved by: https://github.com/ezyang
2025-10-18 05:44:14 +00:00
aaac8cb0f5 [1/N] Add strict parameter to Python zip calls (#165531)
Add `strict=True/False` to zip calls in test utils. `strict=True` is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165531
Approved by: https://github.com/Skylion007
2025-10-18 05:26:33 +00:00
0f0b4bf029 [1/N] Remove unused header inclusion (#165763)
This PR removes unused header inclusion in C++ files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165763
Approved by: https://github.com/Skylion007
2025-10-18 05:23:11 +00:00
b8194268a6 Remove unnecessary noqa suppressions (#164106)
This PR removes unused `noqa` suppressions in Python code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164106
Approved by: https://github.com/albanD
2025-10-18 04:52:41 +00:00
f02e3947f6 Expand type checking to mypy strict files (#165697)
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697
Approved by: https://github.com/ezyang
2025-10-18 04:34:45 +00:00
9095a9dfae [CD] Apply the fix from #162455 to aarch64+cu129 build (#165794)
When trying to bring cu129 back in https://github.com/pytorch/pytorch/pull/163029, I mainly looked at https://github.com/pytorch/pytorch/pull/163029 and missed another tweak coming from https://github.com/pytorch/pytorch/pull/162455

I discover this issue when testing aarch64+cu129 builds in https://github.com/pytorch/test-infra/actions/runs/18603342105/job/53046883322?pr=7373.  Surprisingly, there is no test running for aarch64 CUDA build from what I see in 79a37055e7.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165794
Approved by: https://github.com/malfet
2025-10-18 04:16:24 +00:00
d9f94e0d7d [dynamo] Support fx.traceback.annotate as decorator (#165805)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165805
Approved by: https://github.com/Lucaskabela, https://github.com/SherlockNoMad, https://github.com/yushangdi
2025-10-18 03:58:11 +00:00
23417ae50f [Submodule] Bump FBGEMM to latest (#165544)
Summary:

* FBGEMM submodule updated to main
* CMake updated to reflect necessary changes
* Notably pulls in NVFP4 grouped gemm kernels

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165544
Approved by: https://github.com/cyyever, https://github.com/jeffdaily
2025-10-18 03:58:08 +00:00
e4d6c56ffb Improve dynamo graph capture stack trace for custom ops (#165693)
For a custom op
```
@torch.library.custom_op("my_lib::foo", mutates_args={})
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y
```
ppl could call `torch.ops.my_lib.foo()` or directly call `foo()` in the `forward` of an `nn.Module`

These two calling conventions will lead to the same node in the output graph, but different stack traces.

When directly calling `foo()`, the displayed stack_trace in the graph will be
```
# File: .../pytorch/torch/_library/custom_ops.py:687 in __call__, code: return self._opoverload(*args, **kwargs)
```
This is not useful so we filter it out.

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_custom_op_stack_trace
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165693
Approved by: https://github.com/SherlockNoMad, https://github.com/williamwen42
2025-10-18 03:48:18 +00:00
017d2985f3 set unbacked bindings in reinplace pass for newly created nodes during generalize_scatter decomp (#164948)
Two fixes:
1. in rein_place pass, set unbacked bindings for newly created nodes.
2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164948
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #164341
2025-10-18 03:20:30 +00:00
c6a8db0b9a Fix issues with generalized_scatter and setitem allocated unbacked symbols. (#164341)
Three fixes:
1. When doing t[u0] +=1  if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0.  but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it.

2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints
but those do not effect final output, we also shall ignore them.

3.Before accessing strides in lowering we shall materialize.

Address  https://github.com/pytorch/pytorch/issues/114293 and https://github.com/pytorch/pytorch/issues/131911

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164341
Approved by: https://github.com/bobrenjc93
2025-10-18 03:20:30 +00:00
de09bab4b6 [BE]: Update cudnn frontend submodule to 1.15.0 (#165776)
Update cudnn frontend submodule to 1.15.0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165776
Approved by: https://github.com/eqy
2025-10-18 02:23:27 +00:00
c137e222d4 .venv/ in .gitignore (#165418)
`uv venv` creates venv in `.venv/` directory. So, it's useful to have `.venv/` in `.gitignore`, since perhaps more people are using `uv` in their work. As per comment 3592f5f4e5 (diff-bc37d034bad564583790a46f19d807abfe519c5671395fd494d8cce506c42947)

uv docs  that confirms it: https://docs.astral.sh/uv/pip/environments/#using-arbitrary-python-environments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165418
Approved by: https://github.com/ezyang
2025-10-18 02:00:52 +00:00
cf3a787bbc [annotate] Annotate bw nodes before eliminate dead code (#165782)
Fixes https://github.com/pytorch/torchtitan/pull/1907

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165782
Approved by: https://github.com/SherlockNoMad
2025-10-18 01:54:31 +00:00
de3da77cf7 Thread deterministic config vars to subproc compilation (#165729)
# Summary

TIL (AFTER WAYYYY TOO MUCH INSANITY), that we do not serialize the full set of configs for the subproc compilation.

I found this while working on Flex-attention determinism: https://github.com/meta-pytorch/attention-gym/pull/168

might be good to audit if we need to thread through any more

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165729
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-10-18 01:25:50 +00:00
543ddbf44c [ONNX] Support renaming in dynamic axes to shapes conversion (#165769)
Discovered in ##165748

This PR also deprecates the conversion. ONNX exporter team does not intend to maintain the conversion in long term.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165769
Approved by: https://github.com/justinchuby
2025-10-18 01:11:20 +00:00
e9f4999985 [Code Clean] Replace std::runtime_error with TORCH_CHECK (#165305)
Fixes part of #148114

Including:

- torch/csrc/distributed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165305
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-18 01:08:44 +00:00
29b029648e Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765)
Fixes #164286

Fixed issue with GradTrackingTensor not properly propagating sparse layout.

@ezyang @jcaip
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765
Approved by: https://github.com/ezyang
2025-10-18 01:00:53 +00:00
a25a649e70 [Mem Snapshot] Add Metadata Field (#165490)
Summary:
The implementation adds the ability to:

Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()

Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.

Differential Revision: D84654933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
2025-10-17 23:46:02 +00:00
69c33898fa Revert "[Inductor][CuTeDSL] Move load_template up two directories (#165347) (#165576)"
This reverts commit febb60323018948b2b9d2cff35b3cc4e0d0c55c8.

Reverted https://github.com/pytorch/pytorch/pull/165576 on behalf of https://github.com/seemethere due to This was actually reverted internally, current PR is linked to a stale diff so diff train tools think that this is landed via co-dev when it was actually reverted ([comment](https://github.com/pytorch/pytorch/pull/165576#issuecomment-3417510146))
2025-10-17 23:33:17 +00:00
1b397420f2 Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
2025-10-17 23:28:22 +00:00
fe80f03726 Add B200 files to labeler and update codeowners (#165767)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165767
Approved by: https://github.com/slayton58
2025-10-17 23:24:17 +00:00
e50dc40d28 Revert "Update gm.print_readable to include Annotation (#165397)"
This reverts commit 7a657700131f31577544e93587eb339618677e97.

Reverted https://github.com/pytorch/pytorch/pull/165397 on behalf of https://github.com/malfet due to I don't know how/why, but it breaks windows tests, see 2e22b1a61e/1 ([comment](https://github.com/pytorch/pytorch/pull/165397#issuecomment-3417428128))
2025-10-17 22:35:50 +00:00
2e22b1a61e [pytorch] Composite backend potential fix for is_backend_available (#165061)
Summary: `is_backend_available` takes in a string and expects it to only be backend, if its given a composite (device:backend) string, it fails.

Reviewed By: prashrock

Differential Revision: D81886736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165061
Approved by: https://github.com/H-Huang
2025-10-17 22:06:36 +00:00
616c6bdf8f [dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775
Approved by: https://github.com/zou3519
ghstack dependencies: #165734
2025-10-17 22:04:19 +00:00
c18ddfc572 [dynamo][easy] Support torch.accelerator.current_accelerator (#165734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165734
Approved by: https://github.com/Skylion007
2025-10-17 22:04:19 +00:00
86ebce1766 [precompile] Pass tensor_to_context to backend. (#165702)
Summary:

Fixing a VLLM issue https://github.com/vllm-project/vllm/issues/27040 where
aot precompile fails on some models using symbolic shapes in inductor.

Test Plan:
pp HF_HUB_DISABLE_XET=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_USE_AOT_COMPILE=1 vllm bench latency --model microsoft/DialoGPT-small --input-len 128 --output-len 256 --num-iters 50 --dtype float16

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165702
Approved by: https://github.com/tugsbayasgalan
2025-10-17 21:52:04 +00:00
8cb2fb44f2 [Inductor] Support fallback for all gemm like ops (#165755)
Summary: Fill op_override field for bmm aten ops so they can be converted properly in the wrapper_fxir backend

Reviewed By: StellarrZ

Differential Revision: D84840948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165755
Approved by: https://github.com/blaine-rister
2025-10-17 21:08:29 +00:00
ab65498d71 Fix _StridedShard incorrect split (#165533)
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)

Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.

Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
2025-10-17 20:54:46 +00:00
06d324365c Revert "Escaped html tags name and target to appear as strings (#165543)"
This reverts commit 080365b7d82a3c99c995cab6dc912b7dfe22aa41.

Reverted https://github.com/pytorch/pytorch/pull/165543 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165543#issuecomment-3417102048))
2025-10-17 20:45:48 +00:00
6c9c6e0936 Enable C407 of flake8 (#165046)
This PR enables C407 on flake8. The description is `C407` is `Unnecessary list comprehension - ‘<builtin>’ can take a generator`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165046
Approved by: https://github.com/albanD
2025-10-17 20:15:39 +00:00
2bcd892c86 [distributed] Replace assert statements in distributed checkpoint with explicit checks (#165256)
Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256
Approved by: https://github.com/albanD
2025-10-17 20:14:35 +00:00
75e2a9fae3 [annotate] add annotate_fn function decorator (#165703)
Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
2025-10-17 20:10:53 +00:00
a16fd6b488 [NVSHMEM][Triton] Fix NVSHMEM triton test for wacky world sizes (#165704)
Currently assumes divisible by 4? world size

Not as slick as the old setup code but more general

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165704
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
2025-10-17 19:33:26 +00:00
382b0150de [docs] Add usage examples to ConvTranspose1d docstring (#165618)
Fixes #165615

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165618
Approved by: https://github.com/mikaylagawarecki
2025-10-17 19:11:57 +00:00
a664b299ac Update docs for torch.mode (#165614)
Currently the docs for `torch.mode` include a note:

`This function is not defined for torch.cuda.Tensor yet.`

However with `torch==2.7.1+cu126` when I try to get the mode of a Tensor that is in cuda memory, I do not face any issues:

```
>>> a = torch.tensor([0, 2, 1, 1, 1, 3, 3])
>>> a.mode()
torch.return_types.mode(
values=tensor(1),
indices=tensor(4))
>>> a.cuda().mode()
torch.return_types.mode(
values=tensor(1, device='cuda:0'),
indices=tensor(4, device='cuda:0'))
```

Am I misunderstanding the note? If not, I suggest removing it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165614
Approved by: https://github.com/mikaylagawarecki
2025-10-17 19:06:33 +00:00
9c12651417 Improve error message for non-positive groups in convolution (#165669)
Prevents from segmentation fault for invalid groups value in convolution.

Fixes #142835

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165669
Approved by: https://github.com/mikaylagawarecki
2025-10-17 19:06:05 +00:00
08c97b4a1f Don't run compile inside kernel invocation (#165687)
When we call torch.compile during fake tensor prop, we shouldn't actually compile because we can't guarantee that the compiled artifact can be fake tensor prop-d. (for example, inductor backend). Instead we should just skip compiling. However, the inner compile will be triggered when being executed in runtime.

Fixes: https://github.com/pytorch/pytorch/issues/151328

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165687
Approved by: https://github.com/zou3519
2025-10-17 19:03:57 +00:00
fae74cd52f Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit a032510db38e8331afa08f7635d146f9cefdd0ab.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3416718767))
2025-10-17 18:55:53 +00:00
7a65770013 Update gm.print_readable to include Annotation (#165397)
Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305
2025-10-17 18:35:18 +00:00
e4454947e2 Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
2025-10-17 18:32:39 +00:00
3806e9767b Refactor out headeronly ArrayRef (#164991)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991
Approved by: https://github.com/swolchok
2025-10-17 18:32:39 +00:00
b08d8c2e50 Revert "[DebugMode][2/N] add nn.Module tracking (#165498)"
This reverts commit 45afaf08a14ab760d86ea80dea6d50cec8626513.

Reverted https://github.com/pytorch/pytorch/pull/165498 on behalf of https://github.com/seemethere due to First part of the stack was reverted so will need to revert this too ([comment](https://github.com/pytorch/pytorch/pull/165498#issuecomment-3416618198))
2025-10-17 18:22:48 +00:00
ca5b7f8ded torch.compile: populate compiler_config (#165581)
Summary: This starts writing the compiler_config metadata into logger

Test Plan:
Modified existing test case to make sure this is not null.
(Also eyeballed what we're logging tomake sure it's reasonable

Reviewed By: masnesral

Differential Revision: D84014636

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165581
Approved by: https://github.com/masnesral
2025-10-17 18:21:18 +00:00
9a71d96256 Revert "[DebugMode][1/N] refactor logs into _DebugCalls (#165376)"
This reverts commit 556fc09a9f67f24ca5591ec049c5d0c347c5f62a.

Reverted https://github.com/pytorch/pytorch/pull/165376 on behalf of https://github.com/seemethere due to This is failing for internal tests, see D84877379 for more context ([comment](https://github.com/pytorch/pytorch/pull/165376#issuecomment-3416570407))
2025-10-17 18:08:59 +00:00
0d4c2b71e8 [DeviceMesh] Simplify unflatten method (#165556)
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
2025-10-17 17:57:51 +00:00
d659bbde62 [DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
2025-10-17 17:57:51 +00:00
58879bfafa [DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
2025-10-17 17:57:51 +00:00
a032510db3 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
2025-10-17 17:55:03 +00:00
39e0a832c9 Fix B200 test fails in scaled_mm (#165747)
Summary:

PR #165528 changes some scale/swizzle inference behavior in scaled_mm
tests - mxfp8 tests on Blackwell can get incorrectly classified,
resulting in failures.

Fix the scale/swizzle inference code to prevent this.

Fixes https://github.com/pytorch/pytorch/issues/165743

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

@jagadish-amd @jeffdaily @drisspg

Subscribers:

@Aidyn-A

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165747
Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/jeffdaily
2025-10-17 17:52:19 +00:00
dd3b48e85d Fix bug with serialization after AOTAutogradCache hit (#165474)
Fixes #165447

On AOTAutogradCache load, the serialization function we pick is just lambda: self, because the object itself is an AOTAutogradCacheEntry. However, this isn't safe, because `wrap_post_compile` will make `self` unserializable, since it needs to load triton kernels and stuff!

So instead, on AOTAutogradCache load, we preserve the bytes that were used to load the object to begin with, and return that object on a call to serialize(). This effectively makes it so that we save a copy of the pre-hydrated artifact, without needing to do an eager copy until someone actually calls `serialize`.

Test Plan:

Run

```py
import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(2, 4)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(4, 8)
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

device = "cuda"
m = M().to(device)
sample_inputs = (torch.randn(2, 2, device=device),)
eager_out = m(*sample_inputs)

with torch._dynamo.config.patch("enable_aot_compile", True):
    compiled_fn_path = "./m.pt"
    compiled_fn = torch.compile(
        m,
        fullgraph=True
    ).forward.aot_compile((sample_inputs, {}))

    compiled_fn.save_compiled_function(compiled_fn_path)
    torch._dynamo.reset()
    with torch.compiler.set_stance("fail_on_recompile"):
        with open(compiled_fn_path, "rb") as f:
            loaded_fn = torch.compiler.load_compiled_function(f)

assert loaded_fn is not None

compiled_out = loaded_fn(m, *sample_inputs)

assert torch.allclose(eager_out, compiled_out)
```

twice, see that it succeeds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165474
Approved by: https://github.com/yiming0416, https://github.com/zhxchen17
2025-10-17 17:47:24 +00:00
cff1b20771 Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)
The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them)

Fixes #164247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923
Approved by: https://github.com/williamwen42
2025-10-17 17:44:45 +00:00
da8517fa63 [ROCm][CI] upgrade wheels to 7.0.2 and 6.4.4 patch release (#165756)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165756
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-17 17:41:19 +00:00
45afaf08a1 [DebugMode][2/N] add nn.Module tracking (#165498)
Uses ModTracker to record nn.Module entries, much like CommDebugMode.

Can be switched on with `DebugMode(record_nn_module=True)`:
```
    [nn.Mod] Bar
      [nn.Mod] Bar.abc
        [nn.Mod] Bar.abc.l1
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
        [nn.Mod] Bar.abc.l2
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
      [nn.Mod] Bar.xyz
        aten::t(t: f32[4, 4])
        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])"""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165498
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165376
2025-10-17 17:39:48 +00:00
080365b7d8 Escaped html tags name and target to appear as strings (#165543)
Fixes small typo in markdown documentation file - Added escape characters to precede tag pattern.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165543
Approved by: https://github.com/mikaylagawarecki
2025-10-17 17:35:18 +00:00
2928c5c572 Revert "Pyrefly suppressions 2 (#165692)"
This reverts commit 43d78423ac224cce432bf34ed9627035169d5433.

Reverted https://github.com/pytorch/pytorch/pull/165692 on behalf of https://github.com/seemethere due to This is causing merge conflicts when attempting to land internally, see D84890919 for more details ([comment](https://github.com/pytorch/pytorch/pull/165692#issuecomment-3416397240))
2025-10-17 17:13:04 +00:00
630520b346 [dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #165683, #165706
2025-10-17 17:02:18 +00:00
1dc9a05d03 [dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build (#165706)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165706
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #165683
2025-10-17 17:02:18 +00:00
bfcdbd0a97 fix wrong accuracy_status when exception. (#165731)
When I debug `XPU` accruacy issue, I found the script output wrong accuracy_status.
When the `try` block raise an exception, we should process the exception, but not return the `fail_accuracy`.

Before fixing, it returned as `fail_accuracy`:
<img width="1109" height="216" alt="image" src="https://github.com/user-attachments/assets/385c354f-fbf6-48e4-a1be-3e37e987341b" />

After fixing, it returned the exception message:
<img width="1101" height="292" alt="image" src="https://github.com/user-attachments/assets/f18c0e3c-8358-4ec7-a6bb-c2e01b69d27f" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165731
Approved by: https://github.com/Stonepia, https://github.com/chuanqi129, https://github.com/Lucaskabela
2025-10-17 16:37:06 +00:00
faff826a46 Revert "[ROCm] new implementation of upsample_bilinear2d_backward (#164572)"
This reverts commit 53f9ae0e50d4dcc47f2ca4bf854803f9d4f875ae.

Reverted https://github.com/pytorch/pytorch/pull/164572 on behalf of https://github.com/seemethere due to Looks like this is failing in our internal builds, will post a suggestion for a fix but want you to double verify that this behavior is correct ([comment](https://github.com/pytorch/pytorch/pull/164572#issuecomment-3416262676))
2025-10-17 16:27:59 +00:00
85c5433d38 Revert "Fix _StridedShard incorrect split (#165533)"
This reverts commit dfc8a1c5ddc8401197e9ab546e03b0f745edc27b.

Reverted https://github.com/pytorch/pytorch/pull/165533 on behalf of https://github.com/seemethere due to Causing a merge conflict internally, see D84829161 ([comment](https://github.com/pytorch/pytorch/pull/165533#issuecomment-3416143176))
2025-10-17 15:57:01 +00:00
935ccdbe75 [MPS] Fix internal assertion in torch.linalg.solve for singular matrices (#165254)
Fixes #163962 by special casing MPS in the negative status code branch in `_linalg_check_errors`.

Checks if info is [`MPSMatrixDecompositionStatus.singular`](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus/singular) (which has a raw value of -2). I didn't find an official Apple source with this raw value (besides printing the enum value), so I'm not sure if we can (or should) depend on it? Is there a way to directly get the Objective-C enum value in C++?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165254
Approved by: https://github.com/malfet
2025-10-17 15:35:49 +00:00
3af2f0c12a [inductor] require shape in TritonCSEVariable (#162275)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162275
Approved by: https://github.com/mlazos
ghstack dependencies: #164158
2025-10-17 14:47:45 +00:00
6ece527fc5 [CI] Add aarch64 operator benchmark (#165585)
Running on Graviton4
Skip ConvTranspose1d benchmarks if PyTorch is compiled with ACL, due to https://github.com/pytorch/pytorch/issues/165654
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165585
Approved by: https://github.com/huydhn
2025-10-17 14:42:14 +00:00
ce29d0d796 [ATen] Vectorize 8 elements on 16 bit data types for sum/mean (#165055)
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165055
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790
2025-10-17 13:39:36 +00:00
7231118db3 Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401
Approved by: https://github.com/Skylion007
2025-10-17 13:24:46 +00:00
5d4da26ed0 Revert "[export] preserve_node_meta by default (#165524)"
This reverts commit fdd560afd1d413a9f814cbf7cc2a72e0d39b0117.

Reverted https://github.com/pytorch/pytorch/pull/165524 on behalf of https://github.com/lw due to test/functorch/test_control_flow.py::TestControlFlowTraced::test_cond_symint_closure [GH job link](https://github.com/pytorch/pytorch/actions/runs/18586312291/job/52991654051) [HUD commit link](fdd560afd1) ([comment](https://github.com/pytorch/pytorch/pull/165524#issuecomment-3415352522))
2025-10-17 12:27:17 +00:00
574c9fc950 Revert "Remove torch.serialization entries from the doc ignore list (#160224)"
This reverts commit 9fe3b2afbeff12080b483af1ee23e1c9d9fb0421.

Reverted https://github.com/pytorch/pytorch/pull/160224 on behalf of https://github.com/lw due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/18588004962/job/52997748336) [HUD commit link](9fe3b2afbe) ([comment](https://github.com/pytorch/pytorch/pull/160224#issuecomment-3415345175))
2025-10-17 12:24:08 +00:00
80d2ca7566 Revert "[annotate] add annotate_fn function decorator (#165703)"
This reverts commit f1d882212afc3a73ce1e319d80b6406f9dc4a0c8.

Reverted https://github.com/pytorch/pytorch/pull/165703 on behalf of https://github.com/lw due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/18585518705/job/52989521797) [HUD commit link](f1d882212a) ([comment](https://github.com/pytorch/pytorch/pull/165703#issuecomment-3415073467))
2025-10-17 11:23:13 +00:00
4a22139eea [MPS][BE] Fix unused variable warning (#165726)
Namely this one
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:19:18: warning: unused variable 'output_sizes' [-Wunused-variable]
  constant auto& output_sizes = shared_params.output_sizes;
                 ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:85:1: note: in instantiation of function template specialization 'cat<long, float, float>' requested here
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:69:3: note: expanded from macro 'REGISTER_CAT_FOR_INDEX_TYPE'
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float);  \
  ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:55:3: note: expanded from macro 'REGISTER_CAT_OP_ALL_INPUT_TYPES'
  REGISTER_CAT_OP(I, float, T_out);               \
  ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:47:15: note: expanded from macro 'REGISTER_CAT_OP'
  kernel void cat<I, T_in, T_out>(                               \
```

Repeated about 20-30 times
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165726
Approved by: https://github.com/Skylion007
2025-10-17 11:16:21 +00:00
cb6e4d7d82 User-passed alpha to scaled_gemm (#165563)
Summary:

Add optional user-passed `alpha` argument to
`at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm
calls (where the global de-scales are folded into the `alpha` argument.

Global de-scales are naturally device tensors, but using cublas'
device-pointer mode for `alpha`/`beta` has an interesting lifetime
implication - the `alpha` tensor must be valid & correct until the end
of the matmul call, *not* just the launch (as for host values). To
enable this, I added device-constant memory for `one` and `zero`, along
with a statically-held single-fp32-value tensor, which is valid from the
first passed-`alpha` invocation of `scaled_gemm` to the end of the
program. User-passed values are copied into this perpetual buffer to
ensure lifetime requirements are met.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563
Approved by: https://github.com/drisspg, https://github.com/eqy
2025-10-17 09:42:33 +00:00
202f83dc4e [ROCm][layer_norm] Use __builtin_amdgcn_rcpf(x) instead of 1.f/x (#165589)
Replace (more) exact calculation with hardware approximation.

Benefits:
Reduced code size.
Improved performance for certain scenarios.

Experiments show low reduction in precision.
Experiments show no significant performance regressions. bfloat16 as well as float16 related calculations may benefit largely from this change.

Co-author: @mhalk @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165589
Approved by: https://github.com/jeffdaily
2025-10-17 09:12:30 +00:00
9fe3b2afbe Remove torch.serialization entries from the doc ignore list (#160224)
Follows the approach done in #158581
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160224
Approved by: https://github.com/janeyx99
2025-10-17 09:06:09 +00:00
d0c24b392c [APF Logging][Error Trait] To fill the errorTraits for ChildFailedError with signal abort (re-attempt of #165476) (#165688)
**Summary**
Land @guoding83128 's PR https://github.com/pytorch/pytorch/pull/165476 on his behalf due to EasyCLA blocking.
Refer his original PR for detail. But in short, elastic leaves 'errorTraits' as unknown when the error dump file is missing,
this PR adds a "system terminated error" to such case so the internal scuba table can correctly aggregate.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165688
Approved by: https://github.com/fduwjj
2025-10-17 08:23:27 +00:00
b44fb14906 Remove unused parameter when query extension attribute (#165623)
# Motivation
This code is no longer needed since SYCL compiler 2025.0. We are now using compiler 2025.2 (two tool uplifts later), so it can be safely removed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165623
Approved by: https://github.com/EikanWang
ghstack dependencies: #165622
2025-10-17 08:16:13 +00:00
51348c0219 Give a friendly message for older Intel GPU (#165622)
# Motivation
Notify the user if the GPU is older than officially supported. This provides a friendly warning that the GPU may work, but the experience could be unstable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165622
Approved by: https://github.com/EikanWang
2025-10-17 08:16:13 +00:00
fdd560afd1 [export] preserve_node_meta by default (#165524)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165524
Approved by: https://github.com/malaybag
2025-10-17 07:55:28 +00:00
e925dfcc6b Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-17 07:27:11 +00:00
f1d882212a [annotate] add annotate_fn function decorator (#165703)
Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
2025-10-17 07:18:47 +00:00
24879f0de9 [dynamo] Use Variable Builder to build the property fget object (#165683)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165683
Approved by: https://github.com/ezyang, https://github.com/williamwen42
2025-10-17 06:29:24 +00:00
658 changed files with 9082 additions and 4180 deletions

View File

@ -83,10 +83,6 @@ function build_cpython {
py_suffix=${py_ver::-1}
py_folder=$py_suffix
fi
# Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4
if [ "$py_suffix" == "3.14.0" ]; then
py_suffix="3.14.0rc2"
fi
wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
do_cpython_build $py_ver Python-$py_suffix

View File

@ -39,9 +39,13 @@ case ${DOCKER_TAG_PREFIX} in
DOCKER_GPU_BUILD_ARG=""
;;
rocm*)
# we want the patch version of 7.0 instead
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
fi
# we want the patch version of 6.4 instead
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete

View File

@ -75,9 +75,13 @@ case ${image} in
DOCKERFILE_SUFFIX="_cuda_aarch64"
;;
manylinux2_28-builder:rocm*)
# we want the patch version of 7.0 instead
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
fi
# we want the patch version of 6.4 instead
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
fi
TARGET=rocm_final
MANY_LINUX_VERSION="2_28"

View File

@ -1,15 +1,11 @@
sphinx==5.3.0
sphinx==7.2.6
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
#Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
pytorch_sphinx_theme2==0.1.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.1.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.34.0
breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0
#Pinned versions: 4.36.0
exhale==0.2.3
exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3
#Pinned versions: 0.3.7
docutils==0.16
docutils==0.20
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16
#Pinned versions: 0.20
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==0.17.2
myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
#Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-parser==4.0.1

View File

@ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules
logger.info("Successfully cloned %s", target)
return r, commit
except GitCommandError as e:
logger.error("Git operation failed: %s", e)
except GitCommandError:
logger.exception("Git operation failed")
raise

View File

@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
echo "======================================"
echo "ERROR: $undocumented undocumented objects found!"
echo "======================================"
echo ""
echo "Full coverage report:"
cat build/coverage/python.txt
echo ""
echo "======================================"
echo "Undocumented modules/objects (lines after TOTAL):"
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
echo "======================================"
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1

View File

@ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then
MEMORY_LIMIT_MAX_JOBS=12
NUM_CPUS=$(( $(nproc) - 2 ))
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
if [[ "$(uname)" == Linux ]]; then
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
else
# For other builds
export MAX_JOBS=${NUM_CPUS}
fi
cat >>"$envfile" <<EOL
export MAX_JOBS="${MAX_JOBS}"

View File

@ -7,16 +7,12 @@ max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
# these ignores are from flake8-comprehensions; please fix!
C407,
# these ignores are from flake8-logging-format; please fix!
G100,G101,G200
# these ignores are from flake8-simplify. please fix or ignore with commented reason
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
# SIM104 is already covered by pyupgrade ruff

View File

@ -1 +1 @@
1b013f5b5a87a1882eb143c26d79d091150d6a37
69bbe7363897764f9e758d851cd0340147d27f94

29
.github/labeler.yml vendored
View File

@ -133,3 +133,32 @@
"ciflow/vllm":
- .github/ci_commit_pins/vllm.txt
"ciflow/b200":
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/h100":
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/rocm":
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm

View File

@ -33,6 +33,7 @@ ciflow_push_tags:
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/rocm-navi31
- ciflow/s390
- ciflow/slow
- ciflow/torchbench

View File

@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
"nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
),
"12.9": (
"nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
"nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
"nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
"nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
"nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
"nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
"nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
"nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
"nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
"nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
"nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
),
"13.0": (
"nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "

View File

@ -1092,7 +1092,7 @@ class GitHubPR:
editor = node["editor"]
return GitHubComment(
body_text=node["bodyText"],
created_at=node["createdAt"] if "createdAt" in node else "",
created_at=node.get("createdAt", ""),
author_login=node["author"]["login"],
author_url=node["author"].get("url", None),
author_association=node["authorAssociation"],

View File

@ -26,9 +26,8 @@ name: !{{ build_environment }}
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}"
python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}"
freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
{%- endmacro %}

View File

@ -79,9 +79,9 @@ jobs:
runs-on: "windows-11-arm64-preview"
{%- else %}
{%- if branches == "nightly" %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
{%- else %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
{%- endif %}
{%- endif %}
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}

View File

@ -224,7 +224,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_10-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -473,7 +473,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_11-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -722,7 +722,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_12-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -971,7 +971,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_13-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1220,7 +1220,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_13t-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1469,7 +1469,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_14-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1718,7 +1718,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_14t-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -259,7 +259,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_10-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_10-cuda12_9-test: # Testing
@ -925,7 +925,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_11-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_11-cuda12_9-test: # Testing
@ -1591,7 +1591,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_12-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-cuda12_9-test: # Testing
@ -2257,7 +2257,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_13-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_13-cuda12_9-test: # Testing
@ -2923,7 +2923,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_13t-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_13t-cuda12_9-test: # Testing
@ -3589,7 +3589,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_14-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_14-cuda12_9-test: # Testing
@ -4255,7 +4255,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_14t-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_14t-cuda12_9-test: # Testing

View File

@ -63,7 +63,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.10.4"
freethreaded: false

View File

@ -59,7 +59,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.10.4"
freethreaded: false
@ -169,7 +168,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.11.4"
freethreaded: false
@ -279,7 +277,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.12.4"
freethreaded: false
@ -389,7 +386,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.13.4"
freethreaded: false
@ -499,7 +495,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.13.4"
freethreaded: true
@ -609,9 +604,8 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.14.0-rc.2"
python-version: "3.14.0"
freethreaded: false
- name: Checkout PyTorch
uses: actions/checkout@v4
@ -719,9 +713,8 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.14.0-rc.2"
python-version: "3.14.0"
freethreaded: true
- name: Checkout PyTorch
uses: actions/checkout@v4

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -279,7 +279,7 @@ jobs:
wheel-py3_10-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -517,7 +517,7 @@ jobs:
wheel-py3_10-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -755,7 +755,7 @@ jobs:
wheel-py3_10-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -993,7 +993,7 @@ jobs:
wheel-py3_10-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1229,7 +1229,7 @@ jobs:
wheel-py3_11-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1464,7 +1464,7 @@ jobs:
wheel-py3_11-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1702,7 +1702,7 @@ jobs:
wheel-py3_11-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1940,7 +1940,7 @@ jobs:
wheel-py3_11-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2178,7 +2178,7 @@ jobs:
wheel-py3_11-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2414,7 +2414,7 @@ jobs:
wheel-py3_12-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2649,7 +2649,7 @@ jobs:
wheel-py3_12-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2887,7 +2887,7 @@ jobs:
wheel-py3_12-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3125,7 +3125,7 @@ jobs:
wheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3363,7 +3363,7 @@ jobs:
wheel-py3_12-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3599,7 +3599,7 @@ jobs:
wheel-py3_13-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3834,7 +3834,7 @@ jobs:
wheel-py3_13-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4072,7 +4072,7 @@ jobs:
wheel-py3_13-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4310,7 +4310,7 @@ jobs:
wheel-py3_13-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4548,7 +4548,7 @@ jobs:
wheel-py3_13-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4784,7 +4784,7 @@ jobs:
wheel-py3_13t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5019,7 +5019,7 @@ jobs:
wheel-py3_13t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5257,7 +5257,7 @@ jobs:
wheel-py3_13t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5495,7 +5495,7 @@ jobs:
wheel-py3_13t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5733,7 +5733,7 @@ jobs:
wheel-py3_13t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5969,7 +5969,7 @@ jobs:
wheel-py3_14-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6204,7 +6204,7 @@ jobs:
wheel-py3_14-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6442,7 +6442,7 @@ jobs:
wheel-py3_14-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6680,7 +6680,7 @@ jobs:
wheel-py3_14-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6918,7 +6918,7 @@ jobs:
wheel-py3_14-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7154,7 +7154,7 @@ jobs:
wheel-py3_14t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7389,7 +7389,7 @@ jobs:
wheel-py3_14t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7627,7 +7627,7 @@ jobs:
wheel-py3_14t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7865,7 +7865,7 @@ jobs:
wheel-py3_14t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -8103,7 +8103,7 @@ jobs:
wheel-py3_14t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -52,3 +52,27 @@ jobs:
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
secrets: inherit
aarch64-opbenchmark-build:
if: github.repository_owner == 'pytorch'
name: aarch64-opbenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
]}
secrets: inherit
aarch64-opbenchmark-test:
name: aarch64-opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: aarch64-opbenchmark-build
with:
build-environment: linux-jammy-aarch64-py3.10
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
secrets: inherit

63
.github/workflows/rocm-navi31.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: rocm-navi31
on:
push:
tags:
- ciflow/rocm-navi31/*
workflow_dispatch:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 45 */2 * * 1-5
- cron: 45 4,12 * * 0,6
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
target-determination:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/target_determination.yml
permissions:
id-token: write
contents: read
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: >-
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune' || '' }}
secrets: inherit

View File

@ -59,29 +59,3 @@ jobs:
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-gfx1100-test:
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10-gfx1100
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
tests-to-include: >
test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune
secrets: inherit

View File

@ -190,6 +190,40 @@ jobs:
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
secrets: inherit
linux-jammy-rocm-py3_10-build:
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
secrets: inherit
inductor-build:
name: inductor-build
uses: ./.github/workflows/_linux-build.yml

1
.gitignore vendored
View File

@ -374,6 +374,7 @@ third_party/ruy/
third_party/glog/
# Virtualenv
.venv/
venv/
# Log files

View File

@ -201,3 +201,17 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
/torch/headeronly/ @janeyx99
/torch/header_only_apis.txt @janeyx99
# FlexAttention
/torch/nn/attention/flex_attention.py @drisspg
/torch/_higher_order_ops/flex_attention.py @drisspg
/torch/_inductor/kernel/flex/ @drisspg
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -313,13 +313,14 @@ IF(USE_FBGEMM_GENAI)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
endif()
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(64) FreeBlockList {
struct alignas(hardware_destructive_interference_size) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(64) HostStatsStaged {
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(64) std::mutex blocks_mutex_;
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(64) std::mutex events_mutex_;
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(64) HostStatsStaged stats_;
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -229,10 +229,10 @@ private:
}
static const uint32_t kPhilox10A = 0x9E3779B9;
static const uint32_t kPhilox10B = 0xBB67AE85;
static const uint32_t kPhiloxSA = 0xD2511F53;
static const uint32_t kPhiloxSB = 0xCD9E8D57;
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
};
typedef philox_engine Philox4_32;

View File

@ -16,6 +16,8 @@
#include <c10/util/irange.h>
#include <c10/core/ScalarType.h>
#include <ATen/cuda/detail/BLASConstants.h>
#ifdef USE_ROCM
#include <c10/cuda/CUDAStream.h>
#include <hipblaslt/hipblaslt-ext.hpp>
@ -1954,13 +1956,15 @@ void scaled_gemm(
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum) {
bool use_fast_accum,
const std::optional<Tensor>& alpha) {
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
// of input arguments to this function.
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
const float alpha_val = 1.0;
const float beta_val = 0.0;
// Note: alpha_val may change later depending on user-passed argument
float alpha_val = 1.0;
float beta_val = 0.0;
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
@ -2031,6 +2035,33 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
// Handle user-passed alpha
float *alpha_ptr = &alpha_val;
float *beta_ptr = &beta_val;
if (alpha.has_value()) {
auto& a = alpha.value();
// if device-tensor
if (a.is_cuda()) {
// NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
// valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
// we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
// managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
// for beta respectively.
float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
user_alpha.copy_(a);
// Tell cublasLt we're using device-side pointers for alpha/beta
auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
alpha_ptr = user_alpha.data_ptr<float>();
beta_ptr = at::cuda::detail::get_cublas_device_zero();
} else {
alpha_val = a.item<float>();
}
}
// For other data types, use the get_scale_mode function based on scaling type
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
@ -2048,6 +2079,7 @@ void scaled_gemm(
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
@ -2088,10 +2120,10 @@ void scaled_gemm(
auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
alpha_ptr,
Adesc.descriptor(),
Bdesc.descriptor(),
&beta_val,
beta_ptr,
Cdesc.descriptor(),
Ddesc.descriptor(),
all_algos[i].algo,
@ -2110,17 +2142,14 @@ void scaled_gemm(
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
alpha_ptr,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
#ifdef USE_ROCM
beta_ptr,
// NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
#else
nullptr,
#endif // ifdef USE_ROCM
Cdesc.descriptor(),
result_ptr,
Ddesc.descriptor(),

View File

@ -161,7 +161,8 @@ void scaled_gemm(
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum);
bool use_fast_accum,
const std::optional<Tensor>& alpha);
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)

View File

@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
*/
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
detail::check_rng_state(new_state);

View File

@ -183,11 +183,6 @@ struct CUDACachingHostAllocatorImpl
return true;
}
bool pinned_use_background_threads() override {
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
pinned_use_background_threads();
}
EventPool::Event create_event_internal(DeviceIndex idx) {
// Leak the event pool to avoid shutdown issue.
static auto* event_pool = new EventPool();

View File

@ -177,7 +177,6 @@ inline void segmented_sort_pairs(
}
}
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
inline void unique_by_key(
KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
@ -193,7 +192,6 @@ inline void unique_by_key(
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
}
#endif
namespace impl {
@ -579,7 +577,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
#endif
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
@ -607,7 +604,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
#endif
}
#endif
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
void unique(InputIteratorT input, OutputIteratorT output,

View File

@ -28,22 +28,6 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
// cub support for UniqueByKey is added to cub 1.16 in:
// https://github.com/NVIDIA/cub/pull/405
#if CUB_VERSION >= 101600
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
#else
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
#endif
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500

View File

@ -0,0 +1,54 @@
#include <ATen/Functions.h>
#include <ATen/Tensor.h>
#include <ATen/cuda/Exceptions.h>
#include <mutex>
namespace at {
namespace cuda {
namespace detail {
__device__ __constant__ float cublas_one_device;
__device__ __constant__ float cublas_zero_device;
float *get_cublas_device_one() {
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
const float one = 1.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
return ptr;
}
float *get_cublas_device_zero() {
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
const float zero = 0.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
return ptr;
}
float *get_user_alpha_ptr() {
static float *alpha_ptr;
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
});
return alpha_ptr;
}
} // namespace detail
} // namespace cuda
} // namespace at

View File

@ -0,0 +1,11 @@
#pragma once
#include <ATen/core/TensorBase.h>
namespace at::cuda::detail {
float *get_cublas_device_one();
float *get_cublas_device_zero();
float *get_user_alpha_ptr();
} // namespace at::cuda::detail

View File

@ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->use_fast_accum);
params->use_fast_accum,
std::nullopt /* alpha */);
return OK;
}
};

View File

@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
});
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {

View File

@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
namespace at::native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);

View File

@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
#if AT_BUILD_WITH_BLAS()
template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax;
}
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
int64_t incx,
[[maybe_unused]] float beta,
int64_t incy) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
}

View File

@ -658,6 +658,7 @@ static void check_shape_forward(const at::Tensor& input,
TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
TORCH_CHECK(weight_dim == k,
"Expected ", weight_dim, "-dimensional input for ", weight_dim,

View File

@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <ATen/native/Math.h>
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
@ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
const static scalar_t kTailValues[] = {
constexpr static scalar_t kTailValues[] = {
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
@ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
0.00925546218271273,
0.00833056343336287
};
if (k <= 9) {
if (k < std::size(kTailValues)) {
return kTailValues[static_cast<size_t>(k)];
}
scalar_t kp1sq = (k + 1) * (k + 1);

View File

@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
try {
mkldnn_matmul_i8i8i32(self, mat2, result);
dispatched = true;
} catch (const std::exception& e) {
} catch ([[maybe_unused]] const std::exception& e) {
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
}
}

View File

@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static const scalar_t d[25][25] =
static constexpr scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,

View File

@ -62,7 +62,7 @@
#include <utility>
#include <vector>
static const int MIOPEN_DIM_MAX = 5;
static constexpr int MIOPEN_DIM_MAX = 5;
namespace at::meta {

View File

@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
// We keep this structure for BC and consider as deprecated.
// See HelperInterpNearestExact as replacement
static const int interp_size = 1;
static constexpr int interp_size = 1;
static inline void init_indices_weights(
at::ScalarType output_type,
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
struct HelperInterpLinear : public HelperInterpBase {
static const int interp_size = 2;
static constexpr int interp_size = 2;
// Compute indices and weights for each interpolated dimension
// indices_weights = {
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
struct HelperInterpCubic : public HelperInterpBase {
static const int interp_size = 4;
static constexpr int interp_size = 4;
// Compute indices and weights for each interpolated dimension
// indices_weights = {

View File

@ -1359,7 +1359,8 @@ _scaled_gemm(
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
Tensor& out,
const std::optional<Tensor>& alpha = std::nullopt) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@ -1410,7 +1411,8 @@ _scaled_gemm(
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
use_fast_accum,
alpha);
return out;
}
}
@ -2320,12 +2322,23 @@ _scaled_nvfp4_nvfp4(
const Tensor& scale_b, const SwizzleType swizzle_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
const bool single_scale,
Tensor& out) {
Tensor& out,
const std::optional<Tensor>& global_scale_a = std::nullopt,
const std::optional<Tensor>& global_scale_b = std::nullopt) {
#ifdef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
#endif
TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
std::optional<Tensor> alpha = std::nullopt;
// Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise,
// if this is "And" we would silently do nothing in the case where one global scale is
// passed and not the other.
if (global_scale_a.has_value() || global_scale_b.has_value()) {
TORCH_CHECK_VALUE(global_scale_a.has_value(),
"For two-level-scaled NVFP4, global_scale_a must have a value");
TORCH_CHECK_VALUE(global_scale_b.has_value(),
"For two-level-scaled NVFP4, global_scale_b must have a value");
alpha = global_scale_a.value().mul(global_scale_b.value());
}
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
// Scales must be swizzled
@ -2347,7 +2360,7 @@ _scaled_nvfp4_nvfp4(
auto scaling_choice_a = ScalingType::BlockWise1x16;
auto scaling_choice_b = ScalingType::BlockWise1x16;
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha);
}
@ -2553,9 +2566,10 @@ _scaled_mm_cuda_v2_out(
} else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out,
scale_a[1], scale_b[1]);
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else {

View File

@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
}
static const int BLOCK_THREADS = 256;
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)

View File

@ -15,9 +15,7 @@
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/thread_constants.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -36,9 +34,9 @@ namespace at::native {
namespace {
#if defined(USE_ROCM)
static const int BLOCKDIMY = 16;
static constexpr int BLOCKDIMY = 16;
#else
static const int BLOCKDIMY = 32;
static constexpr int BLOCKDIMY = 32;
#endif
template
@ -240,10 +238,6 @@ __global__ void renorm_kernel(
} // anonymous namespace
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
int64_t num_weights, int64_t padding_idx,
@ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
num_indices
);
});
#else
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
});
#endif
}
return embedding_backward_cuda_kernel(grad, orig_indices,

View File

@ -10,9 +10,7 @@
#include <c10/macros/Macros.h>
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
#include <thrust/iterator/counting_iterator.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm
partials_per_segment_offset[num_of_segments-1];
}
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
*num_of_segments_ptr = num_of_segments;
}
#endif
} // anon namespace
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
#endif
Tensor embedding_backward_cuda_kernel(
const Tensor &grad,
@ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel(
auto segment_offsets = at::empty({numel}, orig_indices.options());
auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#else
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
cuda::cub::unique_by_key(
sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
segment_offsets.mutable_data_ptr<index_t>(),
num_of_segments_ptr, sorted_indices.numel());
});
#endif
int64_t max_segments = std::min<int64_t>(numel, num_weights);

View File

@ -31,16 +31,10 @@
#include <c10/macros/Macros.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h>
#endif
namespace at::native {
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif
namespace {
@ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
num_indices
);
});
#else
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
});
#endif
}
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,

View File

@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac;
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
7.09782712893383996843E2 : 88.72283905206835;
static const accscalar_t EXP1 = 2.718281828459045;
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
constexpr accscalar_t EXP1 = 2.718281828459045;
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a);
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const int MAXITER = 2000;
constexpr int MAXITER = 2000;
int i;
accscalar_t ans, ax, c, r;
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
accscalar_t fac = 1;
accscalar_t sum = 0;
accscalar_t term, logx;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t d[25][25] =
constexpr accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
int k, n, sgn;
int maxpow = 0;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a;
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
int i;
accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
4.503599627370496e15 : 16777216.;
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy
if ((x < 0) || (a < 0)) {

View File

@ -1,90 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/cuda/cub_definitions.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <ATen/cuda/ThrustAllocator.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/constant_iterator.h>
namespace at::native {
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
auto num_indices = count.numel();
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<index_t>(),
thrust::maximum<index_t>()
);
}
template
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
template
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
#endif
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
const ptrdiff_t numel = sorted_indices.numel();
auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
return thrust::get<0>(ends) - dummy_dev;
}
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
} // namespace at::native

View File

@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
const auto digamma_string = jiterator_stringify(
template <typename T>
T digamma(T x) {
static const double PI_f64 = 3.14159265358979323846;
static constexpr double PI_f64 = 3.14159265358979323846;
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
if (x == 0) {
@ -3072,9 +3072,9 @@ template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
static constexpr double PI_f64 = 3.14159265358979323846;
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
constexpr accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,

View File

@ -146,6 +146,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
int64_t batch_size = target.size(0);
int64_t H = target.size(1);
int64_t W = target.size(2);
int64_t n_classes = grad_input.size(1);
CUDA_KERNEL_LOOP(index, n_threads) {
const int64_t b = index % batch_size;
@ -156,6 +157,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
if (cur_target == ignore_index) {
continue;
}
CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
}

View File

@ -413,14 +413,12 @@ struct ReduceOp {
value = thread_reduce<output_vec_size>(input_slice);
}
if (config.should_block_y_reduce()) {
value = block_y_reduce<output_vec_size>(value, shared_memory);
}
__syncthreads();
if (config.should_block_x_reduce()) {
value = block_x_reduce<output_vec_size>(value, shared_memory);
}
if (config.should_block_y_reduce()) {
value = block_y_reduce<output_vec_size>(value, shared_memory);
}
using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>;
using offset_vec_t = std::array<index_t, output_vec_size>;
offset_vec_t base_offsets;
@ -657,8 +655,8 @@ struct ReduceOp {
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// todo for AMD
#ifdef USE_ROCM
// TODO(PaulZhang12): AMD and internal
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
@ -1097,11 +1095,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
// threads with different threadIdx.x are independent and will produce results for different outputs.
// In such case, values in each loaded vector always correspond to different outputs.
if (fastest_moving_stride == sizeof(scalar_t)) {
#ifdef USE_ROCM
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
#else
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
#endif
// Case 1: "vectorize along input"
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
// we should avoid vectorization.

View File

@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void mean_kernel_impl(TensorIterator& iter) {
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
using factor_t = typename c10::scalar_value_type<acc_t>::type;
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
} else {
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
}
}
static void mean_kernel_cuda(TensorIterator& iter) {

View File

@ -13,24 +13,19 @@ namespace at::native {
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
struct sum_functor {
void operator()(TensorIterator& iter) {
#ifdef USE_ROCM
// Half and BFloat16 can be packed in groups of up to 8 elements and
// can use *_DWORDX4 instructions to achieve that.
const bool is_16_bits =
( (std::is_same<at::Half, scalar_t>::value) ||
(std::is_same<at::BFloat16, scalar_t>::value) );
if (is_16_bits) {
const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
};
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
return;
iter, func_wrapper<out_t>(sum_combine)
);
} else {
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>(sum_combine)
);
}
#endif
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
}
};

View File

@ -19,7 +19,6 @@
namespace at::native {
// TODO: remove this when CUDA <11.6 is no longer supported
void topk_out_with_sort(
const Tensor& self,
int64_t k, int64_t dim, bool largest,
@ -31,21 +30,12 @@ void topk_out_with_sort(
indices.copy_(sorted_indices.narrow(dim, 0, k));
}
// TODO: remove this when CUDA <11.6 is no longer supported
bool disable_sort_for_topk();
bool should_use_sort(const Tensor& self, int64_t dim) {
#if defined(USE_ROCM)
if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972
return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387
#else
if (disable_sort_for_topk()) return false;
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
if (self.dim() == 0) return false;
if (self.dtype() == kBool) return false; // Bool is not support by topk
int64_t slice_size = self.size(dim);
if (slice_size == 0) return false;
int64_t num_slices = self.numel() / slice_size;
return num_slices <= 10 && slice_size >= 100000;
return false;
#endif
}

View File

@ -21,11 +21,6 @@ using namespace at::native;
namespace at::native {
// TODO: remove this when CUDA <11.6 is no longer supported
bool disable_sort_for_topk() {
return CUB_SUPPORTS_SCAN_BY_KEY();
}
namespace sbtopk { // single_block_topk
template <typename T>
@ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts(
}
__syncthreads();
#if !CUB_SUPPORTS_SCAN_BY_KEY()
return;
#endif
Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
// if largest, then only threads that has tidx > desired_digit are active
@ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts(
}
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
// Assumption: slice_size can not be larger than UINT32_MAX
template <typename Bitwise>
__global__ void computeBlockwiseKthCounts(
@ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu
}
}
}
#endif
int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
// occupancy of this kernel is limited by registers per threads
@ -687,16 +676,12 @@ void launch(
uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get());
AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream));
#if CUB_SUPPORTS_SCAN_BY_KEY()
auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
#else
uint32_t* withinKCounts = nullptr;
#endif
Bitwise desiredMask = 0;
dim3 grid;
@ -743,7 +728,6 @@ void launch(
}
desired = desired_in;
#if CUB_SUPPORTS_SCAN_BY_KEY()
computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
desired, counts, num_blocks, blocks_per_slice, kthCounts);
C10_CUDA_KERNEL_LAUNCH_CHECK();
@ -759,28 +743,6 @@ void launch(
topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
// Find topk values based on kth values
{
dim3 grid;
TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
int warp_size = at::cuda::warp_size();
dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
input,
inputSliceSize,
outputSliceSize,
largest,
numInputSlices,
inputWithinSliceStride,
topK,
topKWithinSliceStride,
indices,
indicesWithinSliceStride,
kthValues);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
}
} // namespace mbtopk
@ -788,7 +750,6 @@ void launch(
bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
if (num_slices > std::numeric_limits<uint32_t>::max() ||
slice_size > std::numeric_limits<uint32_t>::max()) return false;
#if CUB_SUPPORTS_SCAN_BY_KEY()
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
return (num_slices <= 20 && slice_size >= 20000) ||
(num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
@ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
(num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
(num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
(num_slices > 4000 && slice_size >= 400);
#else
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
return (num_slices <= 400 && slice_size >= 5000) ||
(num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
(num_slices >= 4000 && slice_size >= 300);
#endif
}
void launch_gather_topk_kernel(

View File

@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
const int64_t k,
const int64_t N_padded,
const IndexType last_dim_padded) {
int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
if (linear_idx >= N_padded) {
return;
}

View File

@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
return 0;
}
static const int size = 2;
static constexpr int size = 2;
};
// taken from
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
return 0;
}
static const int size = 4;
static constexpr int size = 4;
};
template <typename accscalar_t>

View File

@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
#ifdef USE_ROCM
// Helper function to compute output pixel range that can contribute to input pixel
template <typename accscalar_t>
__device__ __forceinline__ void compute_output_range(
int input_pos,
accscalar_t scale,
int output_size,
bool align_corners,
int& min_output,
int& max_output) {
accscalar_t lo, hi;
if (align_corners) {
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
} else {
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
}
min_output = max(0, static_cast<int>(ceil(lo)));
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
}
#endif
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
#ifdef USE_ROCM
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
index += blockDim.x * gridDim.x) {
// Decode input pixel coordinates
size_t index_temp = index;
const int w1 = index_temp % width1;
index_temp /= width1;
const int h1 = index_temp % height1;
const size_t nc_idx = index_temp / height1;
accscalar_t grad_sum = 0;
// Find range of output pixels that could interpolate from this input pixel
int h2_min, h2_max, w2_min, w2_max;
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
// Iterate over potential output pixels
for (int h2 = h2_min; h2 <= h2_max; h2++) {
for (int w2 = w2_min; w2 <= w2_max; w2++) {
// Compute source coordinates for this output pixel
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
rheight, h2, align_corners, /*cubic=*/false);
const int h1_base = (int)h1r;
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
const accscalar_t h1lambda = h1r - h1_base;
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1_base = (int)w1r;
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1_base;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
// Check if our input pixel participates in this interpolation and accumulate all weights
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
// to the same pixel, so we need to accumulate weights from all matching positions
accscalar_t weight = 0;
// Check all four interpolation positions and accumulate weights
if (h1 == h1_base && w1 == w1_base) {
weight += h0lambda * w0lambda; // top-left
}
if (h1 == h1_base && w1 == w1_base + w1p) {
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base) {
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
}
if (weight > 0) {
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
}
}
}
// Write accumulated gradient (no atomics needed)
idata[index] = static_cast<scalar_t>(grad_sum);
}
#else
const size_t o_numel = nc * width2 * height2;
const size_t i_numel = nc * width1 * height1;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
true);
}
#endif
}
template <typename scalar_t, typename accscalar_t>
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
#ifdef USE_ROCM
constexpr bool use_input = true;
#else
constexpr bool use_input = false;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * output_height * output_width;
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
input_height,
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
num_threads,

View File

@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){
U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f;
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
#endif
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
} else{
return {0.f, curr_sum.sigma2 + val * val, 0};
@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count;
U mean, sigma2;
if (count > decltype(dataB.count){0}) {
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
auto coef = __builtin_amdgcn_rcpf(count);
#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
#endif
auto nA = dataA.count * coef;
auto nB = dataB.count * coef;
mean = nA*dataA.mean + nB*dataB.mean;

View File

@ -466,7 +466,7 @@ struct ReduceJitOp {
__syncthreads();
#ifdef USE_ROCM
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {

View File

@ -487,9 +487,7 @@ std::unique_ptr<fe::graph::Graph> build_graph(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA")
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
@ -707,9 +705,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA_NESTEDTENSOR")
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_seq_len_q(SEQ_LEN_Q_)

View File

@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;

View File

@ -441,7 +441,7 @@ kernel void applySYRK(
uint3 tid [[thread_position_in_threadgroup]],
uint3 tgid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]]) {
uint warp_id [[simdgroup_index_in_threadgroup]]) {
const uint tx = tid.x;
const uint ty = tid.y;
const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32;
@ -474,11 +474,8 @@ kernel void applySYRK(
(actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0);
if (use_simdgroup) {
uint warp_id = sgitg;
simdgroup_matrix<float, 8, 8> negative_identity =
simdgroup_matrix<float, 8, 8>(-1.0);
simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0);
simdgroup_matrix<float, 8, 8> Prod;
simdgroup_matrix<float, 8, 8> Afrag;
simdgroup_matrix<float, 8, 8> Bfrag;
@ -521,8 +518,7 @@ kernel void applySYRK(
/* transpose = */ upper);
simdgroup_multiply(Prod, Afrag, Bfrag);
simdgroup_multiply(Prod, Prod, negative_identity);
simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod);
simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag);
}
simdgroup_store(

View File

@ -16,7 +16,6 @@ kernel void cat(
auto ndim = shared_params.ndim;
auto cat_dim = shared_params.cat_dim;
constant auto& output_strides = shared_params.output_strides;
constant auto& output_sizes = shared_params.output_sizes;
auto cat_dim_offset = input_params.cat_dim_offset;
auto input_element_offset = input_params.input_element_offset;

View File

@ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
other.size(0) > max_stride_size || other.size(1) > max_stride_size);
}
void map_mps_decomposition_error_code_to_blas(const Tensor& status) {
const auto& status_flat = status.view(-1);
for (const auto i : c10::irange(status_flat.size(0))) {
int code = status_flat[i].item<int>();
switch (code) {
case MPSMatrixDecompositionStatusSuccess:
status_flat[i] = 0;
break;
case MPSMatrixDecompositionStatusNonPositiveDefinite:
case MPSMatrixDecompositionStatusSingular:
status_flat[i] = 2;
break;
case MPSMatrixDecompositionStatusFailure:
status_flat[i] = -1;
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code);
}
}
}
} // anonymous namespace
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
@ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
"mpsmatrixdecompositionstatus for details.");
}
}
map_mps_decomposition_error_code_to_blas(info);
if (!left) {
// If this was a right solve, transpose the result back
result.copy_(result_t.transpose(-2, -1).contiguous());

View File

@ -706,6 +706,7 @@
variants: function, method
dispatch:
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all
tags: reduction
- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
@ -715,6 +716,7 @@
cpp_no_default_args: ['dim']
dispatch:
CompositeExplicitAutograd: all_dims_default
tags: reduction
- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -723,6 +725,7 @@
CPU, CUDA: all_out
MPS: all_out_mps
MTIA: all_out_mtia
tags: reduction
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -731,13 +734,16 @@
CPU, CUDA: all_dims_out
CompositeExplicitAutograd: all_dims_out_default
cpp_no_default_args: ['dim']
tags: reduction
- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
variants: function, method
@ -749,14 +755,14 @@
device_check: NoCheck # TensorIterator
structured_delegate: any.out
variants: function, method
tags: core
tags: [core, reduction]
- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.dims_out
variants: function, method
cpp_no_default_args: ['dim']
tags: core
tags: [core, reduction]
dispatch:
CompositeExplicitAutograd: any_dims_default
@ -766,6 +772,7 @@
dispatch:
CPU, CUDA: any_out
MPS: any_out_mps
tags: reduction
- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -774,13 +781,16 @@
CPU, CUDA: any_dims_out
CompositeExplicitAutograd: any_dims_out_default
cpp_no_default_args: ['dim']
tags: reduction
- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
@ -826,25 +836,27 @@
structured_delegate: argmax.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: argmax_out
MPS: argmax_out_mps
tags: reduction
- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
structured_delegate: argmin.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: argmin_out
MPS: argmin_out_mps
tags: reduction
- func: acosh(Tensor self) -> Tensor
variants: function, method
@ -1370,6 +1382,7 @@
dispatch:
SparseCPU: bmm_sparse_cpu
SparseCUDA: bmm_sparse_cuda
SparseMPS: bmm_sparse_mps
NestedTensorCPU: bmm_nested
NestedTensorCUDA: bmm_nested_cuda
tags: core
@ -1385,6 +1398,7 @@
MTIA: bmm_out_mtia
SparseCPU: bmm_out_sparse_cpu
SparseCUDA: bmm_out_sparse_cuda
SparseMPS: bmm_out_sparse_mps
SparseCsrCUDA: bmm_out_sparse_csr_cuda
- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
@ -1867,12 +1881,14 @@
CUDA: count_nonzero_cuda
MPS: count_nonzero_mps
autogen: count_nonzero.dim_IntList_out
tags: reduction
- func: count_nonzero(Tensor self, int? dim=None) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: count_nonzero
autogen: count_nonzero.out
tags: reduction
- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
variants: function, method
@ -3793,19 +3809,23 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: logsumexp
tags: reduction
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
# calls squeeze
CompositeExplicitAutogradNonFunctional: logsumexp_out
tags: reduction
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
@ -3855,6 +3875,7 @@
device_check: NoCheck # TensorIterator
structured_delegate: aminmax.out
variants: function, method
tags: reduction
- func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
device_check: NoCheck # TensorIterator
@ -3862,6 +3883,7 @@
dispatch:
CPU, CUDA, MTIA: aminmax_out
MPS: aminmax_out_mps
tags: reduction
- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor
dispatch:
@ -3877,7 +3899,7 @@
variants: function, method
dispatch:
QuantizedCPU, QuantizedCUDA: qmax
tags: core
tags: [core, reduction]
- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
@ -3887,13 +3909,16 @@
dispatch:
CPU, CUDA, MTIA: max_out
MPS: max_out_mps
tags: reduction
- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
tags: reduction
- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
variants: function
@ -3906,13 +3931,14 @@
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amax.out
tags: core
tags: [core, reduction]
- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA, MTIA: amax_out
MPS: amax_out_mps
tags: reduction
# Return: (Tensor output, Tensor indices)
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
@ -3974,13 +4000,14 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mean
tags: core
tags: [core, reduction]
# For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this.
- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: mean_dtype_out
tags: reduction
- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: mean.out
@ -3988,7 +4015,7 @@
variants: function, method
dispatch:
QuantizedCPU: mean_quantized_cpu
tags: core
tags: [core, reduction]
- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -3997,13 +4024,16 @@
CPU, CUDA: mean_out
MPS: mean_out_mps
QuantizedCPU: mean_out_quantized_cpu
tags: reduction
- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # Composite
@ -4066,7 +4096,7 @@
variants: function, method
dispatch:
QuantizedCPU, QuantizedCUDA: qmin
tags: core
tags: [core, reduction]
- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
@ -4076,24 +4106,28 @@
dispatch:
CPU, CUDA, MTIA: min_out
MPS: min_out_mps
tags: reduction
- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
tags: reduction
- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amin.out
tags: core
tags: [core, reduction]
- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA, MTIA: amin_out
MPS: amin_out_mps
tags: reduction
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
@ -4173,7 +4207,7 @@
structured_delegate: mm.out
variants: function, method
dispatch:
SparseCPU, SparseCUDA: _sparse_mm
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
tags: core
@ -5858,6 +5892,7 @@
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
autogen: sum.out
tags: reduction
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
# TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype
@ -5868,11 +5903,12 @@
NestedTensorCPU: NestedTensor_sum_dim_CPU
SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed
tags: core
tags: [core, reduction]
- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -5880,9 +5916,11 @@
dispatch:
CPU, CUDA: sum_out
MPS: sum_out_mps
tags: reduction
- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
# TODO: this function will be replaced once nested expand semantics have been settled on
- func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
@ -5894,11 +5932,13 @@
dispatch:
CPU, CUDA: nansum
MPS: nansum_mps
tags: reduction
- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: nansum_out
MPS: nansum_out_mps
tags: reduction
- func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
variants: function, method
@ -5962,11 +6002,13 @@
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
@ -5975,16 +6017,19 @@
CPU, CUDA: std
MPS: std_mps
QuantizedCPU: std_quantized_cpu
tags: reduction
- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
@ -5993,42 +6038,51 @@
CPU, CUDA: std_mean
MPS: std_mean_mps
autogen: std_mean.correction_out
tags: reduction
- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: std_out
QuantizedCPU: std_out_quantized_cpu
tags: reduction
- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
@ -6037,13 +6091,13 @@
CPU, CUDA: prod
MPS: prod_mps
autogen: prod.out
tags: core
tags: [core, reduction]
- func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: prod.int_out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
- func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6051,13 +6105,16 @@
dispatch:
CPU, CUDA: prod_out
MPS: prod_out_mps
tags: reduction
- func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: t(Tensor(a) self) -> Tensor(a)
device_check: NoCheck
@ -6518,11 +6575,12 @@
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: core
tags: [core, reduction]
cpp_no_default_args: ["unbiased"]
- func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
@ -6532,43 +6590,51 @@
CPU, CUDA: var
MPS: var_mps
MTIA: var_mtia
tags: core
tags: [core, reduction]
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: var_out
tags: reduction
- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
@ -6577,15 +6643,18 @@
CPU, CUDA: var_mean
MPS: var_mean_mps
autogen: var_mean.correction_out
tags: reduction
- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a)
variants: method
@ -6845,6 +6914,7 @@
dispatch:
CompositeExplicitAutograd: norm
autogen: norm.ScalarOpt_dtype_out
tags: reduction
- func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor
device_check: NoCheck # TensorIterator
@ -6852,6 +6922,7 @@
dispatch:
CompositeExplicitAutograd: norm
autogen: norm.Scalar_out
tags: reduction
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
structured_delegate: norm.dtype_out
@ -6859,6 +6930,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm
tags: reduction
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
structured_delegate: norm.out
@ -6866,6 +6938,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_norm
tags: reduction
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6873,6 +6946,7 @@
dispatch:
CPU, CUDA: norm_dtype_out
MPS: norm_dtype_out_mps
tags: reduction
- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6880,21 +6954,26 @@
dispatch:
CPU, CUDA: norm_out
MPS: norm_out_mps
tags: reduction
# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd
- func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
variants: method, function
@ -7112,6 +7191,7 @@
MTIA: addmm_out_mtia
SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda
SparseMPS: addmm_out_sparse_dense_mps
SparseCsrCPU: addmm_out_sparse_compressed_cpu
SparseCsrCUDA: addmm_out_sparse_compressed_cuda
@ -7121,6 +7201,7 @@
dispatch:
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
SparseMPS: addmm_sparse_dense_mps
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
tags: core
@ -10078,12 +10159,14 @@
CPU, CUDA: min
MPS: min_mps
QuantizedCPU: min_quantized_cpu
tags: [reduction]
- func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: min_unary_out
QuantizedCPU: min_quantized_unary_out
tags: [reduction]
- func: fmin(Tensor self, Tensor other) -> Tensor
structured_delegate: fmin.out
@ -10106,6 +10189,7 @@
CPU, CUDA: max
MPS: max_mps
QuantizedCPU: max_quantized_cpu
tags: [reduction]
- func: fmax(Tensor self, Tensor other) -> Tensor
structured_delegate: fmax.out
@ -10152,6 +10236,7 @@
dispatch:
CPU, CUDA: max_unary_out
QuantizedCPU: max_quantized_unary_out
tags: [reduction]
- func: minimum(Tensor self, Tensor other) -> Tensor
structured_delegate: minimum.out
@ -10271,6 +10356,7 @@
device_check: NoCheck # TensorIterator
structured_delegate: all.all_out
variants: method, function
tags: reduction
- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
@ -10279,6 +10365,7 @@
CPU, CUDA: all_all_out
MTIA: all_all_out_mtia
MPS: all_all_out_mps
tags: reduction
- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -10286,7 +10373,7 @@
variants: method, function
dispatch:
SparseCPU, SparseCUDA, SparseMPS: any_sparse
tags: core
tags: [core, reduction]
- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
@ -10294,6 +10381,7 @@
dispatch:
CPU, CUDA: any_all_out
MPS: any_all_out_mps
tags: reduction
- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -14345,6 +14433,7 @@
python_module: linalg
variants: function
structured_delegate: linalg_vector_norm.out
tags: reduction
- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
@ -14352,6 +14441,7 @@
dispatch:
CPU, CUDA: linalg_vector_norm_out
MPS: linalg_vector_norm_out_mps
tags: reduction
- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
python_module: linalg

View File

@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__)
const static int PARALLEL_THRESHOLD = 1 << 20;
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation
template <typename T>

View File

@ -1388,7 +1388,7 @@ namespace at::native {
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt;
static const std::string_view binary_post_op = "none";
constexpr std::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,

View File

@ -16,8 +16,8 @@ namespace {
#ifdef USE_PYTORCH_QNNPACK
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
const static int qnnpack_softmax_output_zero_point = 0;
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
constexpr static int qnnpack_softmax_output_zero_point = 0;
bool is_qnnpack_compatible(
const Tensor& qx,

View File

@ -1,5 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
@ -18,6 +19,8 @@
#include <ATen/ops/ones_like.h>
#include <ATen/ops/argsort.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/copy_sparse_to_sparse.h>
#include <ATen/ops/mul.h>
#endif
@ -33,6 +36,305 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Mul_metallib.h>
#endif
static Tensor& s_addmm_out_sparse_dense_mps(
Tensor& r,
const Tensor& t,
const SparseTensor& sparse_,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha) {
TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim());
TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim());
TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim());
TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim());
const int64_t I = sparse_.size(0);
const int64_t J = sparse_.size(1);
const int64_t K = dense.size(1);
TORCH_CHECK(dense.size(0) == J,
"addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0));
TORCH_CHECK(t.size(0) == I && t.size(1) == K,
"addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")");
r.resize_({I, K});
auto sparse = sparse_.coalesce();
const int64_t nnz = sparse._nnz();
if (nnz == 0 || I == 0 || K == 0) {
at::mul_out(r, t, beta);
return r;
}
const auto v_dtype = sparse._values().scalar_type();
const auto d_dtype = dense.scalar_type();
const auto t_dtype = t.scalar_type();
auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype);
TORCH_CHECK(canCast(compute_dtype, r.scalar_type()),
"Can't convert computed type ", compute_dtype, " to output ", r.scalar_type());
auto indices2d = sparse._indices().contiguous();
auto values = sparse._values().to(compute_dtype);
auto dense_c = dense.to(compute_dtype).contiguous();
auto t_c = t.to(compute_dtype).contiguous();
const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous();
Tensor out_buf = out_needs_cast
? at::empty({I, K}, r.options().dtype(compute_dtype))
: r;
auto out_contig = out_buf.contiguous();
auto device = r.device();
auto stream = getCurrentMPSStream();
const float alpha_f = alpha.to<float>();
const float beta_f = beta.to<float>();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values);
auto pso = lib.getPipelineStateForFunc(func);
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t gridX = static_cast<uint32_t>(K);
const uint32_t gridZ = static_cast<uint32_t>(I);
const uint32_t tgW = std::min<uint32_t>(gridX, tew);
MTLSize grid = MTLSizeMake(gridX, 1, gridZ);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
indices2d,
values,
dense_c,
t_c,
out_contig,
std::array<uint32_t, 3>{static_cast<uint32_t>(I),
static_cast<uint32_t>(J),
static_cast<uint32_t>(K)},
std::array<float, 2>{alpha_f, beta_f},
static_cast<uint32_t>(nnz));
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
if (out_needs_cast) {
r.copy_(out_contig.to(r.scalar_type()));
}
return r;
}
static void build_batch_ptr_mps(
const Tensor& indices_dim0,
int64_t B,
Tensor& batch_ptr
) {
// Builds an array of pointers which point to each batches elements. Example:
// idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2] // 9 non-zero elements
// └─────┘ └──┘ └─────────┘
// batch 0 batch 1 batch 2
// batch_ptr = [0, 3, 5, 9]
// │ │ │ └─ end of batch 2 (total nnz)
// │ │ └──── batch 2 starts at index 5
// │ └─────── batch 1 starts at index 3
// └────────── batch 0 starts at index 0
TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected");
auto device = indices_dim0.device();
auto stream = getCurrentMPSStream();
const int64_t nnz = indices_dim0.numel();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t Q = static_cast<uint32_t>(B + 1);
const uint32_t tgW = std::min<uint32_t>(Q, tew);
MTLSize grid = MTLSizeMake(Q, 1, 1);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
indices_dim0,
batch_ptr,
std::array<uint32_t, 2>{static_cast<uint32_t>(nnz),
static_cast<uint32_t>(B)});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
}
static void build_row_ptr_per_batch_mps(
const Tensor& rows,
const Tensor& batch_ptr,
int64_t B,
int64_t I,
Tensor& row_ptr
) {
// Build per-batch CSR-style row pointer arrays from row indices sorted by batch
// Given:
// rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch
// batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b
// Produces:
// - row_ptr: shape [B, I+1]
//
// Example (B = 2, I = 4):
// rows = [0, 0, 1, 3, 0, 2, 2] // 7 non-zero elements
// └─── batch 0 ──┘ └─ batch 1 ─┘
// batch_ptr = [0, 4, 7]
// │ │ └─ end of batch 1 (total nnz)
// │ └──── end of batch 0/start of batch 1
// └─────── start of batch 0
//
// per-batch row pointers (I+1 entries each):
// row_ptr[0] = [0, 2, 3, 3, 4]
// row_ptr[1] = [0, 1, 1, 3, 3]
// laid out in memory: [0, 2, 3, 3, 4, 0, 1, 1, 3, 3]
TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected");
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t Qx = static_cast<uint32_t>(I + 1);
const uint32_t Qy = static_cast<uint32_t>(B);
const uint32_t tgW = std::min<uint32_t>(Qx, tew);
MTLSize grid = MTLSizeMake(Qx, Qy, 1);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
rows,
batch_ptr,
row_ptr,
std::array<uint32_t, 2>{static_cast<uint32_t>(I),
static_cast<uint32_t>(B)});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
}
Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) {
TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device());
TORCH_CHECK(self_.is_mps(), "bmm_sparse: expected 'self' to be MPS, got ", self_.device());
TORCH_CHECK(mat2_.is_mps(), "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device());
TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim());
TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim());
TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim());
TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match");
TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match");
const int64_t B = self_.size(0);
const int64_t I = self_.size(1);
const int64_t J = self_.size(2);
const int64_t K = mat2_.size(2);
auto self = self_.coalesce();
const int64_t nnz = self._nnz();
if (nnz == 0) {
return result_.zero_();
}
const auto computeDtype = at::kFloat;
auto indices = self._indices();
auto values = self._values();
auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype);
auto mat2_c = mat2_.scalar_type() == computeDtype ? mat2_ : mat2_.to(computeDtype);
auto mat2_contig = mat2_c.contiguous();
auto idx_b = indices.select(0, 0).contiguous();
auto idx_i = indices.select(0, 1).contiguous();
auto idx_j = indices.select(0, 2).contiguous();
// builds an array of pointers of where the batch_idx's pointer starts and ends
// look in function for better explanation
auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong));
build_batch_ptr_mps(idx_b, B, batch_ptr);
// build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals
auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong));
build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr);
const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous();
Tensor out_buf = out_needs_cast
? at::empty({B, I, K}, result_.options().dtype(computeDtype))
: result_;
auto out_contig = out_buf.contiguous();
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew);
// One threadgroup per (row i, batch b), lanes cover K
MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
idx_i,
idx_j,
values_c,
mat2_contig,
out_contig,
row_ptr,
std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
if (out_needs_cast) {
result_.copy_(out_contig.to(result_.scalar_type()));
}
return result_;
}
Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) {
Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options());
return bmm_out_sparse_mps(self, mat2, result);
}
Tensor& addmm_out_sparse_dense_mps(
const Tensor& self,
const SparseTensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
}
Tensor addmm_sparse_dense_mps(
const Tensor& self,
const SparseTensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha
) {
c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
Tensor result = at::empty({0}, self.options());
return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
}
static SparseTensor& mul_out_dense_sparse_mps(
const Tensor& dense,
const Tensor& sparse,

View File

@ -1,10 +1,105 @@
#include <metal_stdlib>
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
using namespace c10::metal;
using namespace metal;
inline uint lower_bound_i64(device const long* arr, uint lo, uint hi, long key) {
uint l = lo, r = hi;
while (l < r) {
uint m = (l + r) >> 1;
long v = arr[m];
if (v < key) {
l = m + 1;
} else {
r = m;
}
}
return l;
}
template <typename T> struct MulAccum { using type = float; };
template <> struct MulAccum<float2> { using type = float2; };
inline uint upper_bound_i64(device const long* arr, uint lo, uint hi, long key) {
uint l = lo, r = hi;
while (l < r) {
uint m = (l + r) >> 1;
long v = arr[m];
if (v <= key) {
l = m + 1;
} else {
r = m;
}
}
return l;
}
kernel void build_row_ptr_from_sorted_rows_by_batch(
device const long* rows [[buffer(0)]],
device const long* batch_ptr [[buffer(1)]],
device long* row_ptr [[buffer(2)]],
constant uint2& dims [[buffer(3)]],
uint3 tid [[thread_position_in_grid]])
{
const uint I = dims.x;
const uint B = dims.y;
const uint i = tid.x;
const uint b = tid.y;
if (b >= B || i > I) return;
const uint base = (uint)batch_ptr[b];
const uint lim = (uint)batch_ptr[b + 1];
const ulong out_base = (ulong)b * (ulong)(I + 1);
if (i == I) {
row_ptr[out_base + (ulong)I] = (long)lim;
} else {
const long key = (long)i;
const uint pos = lower_bound_i64(rows, base, lim, key);
row_ptr[out_base + (ulong)i] = (long)pos;
}
}
template <typename T>
kernel void spmm_bmm_coo_rows_grouped(
device const long* rows [[buffer(0)]],
device const long* cols [[buffer(1)]],
device const T* vals [[buffer(2)]],
device const T* dense [[buffer(3)]],
device T* out [[buffer(4)]],
device const long* row_ptr [[buffer(5)]],
constant uint4& dims [[buffer(6)]],
uint3 tid [[thread_position_in_grid]],
uint3 ltid [[thread_position_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]])
{
const uint B = dims.x;
const uint I = dims.y;
const uint J = dims.z;
const uint K = dims.w;
const uint b = tid.z;
const uint i = tid.y;
const uint lane = ltid.x;
const uint tgW = tptg.x;
const ulong rp_base = (ulong)b * (ulong)(I + 1);
const uint start = (uint)row_ptr[rp_base + (ulong)i];
const uint end = (uint)row_ptr[rp_base + (ulong)i + 1];
for (uint k = lane; k < K; k += tgW) {
auto acc = static_cast<accum_t<T>>(T(0));
for (uint p = start; p < end; ++p) {
const uint c = (uint)cols[p];
const auto v = static_cast<accum_t<T>>(vals[p]);
const uint d_off = ((b * J) + c) * K + k;
const auto d = static_cast<accum_t<T>>(dense[d_off]);
acc += mul(v, d);
}
const uint y_off = ((b * I) + i) * K + k;
out[y_off] = static_cast<T>(acc);
}
}
template <typename T>
kernel void dense_sparse_mul_kernel(
@ -32,10 +127,9 @@ kernel void dense_sparse_mul_kernel(
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
using accum_t = typename MulAccum<T>::type;
const accum_t a = static_cast<accum_t>(values[val_idx]);
const accum_t b = static_cast<accum_t>(dense[dense_idx]);
out_values[val_idx] = static_cast<T>(a * b);
const auto a = static_cast<accum_t<T>>(values[val_idx]);
const auto b = static_cast<accum_t<T>>(dense[dense_idx]);
out_values[val_idx] = static_cast<T>(mul(a, b));
}
kernel void intersect_binary_search(
@ -120,6 +214,76 @@ kernel void fused_gather_mul_kernel(
}
}
kernel void build_batch_ptr_from_sorted_batches(
device const long* batches [[buffer(0)]],
device long* batch_ptr [[buffer(1)]],
constant uint2& nnz_B [[buffer(2)]],
uint3 tid [[thread_position_in_grid]])
{
uint b = tid.x;
uint nnz = nnz_B.x;
uint batch = nnz_B.y;
if (b == batch) {
batch_ptr[b] = (long)nnz;
return;
}
uint lo = 0;
uint hi = nnz;
long key = (long)b;
while (lo < hi) {
uint mid = (lo + hi) >> 1;
long v = batches[mid];
if (v < key) lo = mid + 1;
else hi = mid;
}
batch_ptr[b] = (long)lo;
}
template <typename T>
kernel void spmm_addmm_coo(
device const long* indices2d [[buffer(0)]],
device const T* vals [[buffer(1)]],
device const T* dense [[buffer(2)]],
device const T* t_in [[buffer(3)]],
device T* out [[buffer(4)]],
constant uint3& dims [[buffer(5)]],
constant float2& alpha_beta [[buffer(6)]],
constant uint& nnz [[buffer(7)]],
uint3 tid [[thread_position_in_grid]])
{
const uint K = dims.z;
const uint k = tid.x;
const uint i = tid.z;
const float alpha = alpha_beta.x;
const float beta = alpha_beta.y;
device const long* rows = indices2d;
device const long* cols = indices2d + nnz;
const uint start = lower_bound_i64(rows, 0u, nnz, (long)i);
const uint end = upper_bound_i64(rows, 0u, nnz, (long)i);
// accumulator is float for scalar/half/bfloat and float2 for float2
auto acc = static_cast<accum_t<T>>(T(0));
for (uint p = start; p < end; ++p) {
const uint c = (uint)cols[p];
const auto v = static_cast<accum_t<T>>(vals[p]);
const uint dense_off = c * K + k;
const auto d = static_cast<accum_t<T>>(dense[dense_off]);
acc += mul(v, d);
}
const uint off = i * K + k;
const auto base = (beta != 0.0f) ? (static_cast<accum_t<T>>(t_in[off]) * beta) : static_cast<accum_t<T>>(T(0));
const auto y = base + alpha * acc;
out[off] = static_cast<T>(y);
}
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE) \
template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void \
dense_sparse_mul_kernel<DTYPE>( \
@ -151,6 +315,36 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2);
constant uint2& dims_output [[buffer(8)]], \
uint3 gid [[thread_position_in_grid]]);
INSTANTIATE_FUSED_GATHER_MUL(float);
INSTANTIATE_FUSED_GATHER_MUL(half);
INSTANTIATE_FUSED_GATHER_MUL(bfloat);
INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
spmm_bmm_coo_rows_grouped<DTYPE>( \
device const long* rows [[buffer(0)]], \
device const long* cols [[buffer(1)]], \
device const DTYPE* vals [[buffer(2)]], \
device const DTYPE* dense [[buffer(3)]], \
device DTYPE* out [[buffer(4)]], \
device const long* row_ptr [[buffer(5)]], \
constant uint4& dims [[buffer(6)]], \
uint3 tid [[thread_position_in_grid]], \
uint3 ltid [[thread_position_in_threadgroup]], \
uint3 tptg [[threads_per_threadgroup]]);
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED);
#define INSTANTIATE_SPMM_ADDMM_COO(DTYPE) \
template [[host_name("spmm_addmm_coo_" #DTYPE)]] kernel void \
spmm_addmm_coo<DTYPE>( \
device const long* indices2d [[buffer(0)]], \
device const DTYPE* vals [[buffer(1)]], \
device const DTYPE* dense [[buffer(2)]], \
device const DTYPE* t_in [[buffer(3)]], \
device DTYPE* out [[buffer(4)]], \
constant uint3& dims [[buffer(5)]], \
constant float2& alpha_beta [[buffer(6)]], \
constant uint& nnz [[buffer(7)]], \
uint3 tid [[thread_position_in_grid]]);
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_ADDMM_COO);

View File

@ -93,3 +93,7 @@
This operator does not support cudagraphs. The presence of this tag on an operator will cause
Inductor to split the graph around this operator. Note that operators without this tag may still
not support CUDAGraphs. Inductor may have other hardcoded lists around that.
- tag: reduction
desc: |
This tag indicates that an operator performs a reduction operation, computing aggregate values
(sum, mean, max, min, etc.) across one or more dimensions of the input tensor(s).

View File

@ -110,9 +110,9 @@ class ApplyLogSumExp {
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
static int constexpr kElementsPerAccess = ElementsPerAccess;
static int constexpr kCount = kElementsPerAccess;
static constexpr ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;

View File

@ -202,7 +202,6 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -14,16 +14,16 @@ using namespace at;
namespace {
const auto int_min = std::numeric_limits<int>::min();
const auto int_max = std::numeric_limits<int>::max();
const auto long_min = std::numeric_limits<int64_t>::min();
const auto long_max = std::numeric_limits<int64_t>::max();
const auto float_lowest = std::numeric_limits<float>::lowest();
const auto float_min = std::numeric_limits<float>::min();
const auto float_max = std::numeric_limits<float>::max();
const auto double_lowest = std::numeric_limits<double>::lowest();
const auto double_min = std::numeric_limits<double>::min();
const auto double_max = std::numeric_limits<double>::max();
constexpr auto int_min = std::numeric_limits<int>::min();
constexpr auto int_max = std::numeric_limits<int>::max();
constexpr auto long_min = std::numeric_limits<int64_t>::min();
constexpr auto long_max = std::numeric_limits<int64_t>::max();
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
constexpr auto float_min = std::numeric_limits<float>::min();
constexpr auto float_max = std::numeric_limits<float>::max();
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
constexpr auto double_min = std::numeric_limits<double>::min();
constexpr auto double_max = std::numeric_limits<double>::max();
const std::vector<int> ints {
int_min,

View File

@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(uint64_t);
constexpr size_t total_size = seed_size + offset_size;
// The internal state is returned as a CPU byte tensor.
auto state_tensor = at::detail::empty_cpu(
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::xpu::assertNotCapturing(
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(uint64_t);
constexpr size_t total_size = seed_size + offset_size;
at::detail::check_rng_state(new_state);

View File

@ -1751,8 +1751,8 @@ def maybe_snapshot_memory(should_snapshot_memory, suffix):
f"{output_filename.rstrip('.csv')}_{suffix}.pickle",
)
)
except Exception as e:
log.error("Failed to save memory snapshot, %s", e)
except Exception:
log.exception("Failed to save memory snapshot")
torch.cuda.memory._record_memory_history(enabled=None)
@ -2284,9 +2284,11 @@ class BenchmarkRunner:
)
):
is_same = False
except Exception:
except Exception as e:
# Sometimes torch.allclose may throw RuntimeError
is_same = False
exception_string = str(e)
accuracy_status = f"fail_exception: {exception_string}"
return record_status(accuracy_status, dynamo_start_stats=start_stats)
if not is_same:
accuracy_status = "eager_two_runs_differ"
@ -2403,9 +2405,11 @@ class BenchmarkRunner:
force_max_multiplier=force_max_multiplier,
):
is_same = False
except Exception:
except Exception as e:
# Sometimes torch.allclose may throw RuntimeError
is_same = False
exception_string = str(e)
accuracy_status = f"fail_exception: {exception_string}"
return record_status(accuracy_status, dynamo_start_stats=start_stats)
if not is_same:
if self.args.skip_accuracy_check:
@ -4060,7 +4064,7 @@ def run(runner, args, original_dir=None):
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = (
speedup_experiment if not args.backend == "torchao" else latency_experiment
speedup_experiment if args.backend != "torchao" else latency_experiment
)
if args.accuracy:
output_filename = f"accuracy_{args.backend}.csv"

View File

@ -124,7 +124,7 @@ with open(MODELS_FILENAME) as fh:
continue
batch_size = int(batch_size)
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
assert len(BATCH_SIZE_KNOWN_MODELS)
assert BATCH_SIZE_KNOWN_MODELS
try:

View File

@ -296,8 +296,8 @@ class OperatorInputsLoader:
for key in self.operator_db.keys():
try:
op = eval(key)
except AttributeError as ae:
log.warning("Evaluating an op name into an OpOverload: %s", ae)
except AttributeError:
log.warning("Evaluating an op name into an OpOverload", exc_info=True)
continue
yield op

View File

@ -3,6 +3,7 @@ import sys
from benchmark_base import BenchmarkBase
import torch
from torch._dynamo.utils import CompileTimeInstructionCounter
class Benchmark(BenchmarkBase):
@ -32,7 +33,11 @@ class Benchmark(BenchmarkBase):
def _work(self):
# enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency.
with torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False):
with (
torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
torch._export.config.patch(use_new_tracer_experimental=True),
CompileTimeInstructionCounter.record(),
):
torch.export.export(self.m, (self.input,), strict=True)

View File

@ -38,7 +38,7 @@ update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.1
38
39
40
41
42
43
44

View File

@ -127,7 +127,7 @@ def trainbench(
bwd_time = bwd_start_event.elapsed_time(bwd_end_event)
return fwd_time, bwd_time
creator_args = creator_args = {
creator_args = {
"seqLength": seqLength,
"numLayers": numLayers,
"inputSize": inputSize,

View File

@ -12,7 +12,7 @@ def modeldef(request, net_name, executor, fuser):
# Given a 'net_name' provided by generate_tests, build the thing
name, rnn_creator, context = get_nn_runners(net_name)[0]
creator_args = creator_args = {
creator_args = {
"seqLength": 100,
"numLayers": 1,
"inputSize": 512,

View File

@ -85,7 +85,7 @@ class WeightOnlyInt8QuantHandler:
cur_state_dict[f"{fqn}.weight"] = int8_weight
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
elif isinstance(mod, ConditionalFeedForward):
for weight_idx in range(0, 3):
for weight_idx in range(3):
weight_name = f"w{weight_idx + 1}"
scales_name = f"scales{weight_idx + 1}"
weight = getattr(mod, weight_name)

View File

@ -38,12 +38,16 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
)
op_bench.generate_pt_test(
configs.convtranspose_1d_configs_short
+ configs.conv_1d_configs_short
+ configs.conv_1d_configs_long,
ConvTranspose1dBenchmark,
)
if not torch.backends.mkldnn.is_acl_available():
# convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654
op_bench.generate_pt_test(
configs.convtranspose_1d_configs_short
+ configs.conv_1d_configs_short
+ configs.conv_1d_configs_long,
ConvTranspose1dBenchmark,
)
"""

View File

@ -271,7 +271,7 @@ def run_single_backend_sdpa(
if config.calculate_bwd_time:
# TODO: debug backward pass for njt
if eager_sdpa and not config.attn_type == "document_mask":
if eager_sdpa and config.attn_type != "document_mask":
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, d_out, retain_graph=True

View File

@ -1729,10 +1729,8 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2025,9 +2023,6 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -9,6 +9,7 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -1,5 +1,4 @@
#include <c10/core/AllocatorConfig.h>
#include <c10/core/DeviceType.h>
#include <c10/util/env.h>
namespace c10::CachingAllocator {
@ -47,7 +46,7 @@ size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) {
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart);
const size_t interval_end =
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd);
TORCH_CHECK(
TORCH_CHECK_VALUE(
interval_end - interval_start == kRoundUpPowerOfTwoIntervals,
"kRoundUpPowerOfTwoIntervals mismatch");
@ -66,7 +65,7 @@ size_t AcceleratorAllocatorConfig::parseMaxSplitSize(
std::numeric_limits<size_t>::max() / kMB;
size_t val_env = tokenizer.toSizeT(++i);
TORCH_CHECK(
TORCH_CHECK_VALUE(
val_env >= min_allowed_split_size_mb,
"CachingAllocator option max_split_size_mb too small, must be >= ",
min_allowed_split_size_mb);
@ -85,7 +84,7 @@ size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize(
std::numeric_limits<size_t>::max() / kMB;
size_t val_env = tokenizer.toSizeT(++i);
TORCH_CHECK(
TORCH_CHECK_VALUE(
val_env >= min_allowed_split_size_mb,
"CachingAllocator option max_non_split_rounding_mb too small, must be >= ",
min_allowed_split_size_mb);
@ -100,7 +99,7 @@ size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold(
size_t i) {
tokenizer.checkToken(++i, ":");
double val_env = tokenizer.toDouble(++i);
TORCH_CHECK(
TORCH_CHECK_VALUE(
val_env > 0 && val_env < 1.0,
"garbage_collect_threshold is invalid, set it in (0.0, 1.0)");
garbage_collection_threshold_ = val_env;
@ -121,7 +120,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
size_t value_index = i;
tokenizer.checkToken(++i, ":");
size_t value = tokenizer.toSizeT(++i);
TORCH_CHECK(
TORCH_CHECK_VALUE(
value == 0 || llvm::isPowerOf2_64(value),
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ");
@ -129,12 +128,13 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
std::fill(
std::next(
roundup_power2_divisions_.begin(),
static_cast<std::vector<size_t>::difference_type>(last_index)),
static_cast<std::vector<size_t>::difference_type>(
last_index + 1)),
roundup_power2_divisions_.end(),
value);
} else {
size_t boundary = tokenizer.toSizeT(value_index);
TORCH_CHECK(
TORCH_CHECK_VALUE(
llvm::isPowerOf2_64(boundary),
"For roundups, the intervals have to be power of 2 ");
@ -164,7 +164,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
} else { // Keep this for backwards compatibility
size_t value = tokenizer.toSizeT(i);
TORCH_CHECK(
TORCH_CHECK_VALUE(
llvm::isPowerOf2_64(value),
"For roundups, the divisions has to be power of 2 ");
std::fill(
@ -224,7 +224,7 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
// If a device-specific configuration parser hook is registered, it will
// check if the key is unrecognized.
if (device_config_parser_hook_) {
TORCH_CHECK(
TORCH_CHECK_VALUE(
getKeys().find(key) != getKeys().end(),
"Unrecognized key '",
key,

View File

@ -76,7 +76,7 @@ class ConfigTokenizer {
} else if (token == "False") {
return false;
} else {
TORCH_CHECK(
TORCH_CHECK_VALUE(
false,
"Expected 'True' or 'False' at index ",
i,
@ -253,7 +253,7 @@ class C10_API AcceleratorAllocatorConfig {
device_config_parser_hook_ = std::move(hook);
auto& mutable_keys = getMutableKeys();
for (auto& key : keys) {
TORCH_CHECK(
TORCH_CHECK_VALUE(
mutable_keys.insert(key).second,
"Duplicated key '",
key,

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

Some files were not shown because too many files have changed in this diff Show More