Compare commits

...

22 Commits

Author SHA1 Message Date
5b6cc8215f Change python doc push script to print the undocumented modules 2025-10-21 12:30:49 -07:00
1c43c9cfd0 Update 2025-10-21 12:30:49 -07:00
102e0d5437 Test 2025-10-21 12:30:49 -07:00
0bd12c1168 [CI] Extend test_transfomers to MPS (#165960)
Just skip grad_checks as they need float64
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165960
Approved by: https://github.com/Skylion007
2025-10-21 19:27:44 +00:00
ce8a7764e2 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 1290b077f26543a34262587137ef64ca9ca5e17d.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to failing internal tests D85160820 ([comment](https://github.com/pytorch/pytorch/pull/165707#issuecomment-3429084393))
2025-10-21 19:25:03 +00:00
d1269a0434 update fr trace analysis (#165994)
Summary:
- allow empty entries from ranks
- allow not all ranks to provide dump

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/165994).
* #165638
* #165640
* #165642
* __->__ #165994
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165994
Approved by: https://github.com/fduwjj
2025-10-21 19:14:33 +00:00
c87cf1be32 Update workaround to old CUDA bug (#164354) (#165984)
The workaround cannot be removed because of BC. Here we'll
update PyTorch code base to not use the workaround.

See https://github.com/pytorch/pytorch/pull/164354 for the BC breakage issue.

Resolves https://github.com/pytorch/pytorch/issues/164348.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165984
Approved by: https://github.com/janeyx99
2025-10-21 19:09:43 +00:00
2fc5e45a41 better error message when there is no pytree impl (#165955)
Differential Revision: [D85117597](https://our.internmc.facebook.com/intern/diff/D85117597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165955
Approved by: https://github.com/avikchaudhuri
2025-10-21 18:49:22 +00:00
f9022ba93b [PyTorch] Add user_metadata display to memory visualizer (#165939)
Summary: Enhanced the PyTorch CUDA memory visualizer to display user_metadata alongside stack frames when inspecting allocations. The user_metadata field is now shown in all views (Allocator State History, Active Memory Timeline, etc.) with consistent formatting. The implementation handles both string and object metadata types, displaying strings directly and objects as key-value pairs.

Test Plan:
1. Generate a memory snapshot with user_metadata
2. Open the memory visualizer in a browser
3. Load the snapshot file
4. Verify user_metadata appears
5. Test with both string metadata ("testing") and object metadata ({"key": "value"})
6. Verify formatting shows "User Metadata:\n  <value>" for strings

 {F1982860439}

Differential Revision: D85095152

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165939
Approved by: https://github.com/yushangdi
2025-10-21 18:48:33 +00:00
ff8be889ad Remove unused exception parameter from some files, to work with -Wunused-exception-parameter (#165770)
Summary: address compiler complains that were coming up to unblock the build

Test Plan:
before the change
```
aten/src/ATen/native/LinearAlgebra.cpp:3623:36: error: unused exception parameter 'e' [-Werror,-Wunused-exception-parameter]
 3623 |     } catch (const std::exception& e) {
      |
```

after: targets build with `-Wunused-exception-parameter`

Differential Revision: D84876246

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165770
Approved by: https://github.com/Skylion007, https://github.com/cyyever

Co-authored-by: Tony Targonski <tony.targonski@meta.com>
2025-10-21 18:30:29 +00:00
292454942e [CD] Introduce windows.12xlarge runners for CD Windows build (#165287)
Follows https://github.com/pytorch/test-infra/pull/7174. Windows CD build time cost comparison as below

|Runner|cpu|cuda|xpu|
|-|-|-|-|
|windows.4xlarge|1.5h| 4.0h| 5.5h|
|windows.12xlarge|0.5h|1.5h|2.5h|

Fixes #162962
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165287
Approved by: https://github.com/zxiiro, https://github.com/malfet, https://github.com/seemethere
2025-10-21 18:28:23 +00:00
6c4412f72b Revert "[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)"
This reverts commit e9d89734274a4a2640fa77b898c800a87d1d874e.

Reverted https://github.com/pytorch/pytorch/pull/163316 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
78bf6186f2 Revert "[Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)"
This reverts commit e8cb34dd52c063a130f3e659576c313bbe4b4981.

Reverted https://github.com/pytorch/pytorch/pull/163324 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 17:44:42 +00:00
c40048472c Remove AOTI cross compilation time from internal CI (#165935)
Summary: as title

Test Plan: CI

Differential Revision: D85088451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165935
Approved by: https://github.com/desertfire
2025-10-21 16:58:28 +00:00
3dfd0c7584 Improve PATH hints in FindvecLib.cmake (#165881)
Change  /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk to /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk in `cmake/Modules/FindvecLib.cmake` which is more general (and MacOSX10.9 is not supported now). Otherwise, vecLib can't be found on MacOS 26.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165881
Approved by: https://github.com/ezyang
2025-10-21 16:44:12 +00:00
e6ba4d0725 Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910)
Summary:
Original commit changeset: d6d62d0c96dd

Original Phabricator Diff: D84468451 and D84613184

D84468451 caused CUDA OutOfMemoryError in model.

Test Plan:
D84468451 was found through bisect.  Also double checked on recent trunk 9866939225248c2adc307be7a804b26db0b9b555: f815887517

With this diff that backs out D84468451 and D84613184 : f816114560

Differential Revision: D85025378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165910
Approved by: https://github.com/clee2000
2025-10-21 16:36:38 +00:00
bdf7cb9d9c Revert "[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)"
This reverts commit e20c9bf2889b9252ac45ae6af35c93c795eab701.

Reverted https://github.com/pytorch/pytorch/pull/165410 on behalf of https://github.com/clee2000 due to sorry I'm going to revert this since I want to try to back out some other things that are conflicting with this, there is nothing wrong with this PR, rebasing and resolving the merge conflicts should be enough, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/165410#issuecomment-3427532373))
2025-10-21 16:27:54 +00:00
6aed378958 [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-21 15:53:05 +00:00
8b3dc0d1b0 Better error handling in torch/csrc/jit/runtime/* (#165118)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/runtime/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165118
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 15:22:49 +00:00
06773663b5 Implement an AOT precompile mode for standalone_compile (#165843)
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
2025-10-21 15:02:45 +00:00
0bff65503c Move hardware_destructive_interference_size to c10/core/alignment.h (#160067)
# Motivation
Move `hardware_destructive_interference_size` to `c10/core/alignment.h`, which gives a chance to reuse it across different accelerators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160067
Approved by: https://github.com/Skylion007, https://github.com/EikanWang
2025-10-21 14:39:46 +00:00
21131a2444 Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171))
2025-10-21 14:15:55 +00:00
76 changed files with 980 additions and 1094 deletions

View File

@ -1,15 +1,11 @@
sphinx==5.3.0
sphinx==7.2.6
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
#Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
pytorch_sphinx_theme2==0.1.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.1.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.34.0
breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0
#Pinned versions: 4.36.0
exhale==0.2.3
exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3
#Pinned versions: 0.3.7
docutils==0.16
docutils==0.20
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16
#Pinned versions: 0.20
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==0.17.2
myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
#Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-parser==4.0.1

View File

@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
echo "======================================"
echo "ERROR: $undocumented undocumented objects found!"
echo "======================================"
echo ""
echo "Full coverage report:"
cat build/coverage/python.txt
echo ""
echo "======================================"
echo "Undocumented modules/objects (lines after TOTAL):"
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
echo "======================================"
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1

View File

@ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then
MEMORY_LIMIT_MAX_JOBS=12
NUM_CPUS=$(( $(nproc) - 2 ))
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
if [[ "$(uname)" == Linux ]]; then
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
else
# For other builds
export MAX_JOBS=${NUM_CPUS}
fi
cat >>"$envfile" <<EOL
export MAX_JOBS="${MAX_JOBS}"

View File

@ -54,17 +54,12 @@ self-hosted-runner:
- windows-11-arm64
- windows-11-arm64-preview
# Organization-wide AMD-hosted runners
# MI2xx non-ARC runners
# MI2xx runners
- linux.rocm.gpu
- linux.rocm.gpu.mi250
- linux.rocm.gpu.2
- linux.rocm.gpu.4
- linux.rocm.gpu.mi250
- linux.rocm.gpu.gfx1100
# MI2xx ARC runners
- linux.rocm.gpu.mi250.1
- linux.rocm.gpu.mi250.2
- linux.rocm.gpu.mi250.4
# gfx942 ARC runners
# gfx942 runners
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4

View File

@ -79,9 +79,9 @@ jobs:
runs-on: "windows-11-arm64-preview"
{%- else %}
{%- if branches == "nightly" %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
{%- else %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
{%- endif %}
{%- endif %}
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -279,7 +279,7 @@ jobs:
wheel-py3_10-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -517,7 +517,7 @@ jobs:
wheel-py3_10-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -755,7 +755,7 @@ jobs:
wheel-py3_10-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -993,7 +993,7 @@ jobs:
wheel-py3_10-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1229,7 +1229,7 @@ jobs:
wheel-py3_11-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1464,7 +1464,7 @@ jobs:
wheel-py3_11-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1702,7 +1702,7 @@ jobs:
wheel-py3_11-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1940,7 +1940,7 @@ jobs:
wheel-py3_11-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2178,7 +2178,7 @@ jobs:
wheel-py3_11-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2414,7 +2414,7 @@ jobs:
wheel-py3_12-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2649,7 +2649,7 @@ jobs:
wheel-py3_12-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2887,7 +2887,7 @@ jobs:
wheel-py3_12-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3125,7 +3125,7 @@ jobs:
wheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3363,7 +3363,7 @@ jobs:
wheel-py3_12-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3599,7 +3599,7 @@ jobs:
wheel-py3_13-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3834,7 +3834,7 @@ jobs:
wheel-py3_13-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4072,7 +4072,7 @@ jobs:
wheel-py3_13-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4310,7 +4310,7 @@ jobs:
wheel-py3_13-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4548,7 +4548,7 @@ jobs:
wheel-py3_13-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4784,7 +4784,7 @@ jobs:
wheel-py3_13t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5019,7 +5019,7 @@ jobs:
wheel-py3_13t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5257,7 +5257,7 @@ jobs:
wheel-py3_13t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5495,7 +5495,7 @@ jobs:
wheel-py3_13t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5733,7 +5733,7 @@ jobs:
wheel-py3_13t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5969,7 +5969,7 @@ jobs:
wheel-py3_14-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6204,7 +6204,7 @@ jobs:
wheel-py3_14-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6442,7 +6442,7 @@ jobs:
wheel-py3_14-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6680,7 +6680,7 @@ jobs:
wheel-py3_14-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6918,7 +6918,7 @@ jobs:
wheel-py3_14-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7154,7 +7154,7 @@ jobs:
wheel-py3_14t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7389,7 +7389,7 @@ jobs:
wheel-py3_14t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7627,7 +7627,7 @@ jobs:
wheel-py3_14t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7865,7 +7865,7 @@ jobs:
wheel-py3_14t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -8103,7 +8103,7 @@ jobs:
wheel-py3_14t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -36,12 +36,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
]}
secrets: inherit

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(64) FreeBlockList {
struct alignas(hardware_destructive_interference_size) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(64) HostStatsStaged {
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(64) std::mutex blocks_mutex_;
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(64) std::mutex events_mutex_;
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(64) HostStatsStaged stats_;
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
try {
mkldnn_matmul_i8i8i32(self, mat2, result);
dispatched = true;
} catch (const std::exception& e) {
} catch ([[maybe_unused]] const std::exception& e) {
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
}
}

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -202,7 +202,6 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -9,6 +9,7 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -1,6 +1,7 @@
#pragma once
#include <cstddef>
#include <new>
namespace c10 {
@ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096;
// since the default thp pagesize is 2MB, enable thp only
// for buffers of size 2MB or larger to avoid memory bloating
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
// Cache line size used to avoid false sharing between threads. Falls back to 64
// bytes if C++17 feature is unavailable.
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
} // namespace c10

View File

@ -941,7 +941,7 @@ class EventPool {
private:
struct PerDevicePool {
alignas(64) std::mutex mutex_;
alignas(hardware_destructive_interference_size) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
@ -3758,11 +3758,6 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
static constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
class NativeCachingAllocator : public CUDAAllocator {
private:

View File

@ -554,7 +554,7 @@ static void local_raw_delete(void* ptr);
class XPUAllocator : public DeviceAllocator {
private:
std::mutex mutex;
alignas(hardware_destructive_interference_size) std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
void add_allocated_block(Block* block) {

View File

@ -16,7 +16,7 @@ find_path(vecLib_INCLUDE_DIR vecLib.h
DOC "vecLib include directory"
PATHS /System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
/System/Library/${__veclib_include_suffix}
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
NO_DEFAULT_PATH)

View File

@ -207,6 +207,42 @@ templates_path = [
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,7 +253,6 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -15192,6 +15192,25 @@ graph():
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
)
def test_invalid_pytree_dynamo_graph_capture(self):
class Block:
def __init__(self, a, b):
self.a = a
self.b = b
class Foo(torch.nn.Module):
def forward(self, block):
return block.a + block.b
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
with self.assertRaisesRegex(
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
):
_dynamo_graph_capture_for_export(Foo())(
Block(torch.randn(4, 4), torch.randn(4, 4))
)
def test_enum_str(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"

View File

@ -318,17 +318,19 @@ class inner_f(torch.nn.Module):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x, scale=1.0):
def forward(self, x, *, scale):
return self.linear(x) * scale
model = ModuleWithKwargs()
inputs = (torch.randn(4, 3),)
kwargs = {"scale": 2.0}
kwargs = {"scale": torch.tensor(2.0)}
gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
with ExitStack() as stack:
# Export joint with descriptors
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, kwargs, decompositions=decomposition_table
stack, gm, inputs, kwargs, decompositions=decomposition_table
)
# Test the exported graph structure
@ -336,9 +338,17 @@ class inner_f(torch.nn.Module):
print_output=False, expanded_def=True
)
# For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
# I tried to fix this in normalize_gm but there are too many files
# depending on that behavior..
graph_code_str = normalize_gm(graph_code)
graph_code_str = "\n".join(
[line for line in graph_code_str.split("\n") if len(line.rstrip()) > 0]
)
# Expect test on the printed graph
self.assertExpectedInline(
normalize_gm(graph_code),
graph_code_str,
"""\
class inner_f(torch.nn.Module):
def forward(
@ -346,19 +356,20 @@ class inner_f(torch.nn.Module):
primals,
tangents,
):
primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight')
primals_2: "f32[2]" # ParamAOTInput(target='linear.bias')
primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight')
primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias')
primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
primals_4: "f32[]" # PlainAOTInput(idx=1)
tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0); add = None
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0); tangents_1 = None
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4); add = None
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4); tangents_1 = primals_4 = None
transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
@ -368,12 +379,11 @@ class inner_f(torch.nn.Module):
transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
return pytree.tree_unflatten([
mul_2, # PlainAOTOutput(idx=0)
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight'))
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias'))
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight'))
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias'))
None, # None
None, # None
], self._out_spec)
""",
], self._out_spec)""",
)
# Compile the result

View File

@ -7356,6 +7356,7 @@ metadata incorrectly.
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
@unittest.expectedFailure
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_rms_norm(self):
# Only CUDA rms norm fails to be decomposed

View File

@ -1,74 +0,0 @@
# Owner(s): ["module: inductor"]
import tempfile
import unittest
import zipfile
import torch
import torch._inductor.config
from torch._environment import is_fbcode
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import IS_CI
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
class Simple(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(16, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
class TestAOTInductorWindowsCrossCompilation(TestCase):
@requires_gpu()
def test_simple_so(self):
if is_fbcode() or IS_CI:
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
# TODO: enable in CI
with torch.no_grad():
device = "cuda"
model = Simple().to(device=device)
example_inputs = (torch.randn(8, 10, device=device),)
batch_dim = torch.export.Dim("batch", min=1, max=1024)
exported = torch.export.export(
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
)
package_path = torch._inductor.aoti_compile_and_package(
exported,
inductor_configs={
"aot_inductor.model_name_for_generated_files": "model",
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.link_libtorch": False,
# TODO: need to add aoti_shim_library_path for CI
"aot_inductor.aoti_shim_library": "executorch",
# no fallback ops
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON,CPP",
"max_autotune_conv_backends": "TRITON,CPP",
"aot_inductor.embed_kernel_binary": True,
# simplify things for now
"aot_inductor.precompile_headers": False,
"aot_inductor.package_constants_on_disk_format": "binary_blob",
"aot_inductor.package_constants_in_so": False,
},
)
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(package_path, "r") as zf:
zf.extractall(tmpdir)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests(needs="filelock")

View File

@ -9,6 +9,7 @@ from typing import Any, Optional
import torch
import torch._inductor.config
from torch._environment import is_fbcode
from torch._inductor.test_case import TestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
@ -77,6 +78,9 @@ class WindowsCrossCompilationTestFramework:
"This test should run on Linux for cross-compilation"
)
if is_fbcode():
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
self.assertTrue("WINDOWS_CUDA_HOME" in os.environ)
with torch.no_grad():
@ -128,6 +132,9 @@ class WindowsCrossCompilationTestFramework:
if platform.system() != "Windows":
raise unittest.SkipTest("This test should run on Windows")
if is_fbcode():
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
if not HAS_GPU:
raise unittest.SkipTest("Test requires GPU")

View File

@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase):
@parametrize("format", ("binary", "unpacked"))
@parametrize("dynamic", (False, True))
@parametrize("graph_partition", (False, True))
@parametrize("is_aot", (False, True))
def test_basic(
self, device: str, format: str, dynamic: bool, graph_partition: bool
self,
device: str,
format: str,
dynamic: bool,
graph_partition: bool,
is_aot: bool,
) -> None:
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
# AOT mode does not support unpacked format
if is_aot and format == "unpacked":
raise unittest.SkipTest("AOT mode does not support unpacked format")
mod = torch.nn.Linear(1, 3, device=device)
x = torch.randn(4, 1, device=device)
if dynamic:
@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase):
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact = torch._inductor.standalone_compile(
gm, args, aot=is_aot
)
compiled_artifact.save(path=path, format=format)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase):
compiled_out = loaded(*concrete_args)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
if not is_aot:
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic", (False, True))
def test_call_in_backend(self, dynamic: bool) -> None:
@parametrize("is_aot", (False, True))
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
if dynamic:
@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase):
eager_out = f(x)
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(gm, args)
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
with fresh_cache():
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_dynamic_shapes_from_graph(self):
@parametrize("is_aot", (False, True))
def test_dynamic_shapes_from_graph(self, is_aot: bool):
def f(x):
return x.shape[0] * x
@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes="from_graph"
gm, args, dynamic_shapes="from_graph", aot=is_aot
)
x = torch.ones(4)
(result,) = compiled_artifact(4, x)
@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
def test_split_module(self):
@parametrize("is_aot", (False, True))
def test_split_module(self, is_aot):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
x = x + (a0**2) + (a1 / 2)
@ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
split = torch.fx.passes.split_module.split_module(gm, gm, split)
# Each of the split graphs only has one output.
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
ca0 = torch._inductor.standalone_compile(
split.submod_0, (a0, x, a1), aot=is_aot
)
ca1 = torch._inductor.standalone_compile(
split.submod_1, (b0, x, b1), aot=is_aot
)
ca2 = torch._inductor.standalone_compile(
split.submod_2, (c0, x, c1), aot=is_aot
)
y = ca0(a0, x, a1)
y = ca1(b0, y, b1)
y = ca2(c0, y, c1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
if not is_aot:
# fx graph cache doesn't run in AOT mode
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
# TODO: split_module causes ca1 and ca2 to have different type annotations
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (False, True))
@parametrize("config_patches", [True, False])
def test_dynamic_shapes_from_example_inputs(self, config_patches):
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
def f(x):
return x.shape[0] * x
@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
(5, torch.ones(4)),
dynamic_shapes="from_example_inputs",
options={"config_patches": config_patches},
aot=is_aot,
)
x = torch.ones(4)
(result,) = compiled_artifact(3, x)
@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
def test_static_shapes(self, dynamic_shapes):
def test_static_shapes(self, dynamic_shapes, is_aot):
def f(x):
return x.shape[0] * x
@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
static_gm, [static_x], dynamic_shapes=dynamic_shapes
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
)
x = torch.randn(3)
(result,) = compiled_artifact(x)
@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
def test_backend(self, dynamic_shapes):
def test_backend(self, dynamic_shapes, is_aot):
def f(x):
return x.shape[0] * x
@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
)
y = torch.randn(4)
(result,) = compiled_artifact(4, y)
@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_backend_dynamic_shapes_from_example_inputs(self):
@parametrize("is_aot", (True, False))
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
def f(x):
return x.shape[0] * x
@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
)
y = torch.ones(4)
(result,) = compiled_artifact(4, y)

View File

@ -1543,26 +1543,22 @@ class CPUReproTests(TestCase):
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
inputs = (
x,
scale,
zero_point,
use_dequant,
use_quant,
quant_min,
quant_max,
dtype,
dequant_out_dtype,
self.common(
fn,
(
x,
scale,
zero_point,
use_dequant,
use_quant,
quant_min,
quant_max,
dtype,
dequant_out_dtype,
),
)
self.common(fn, inputs)
check_metrics_vec_kernel_count(1)
# Check that both main and tail loops are vectorized
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
compiled_fn = torch.compile(fn)
_, code = run_and_get_cpp_code(compiled_fn, *inputs)
FileCheck().check_count("loadu", 2, exactly=True).run(code)
@requires_vectorization
def test_dequant_quant_lowering_uint8(self):
self._test_dequant_quant_lowering_helper(torch.uint8)
@ -4814,22 +4810,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((22, 22), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
).run(code)
def test_double_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
@ -4839,22 +4819,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((22, 22), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
).run(code)
def test_convert_fp32_to_double_vec(self):
def fn(x):
return x.to(torch.double)
@ -4864,22 +4828,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn(22, 22)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::convert<double,2,float,1>", 2, exactly=True
).run(code)
def test_convert_double_to_fp32_vec(self):
def fn(x):
return x.to(torch.float32)
@ -4889,22 +4837,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((22, 22), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::convert<float,1,double,2>", 2, exactly=True
).run(code)
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
# https://github.com/pytorch/pytorch/issues/115260
p0 = torch.tensor([1.0879], dtype=torch.float16)

View File

@ -85,7 +85,6 @@ def init_lists():
"linalg_inv_ex",
"linalg_pinv.atol_rtol_tensor",
"logsumexp",
"svd",
}
# For some ops, we don't support all variants. Here we use formatted_name
# to uniquely identify the variant.
@ -221,15 +220,20 @@ class TestLazyOpInfo(TestCase):
torch._lazy.wait_device_ops()
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
metrics = remove_suffixes(torch._lazy.metrics.counter_names())
cands = [f"{prefix}::{op.name}{symint_suffix}"]
# check aliases
for alias in op.aliases:
cands.append(f"{prefix}::{alias.name}{symint_suffix}")
self.assertTrue(
any(c in metrics for c in cands), f"none of {cands} not found in {metrics}"
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
torch._lazy.metrics.counter_names()
)
# check aliases
if not found:
for alias in op.aliases:
alias_found = (
f"{prefix}::{alias.name}{symint_suffix}"
in remove_suffixes(torch._lazy.metrics.counter_names())
)
found = found or alias_found
if found:
break
self.assertTrue(found)
@ops(
[

View File

@ -1258,10 +1258,11 @@ class DecompOneOffTests(TestCase):
)
# check RMSNorm was fused with sinh
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
self.assertTrue(
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
in generated_codes[1]
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
)

View File

@ -17,7 +17,7 @@ from unittest.mock import patch, MagicMock, ANY
import math
import itertools
import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, largeTensorTest
from torch.testing._internal.common_device_type import expectedFailureMPS, instantiate_device_type_tests, onlyCUDA, largeTensorTest
from typing import Optional
import torch.utils.cpp_extension
from torch.testing._internal.common_nn import NNTestCase
@ -2022,6 +2022,7 @@ class TestSDPA(NNTestCase):
for both cpu and cuda. If you're test is only applicable to cuda,
add it to TestSDPACudaOnly.
"""
@expectedFailureMPS # No double support
@parametrize("contiguous_inputs", [True, False])
def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
@ -4625,13 +4626,13 @@ class TestAttnBias(NNTestCase):
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
if NOTEST_CPU:
device_types = ("cuda", )
device_types = ("cuda", "mps")
else:
device_types = ("cpu", "cuda")
device_types = ("cpu", "cuda", "mps")
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types, allow_mps=True)
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types, allow_mps=True)
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)

View File

@ -754,6 +754,10 @@ def align_trace_from_beginning(
# Rank 3: [0, 1, 2, 3, 4, 5, None]
# Then we should start from collective 2 not 0 because any collective before,
# we don't have complete records from all ranks so we need to ignore them.
# If we don't have any trace from some ranks, ignore them
# as well.
if len(entries[rank]) == 0:
continue
first_record_id = entries[rank][0]["record_id"]
maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)

View File

@ -404,7 +404,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.max_unpool3d,
aten.mish,
aten.mish_,
aten.mish_backward,
aten.mse_loss,
aten.mse_loss_backward,
aten.multi_margin_loss,
@ -420,7 +419,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
@ -477,7 +475,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.silu,
aten.silu_,
aten.silu_backward.grad_input,
aten.silu_backward,
aten.sinc,
aten.sinc_,
aten.slice_backward,

View File

@ -1757,61 +1757,6 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm.default)
def _fused_rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Optional[Tensor],
eps: Optional[float],
) -> tuple[Tensor, Tensor]:
dims_to_reduce: list[int] = []
for i in range(len(normalized_shape)):
dims_to_reduce.append(input.dim() - i - 1)
# upcast is needed for fp16 and bf16
computation_dtype = utils.get_computation_dtype(input.dtype)
upcasted_input = input.to(computation_dtype)
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
if eps is None:
if computation_dtype in (torch.float32, torch.complex64):
eps_val = torch.finfo(torch.float32).eps
else:
eps_val = torch.finfo(torch.float64).eps
else:
eps_val = eps
rqrst_input = torch.rsqrt(
# NB: don't inplace here, will violate functional IR invariant
# NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
torch.ops.aten.add.Scalar(
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
)
)
upcasted_result = upcasted_input.mul(rqrst_input)
if weight is not None:
upcasted_result = upcasted_result.mul(weight)
# NB: nested should be dead here, just here for fidelity
is_nested = input.is_nested or (weight is not None and weight.is_nested)
memory_format = utils.suggest_memory_format(input)
is_channels_last = memory_format in (
torch.channels_last,
torch.channels_last_3d,
)
if not is_nested and not is_channels_last:
upcasted_result = upcasted_result.contiguous()
rqrst_input = rqrst_input.contiguous()
# Cast normalized result back to original input type
result = upcasted_result.type_as(input)
return result, rqrst_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,

View File

@ -1,4 +1,3 @@
import abc
import dataclasses
import importlib
import inspect
@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type
from torch._dynamo.package import SystemInfo
from . import convert_frame
from .aot_compile_types import (
BundledAOTAutogradSerializableCallable,
SerializableCallable,
)
from .hooks import Hooks
@ -26,18 +29,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
def bind_locals(
signature: inspect.Signature, *args: Any, **kwargs: Any
) -> dict[str, Any]:
@ -149,53 +140,6 @@ class AOTCompiledFunction:
self._guard_check_enabled = False
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
if hasattr(self, attr):
return getattr(super(), attr)
else:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)
def aot_compile_fullgraph(
model: Any,
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],

View File

@ -0,0 +1,61 @@
import abc
import pickle
from typing import Any
import torch
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)

View File

@ -1707,6 +1707,39 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
)
def check_user_input_output(flat_values: list[Any], error_type: UserErrorType) -> None:
supported_types = [
torch.Tensor,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
torch._C.ScriptObject,
_IntWrapper,
] + list(common_constant_types)
def is_supported_type(val: Any) -> bool:
return isinstance(val, tuple(supported_types))
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
# We only check that the outputs are not None. Inputs can be None.
for v in flat_values:
if not is_supported_type(v):
if error_type == UserErrorType.INVALID_INPUT and v is None:
continue
raise UserError(
error_type,
f"It looks like one of the {value_type}s with type `{type(v)}` "
"is not supported or pytree-flattenable. \n"
f"Exported graphs {value_type}s can only contain the "
f"following supported types: {supported_types}. \n"
"If you are using a custom class object, "
"please register a pytree_flatten/unflatten function "
"using `torch.utils._pytree.register_pytree_node` or "
"`torch.export.register_dataclass`.",
)
def rewrite_signature(
f_sig: inspect.Signature,
graph: torch.fx.GraphModule,
@ -1721,40 +1754,6 @@ def rewrite_signature(
) -> torch.fx.GraphModule:
orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
def check_user_input_output(
flat_values: list[Any], error_type: UserErrorType
) -> None:
supported_types = [
torch.Tensor,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
torch._C.ScriptObject,
_IntWrapper,
] + list(common_constant_types)
def is_supported_type(val: Any) -> bool:
return isinstance(val, tuple(supported_types))
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
# We only check that the outputs are not None. Inputs can be None.
for v in flat_values:
if not is_supported_type(v):
if error_type == UserErrorType.INVALID_INPUT and v is None:
continue
raise UserError(
error_type,
f"It looks like one of the {value_type}s with type `{type(v)}` "
"is not supported or pytree-flattenable. \n"
f"Exported graphs {value_type}s can only contain the "
f"following supported types: {supported_types}. \n"
"If you are using a custom class object, "
"please register a pytree_flatten/unflatten function "
"using `torch.utils._pytree.register_pytree_node` or "
"`torch.export.register_dataclass`.",
)
check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)

View File

@ -10,7 +10,8 @@ import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names
from torch._dynamo.eval_frame import argument_names, check_user_input_output
from torch._dynamo.exc import UserErrorType
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._export.utils import _compiling_state_context
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
@ -479,6 +480,7 @@ def _dynamo_graph_capture_for_export(
# This sets the is_exporting flag when building guards.
with _compiling_state_context():
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
check_user_input_output(flat_inputs, UserErrorType.INVALID_INPUT)
module_to_trace = ModuleToTrace(mod, in_spec)
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod

View File

@ -200,10 +200,9 @@ class SuperVariable(VariableTracker):
and not (args or kwargs)
):
with do_not_convert_to_tracable_parameter():
fn_vt = VariableTracker.build(
tx, unpatched_nn_module_init, source=source
)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
return variables.UserFunctionVariable(
unpatched_nn_module_init, source=source
).call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented_v2(
gb_type="Unsupported super().__init__() call",
@ -231,8 +230,9 @@ class SuperVariable(VariableTracker):
elif isinstance(inner_fn, staticmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
return fn_vt.call_function(tx, args, kwargs)
return variables.UserFunctionVariable(
inner_fn.__func__, source=source
).call_function(tx, args, kwargs)
elif isinstance(inner_fn, classmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
tx, self.objvar.value_type, cls_source
)
fn_vt = VariableTracker.build(
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
)
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
return variables.UserFunctionVariable(
inner_fn.__func__, source=AttrSource(source, "__func__")
).call_function(tx, [cls_variable, *args], kwargs)
elif isinstance(inner_fn, types.FunctionType):
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
return variables.UserFunctionVariable(
inner_fn, source=source
).call_function(tx, [self.objvar] + args, kwargs)
elif isinstance(inner_fn, types.MethodType):
return variables.UserMethodVariable(
inner_fn.__func__, self.objvar, source=source
@ -574,8 +574,10 @@ class ComptimeVariable(VariableTracker):
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
return VariableTracker.build(
tx, getattr(comptime, name), source=AttrSource(self.source, name)
from .functions import UserFunctionVariable
return UserFunctionVariable(
getattr(comptime, name), source=AttrSource(self.source, name)
)
def call_function(
@ -769,8 +771,9 @@ class AutogradFunctionVariable(VariableTracker):
sig = inspect.signature(fn)
if len(args) - 1 == len(sig._parameters):
args = args[1:] # Don't use context
fn_vt = VariableTracker.build(tx, fn, source=source)
return fn_vt.call_function(tx, args, kwargs)
return variables.UserFunctionVariable(fn, source=source).call_function(
tx, args, kwargs
)
elif isinstance(fn, types.MethodType):
return variables.UserMethodVariable(
fn.__func__,
@ -796,8 +799,9 @@ class AutogradFunctionVariable(VariableTracker):
assert isinstance(fn, types.FunctionType)
fn_source = AttrSource(self.source, "backward")
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
return fn_vt.call_function(tx, args, kwargs)
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx, args, kwargs
)
def call_function(self, tx: "InstructionTranslator", args, kwargs):
return AutogradFunctionVariable(self.fn_cls)
@ -1022,12 +1026,10 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
assert tx.one_graph or tx.error_on_graph_break, (
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
)
fn_vt = VariableTracker.build(
tx,
return variables.UserFunctionVariable(
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
source=self.source,
)
return fn_vt.call_function(
).call_function(
tx,
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
kwargs,

View File

@ -1347,6 +1347,15 @@ def create_functional_call(
maybe_disable_thunkify(),
):
if isinstance(mod, torch.fx.GraphModule):
if kwargs:
# Handle **kwargs. FX only natively supports positional
# arguments (through placeholders).
arg_list = list(args[params_len:])
arg_list.extend(list(kwargs.values()))
args = tuple(arg_list)
else:
args = args[params_len:]
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled."
@ -1355,9 +1364,7 @@ def create_functional_call(
fake_mode = detect_fake_mode()
assert fake_mode is not None
fake_mode.epoch += 1
out = PropagateUnbackedSymInts(mod).run(
*args[params_len:], **kwargs
)
out = PropagateUnbackedSymInts(mod).run(*args)
else:
out = mod(*args[params_len:], **kwargs)

View File

@ -391,6 +391,7 @@ def standalone_compile(
"from_example_inputs", "from_tracing_context", "from_graph"
] = "from_graph",
options: Optional[dict[str, Any]] = None,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Precompilation API for inductor.
@ -422,5 +423,5 @@ def standalone_compile(
options = options if options else {}
return standalone_compile(
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
)

View File

@ -159,14 +159,11 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
]
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
torch.float64,
torch.float,
torch.bfloat16,
torch.float16,
torch.uint8,
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
]

View File

@ -5,10 +5,12 @@ import logging
import os
import pickle
import shutil
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
@ -30,9 +32,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class CompiledArtifact:
class CompiledArtifact(ABC):
"""
CompiledArtifact class represents the precompiled inductor artifact that
CompiledArtifact class represents the inductor cache artifacts that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
@ -45,11 +47,68 @@ class CompiledArtifact:
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the precompiled artifact.
to execute the cached artifact.
"""
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
@abstractmethod
def __call__(self, *args: Any) -> Any: ...
@abstractmethod
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None: ...
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
# If format is unpacked, it must be a CacheCompiledArtifact
return CacheCompiledArtifact.load(path=path, format=format)
assert format == "binary"
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
if header == AOTCompiledArtifact.AOT_HEADER:
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
# Otherwise, it's in the CacheCompiledArtifact format
elif header == CacheCompiledArtifact.CACHE_HEADER:
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return CacheCompiledArtifact._load_impl(nullcontext(), key)
else:
raise RuntimeError(
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
+ header.decode("utf-8")
)
class CacheCompiledArtifact(CompiledArtifact):
"""
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
"""
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
def __init__(
self,
@ -83,6 +142,7 @@ class CompiledArtifact:
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
@ -116,9 +176,51 @@ class CompiledArtifact:
log.info("Output code written to: %s", output_file)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
def _load_impl(
cache_dir_ctx: AbstractContextManager[Any], key: str
) -> CompiledArtifact:
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert result is not None
(entry, _) = result
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
@staticmethod
def _prepare_load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> tuple[str, AbstractContextManager[Any]]:
"""
Do format specific prep and loads, return a context manager and key
"""
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
@ -137,8 +239,7 @@ class CompiledArtifact:
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
return key, nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
@ -148,43 +249,105 @@ class CompiledArtifact:
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
return key, cache_dir_ctx
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
path=path, format=format
)
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert result is not None
(entry, _) = result
class AOTCompiledArtifact(CompiledArtifact):
"""
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
This object is always a serializable callable function.
from .compile_fx import _CompileFxKwargs
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
is used by torch._dynamo.aot_compile for AOT Precompilation.
"""
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
def __init__(
self,
compiled_fn: Callable[..., Any],
):
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
self._artifacts = (
None # We don't need artifacts, the inner object handles everything
)
@staticmethod
def from_bundled_callable(
bundled_fn: BundledAOTAutogradSerializableCallable,
) -> AOTCompiledArtifact:
return AOTCompiledArtifact(bundled_fn.compiled_fn)
def __call__(self, *args: Any) -> Any:
return self.inner_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
result_bytes = self.serialize()
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
writer.write_bytes(torch_key())
writer.write_bytes(result_bytes)
from torch._inductor.codecache import write_atomic
# Save a sentinel file to indicate that this is AOT
write_atomic(path, writer.to_bytes())
def serialize(self) -> bytes:
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
self.inner_fn
)
@staticmethod
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
deserialized = (
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
result_bytes
)
)
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
return AOTCompiledArtifact.from_bundled_callable(deserialized)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
assert header == AOTCompiledArtifact.AOT_HEADER
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
def standalone_compile(
@ -193,7 +356,11 @@ def standalone_compile(
*,
dynamic_shapes: Any,
options: Any,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Implementation of torch.inductor.standalone_compile
"""
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
@ -249,6 +416,7 @@ def standalone_compile(
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
torch._functorch.config.patch("bundled_autograd_cache", aot),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
@ -256,7 +424,12 @@ def standalone_compile(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
if aot:
if not hasattr(compiled_fn, "serialize"):
raise RuntimeError(
"Compiled function should have serialize method when aot=True"
)
return AOTCompiledArtifact(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
@ -264,4 +437,4 @@ def standalone_compile(
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CompiledArtifact(compiled_fn, artifacts)
return CacheCompiledArtifact(compiled_fn, artifacts)

View File

@ -15,7 +15,6 @@ from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
_disable_infra_mode,
autograd_would_have_decomposed,
return_and_correct_aliasing,
TorchDispatchMode,
)
@ -410,13 +409,8 @@ class FunctionalTensorMode(TorchDispatchMode):
return False
return True
# in normal torch.compile IR, we only decompose an op if autograd
# would have decomposed it (NB: autograd may have been skipped if
# we are in inference mode)
# TODO: the flatten here can potentially be deduped with the
# unwrapping pytree_map later
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
return autograd_would_have_decomposed(func, flat_args_kwargs)
# in normal torch.compile IR, we decompose functional composite ops
return True
if (
func not in FunctionalTensor.metadata_fns

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/runtime/logging.h>
#include <c10/util/Exception.h>
#include <atomic>
#include <chrono>
#include <mutex>
@ -33,7 +34,7 @@ int64_t LockingLogger::getCounterValue(const std::string& name) const {
return raw_counter.sum / raw_counter.count;
} break;
}
throw std::runtime_error("Unknown aggregation type!");
TORCH_CHECK(false, "Unknown aggregation type!");
}
void LockingLogger::setAggregationType(

View File

@ -11,6 +11,7 @@
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
#include <limits>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
namespace torch::jit {
@ -112,20 +113,17 @@ void listRemove<at::Tensor>(Stack& stack) {
}
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
if (t.requires_grad()) {
throw std::runtime_error(
"Cannot input a tensor that requires grad as a scalar argument");
}
if (!t.sizes().empty()) {
throw std::runtime_error(
"Cannot input a tensor of dimension other than 0 as a scalar argument");
}
if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
std::stringstream ss;
ss << "Cannot input a tensor of type " << t.scalar_type()
<< " as an integral argument";
throw std::runtime_error(ss.str());
}
TORCH_CHECK(
!t.requires_grad(),
"Cannot input a tensor that requires grad as a scalar argument");
TORCH_CHECK(
t.sizes().empty(),
"Cannot input a tensor of dimension other than 0 as a scalar argument");
TORCH_CHECK(
!toInt || isIntegralType(t.scalar_type(), /*includeBool=*/false),
"Cannot input a tensor of type ",
t.scalar_type(),
" as an integral argument");
}
void checkDoubleInRange(double a) {

View File

@ -1,5 +1,6 @@
#include <ATen/autocast_mode.h>
#include <ATen/core/Generator.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
@ -159,9 +160,8 @@ void sort_op(Stack& stack) {
if (!g_list.empty()) {
std::stringstream error_str;
if (!isSortableListOfObjectsOrTuples(g_list, error_str)) {
throw std::runtime_error(error_str.str());
}
TORCH_CHECK(
isSortableListOfObjectsOrTuples(g_list, error_str), error_str.str());
c10::IValueComparator comparator;
if (reverse) {
@ -254,9 +254,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
int64_t lo = 0, hi = 0, step = 0;
pop(stack, lo, hi, step);
// error handling when step_val = 0 during runtime
if (step == 0) {
throw std::runtime_error("range() arg 3 must not be zero");
}
TORCH_CHECK(step != 0, "range() arg 3 must not be zero");
if (step > 0 && lo < hi) {
push(stack, 1 + (hi - 1 - lo) / step);
} else if (step < 0 && lo > hi) {
@ -382,14 +380,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
auto s = pop(stack).toString();
std::string::size_type sz = 0;
int64_t val = static_cast<int64_t>(std::stoll(s->string(), &sz));
if (sz == s->string().size()) {
push(stack, val);
} else {
std::stringstream error_str;
error_str << "invalid literal for int() "
<< "with base 10: '" << s->string() << "'";
throw std::runtime_error(error_str.str());
}
TORCH_CHECK(
sz == s->string().size(),
"invalid literal for int() ",
"with base 10: '",
s->string(),
"'");
push(stack, val);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
@ -436,14 +433,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
auto s = pop(stack).toString();
std::string::size_type sz = 0;
double b = std::stod(s->string(), &sz);
if (sz == s->string().size()) {
push(stack, b);
} else {
std::stringstream error_str;
error_str << "could not convert string "
<< "to float: '" << s->string() << "'";
throw std::runtime_error(error_str.str());
}
TORCH_CHECK(
sz == s->string().size(),
"could not convert string ",
"to float: '",
s->string(),
"'");
push(stack, b);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
@ -1793,10 +1789,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
}
const std::string& separator = ivalue.toStringRef();
if (separator.empty()) {
throw std::runtime_error("ValueError: empty separator");
}
TORCH_CHECK(!separator.empty(), "ValueError: empty separator");
auto count = 0;
@ -1919,11 +1912,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string fillchar = pop(stack).toStringRef();
int64_t width = pop(stack).toInt();
std::string string = pop(stack).toStringRef();
if (fillchar.size() != 1) {
// TODO: this should be a TypeError
throw std::runtime_error(
"TypeError: The fill character must be exactly one character long");
}
TORCH_CHECK(
fillchar.size() == 1,
"TypeError: The fill character must be exactly one character long");
if (string.size() > static_cast<std::string::size_type>(width)) {
push(stack, string);
return;
@ -2092,9 +2083,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string substr = pop(stack).toStringRef();
std::string string = pop(stack).toStringRef();
auto result = stringFindImpl(string, substr, start, end);
if (result < 0) {
throw std::runtime_error("ValueError: substring not found");
}
TORCH_CHECK(result >= 0, "ValueError: substring not found");
push(stack, result);
},
aliasAnalysisFromSchema()),
@ -2107,9 +2096,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string substr = pop(stack).toStringRef();
std::string string = pop(stack).toStringRef();
auto result = stringFindImpl(string, substr, start, end, true);
if (result < 0) {
throw std::runtime_error("ValueError: substring not found");
}
TORCH_CHECK(result >= 0, "ValueError: substring not found");
push(stack, result);
},
aliasAnalysisFromSchema()),
@ -2183,11 +2170,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string fillchar = pop(stack).toStringRef();
int64_t width = pop(stack).toInt();
std::string string = pop(stack).toStringRef();
if (fillchar.size() != 1) {
// TODO: this should be a TypeError
throw std::runtime_error(
"TypeError: The fill character must be exactly one character long");
}
TORCH_CHECK(
fillchar.size() == 1,
"TypeError: The fill character must be exactly one character long");
auto to_append =
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
@ -2207,11 +2192,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string fillchar = pop(stack).toStringRef();
int64_t width = pop(stack).toInt();
std::string string = pop(stack).toStringRef();
if (fillchar.size() != 1) {
// TODO: this should be a TypeError
throw std::runtime_error(
"TypeError: The fill character must be exactly one character long");
}
TORCH_CHECK(
fillchar.size() == 1,
"TypeError: The fill character must be exactly one character long");
auto to_append =
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
@ -3358,10 +3341,8 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
int64_t a = 0, b = 0;
lldiv_t divresult = {};
pop(stack, a, b);
if (b == 0) {
throw std::runtime_error(
"ZeroDivisionError: integer division or modulo by zero");
}
TORCH_CHECK(
b != 0, "ZeroDivisionError: integer division or modulo by zero");
divresult = lldiv(a, b);
if (divresult.rem && (a < 0) != (b < 0)) {
divresult.quot -= 1;
@ -3379,9 +3360,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
[](Stack& stack) {
double a = 0, b = 0;
pop(stack, a, b);
if (b == 0) {
throw std::runtime_error("ZeroDivisionError: float divmod()");
}
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()");
double rem = fmod(a, b);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
if (rem && (a < 0) != (b < 0)) {
@ -3426,9 +3405,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
type_a a; \
type_b b; \
pop(stack, a, b); \
if (b == 0) { \
throw std::runtime_error("ZeroDivisionError: float divmod()"); \
} \
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()"); \
double quot = floor(a / b); \
double rem = a - (quot * b); \
push(stack, quot, rem); \

View File

@ -466,14 +466,6 @@ at::Tensor LazyNativeFunctions::linalg_pinv(
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::svd(
const at::Tensor& self,
bool some,
bool compute_uv) {
return at::functionalization::functionalize_aten_op<ATEN_OP(svd)>::call(
self, some, compute_uv);
}
// functionalize_aten_op can't handle out= ops directly.
// Instead, we can call the composite kernel from core, and copy and mutations
// back to the inputs.

View File

@ -21,10 +21,6 @@ backends are ready, this list allows opt-in one at a time.
PRESERVED_ATEN_CIA_OPS = {
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
# NB: don't use the C++ decomp, because it is not functional!
torch.ops.aten.silu_backward.default,
torch.ops.aten.mish_backward.default,
torch.ops.aten._fused_rms_norm.default,
}

View File

@ -63,7 +63,6 @@ from torch.utils._python_dispatch import (
_disable_infra_mode,
_push_mode,
_unset_infra_mode,
autograd_would_have_decomposed,
TorchDispatchMode,
)
from torch.utils._stats import count
@ -1033,16 +1032,11 @@ def proxy_call(
return r
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if (
not pre_dispatch
and func
not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]
and autograd_would_have_decomposed(func, flat_args_kwargs)
):
if not pre_dispatch and func not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]:
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:

View File

@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
def _compile_submod(gm, prefix):
from torch._inductor.standalone_compile import AOTCompiledArtifact
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith(prefix):
fake_inputs = []
@ -56,13 +58,12 @@ def _compile_submod(gm, prefix):
submod = getattr(gm, node.target)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(
torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context"
)
compiled_fn = torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
)
assert isinstance(compiled_fn, AOTCompiledArtifact)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(compiled_fn)
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
compiled_submod, args=node.args, kwargs=node.kwargs

View File

@ -63,15 +63,15 @@ struct dummy_int1_7_t {};
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(c10::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
_(c10::BFloat16, BFloat16) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
@ -81,19 +81,19 @@ struct dummy_int1_7_t {};
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(c10::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<c10::Half>, ComplexHalf) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
_(c10::BFloat16, BFloat16) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
@ -103,7 +103,7 @@ struct dummy_int1_7_t {};
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(c10::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
@ -113,7 +113,7 @@ struct dummy_int1_7_t {};
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
@ -176,24 +176,19 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE>, SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
@ -203,53 +198,41 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
SCALARTYPE5) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE6>::t), \
SCALARTYPE6) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE7>::t), \
SCALARTYPE7)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
SCALARTYPE2) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, \
SCALARTYPE3) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE4>, \
SCALARTYPE4) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE5>, \
SCALARTYPE5) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE6>, \
SCALARTYPE6) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE7>, SCALARTYPE7)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
@ -258,12 +241,12 @@ struct dummy_int1_7_t {};
_(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
@ -298,7 +281,12 @@ struct ScalarTypeToCPPType;
/* can't pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
/* UPDATE: while the CUDA bug is fixed, we cannot remove the */ \
/* workaround as it is BC breaking. However, it is recommended to */ \
/* update any code that contains */ \
/* decltype(ScalarTypeToCPPType<T>::t) */ \
/* with */ \
/* ScalarTypeToCPPTypeT<T> */ \
static type t; \
};

View File

@ -38,8 +38,7 @@ class BytesWriter:
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
4, byteorder="big", signed=False
)
if len(digest) != CHECKSUM_DIGEST_SIZE:
raise AssertionError("Computed checksum digest has unexpected size")
assert len(digest) == CHECKSUM_DIGEST_SIZE
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
return bytes(self._data)
@ -47,13 +46,11 @@ class BytesWriter:
class BytesReader:
def __init__(self, data: bytes) -> None:
# Check for data corruption
if len(data) < CHECKSUM_DIGEST_SIZE:
raise AssertionError("Input data is too short to contain checksum")
assert len(data) >= CHECKSUM_DIGEST_SIZE
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
4, byteorder="big", signed=False
)
if len(digest) != CHECKSUM_DIGEST_SIZE:
raise AssertionError("Computed checksum digest has unexpected size")
assert len(digest) == CHECKSUM_DIGEST_SIZE
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
raise RuntimeError(
"Bytes object is corrupted, checksum does not match. "
@ -123,11 +120,7 @@ class AppendingByteSerializer(Generic[T]):
@staticmethod
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
reader = BytesReader(data)
if reader.read_uint64() != _ENCODING_VERSION:
raise AssertionError(
f"Encoding version mismatch in AppendingByteSerializer.to_list, \
got {reader.read_uint64()}"
)
assert reader.read_uint64() == _ENCODING_VERSION
result: list[T] = []
while not reader.is_finished():

View File

@ -85,16 +85,12 @@ class _Config(Generic[T]):
)
if self.alias is not None:
if (
self.default is not _UNSET_SENTINEL
or self.justknob is not None
or self.env_name_default is not None
or self.env_name_force is not None
):
raise AssertionError(
"if alias is set, none of {default, justknob, \
env_name_default and env_name_force} can be set"
)
assert (
self.default is _UNSET_SENTINEL
and self.justknob is None
and self.env_name_default is None
and self.env_name_force is None
), "if alias is set, none of {default, justknob and env var} can be set"
@staticmethod
def string_or_list_of_string_to_list(
@ -104,8 +100,7 @@ class _Config(Generic[T]):
return None
if isinstance(val, str):
return [val]
if not isinstance(val, list):
raise AssertionError(f"val is not a list, got {type(val)}")
assert isinstance(val, list)
return val
@ -198,10 +193,7 @@ def install_config_module(module: ModuleType) -> None:
if dest is module:
delattr(module, key)
elif isinstance(value, type):
if value.__module__ != module.__name__:
raise AssertionError(
f"subconfig class {value} must be defined in module {module.__name__}"
)
assert value.__module__ == module.__name__
# a subconfig with `class Blah:` syntax
proxy = SubConfigProxy(module, f"{name}.")
visit(value, proxy, f"{name}.")
@ -242,8 +234,10 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
prev_name = ""
maybe_current = token.string.strip()
if COMPILE_IGNORED_MARKER in maybe_current:
if current_comment != ("", -1):
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
assert current_comment == (
"",
-1,
), f"unconsumed {COMPILE_IGNORED_MARKER}"
current_comment = maybe_current, token.start[0]
elif token.type == tokenize.NAME:
# Only accept the first name token, to handle if you have
@ -260,8 +254,7 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
assignments.add(prev_name)
current_comment = "", -1 # reset
prev_name = ""
if current_comment != ("", -1):
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
return assignments
@ -313,22 +306,20 @@ class _ConfigEntry:
# Ensure justknobs and envvars are allowlisted types
if self.justknob is not None and self.default is not None:
if not isinstance(self.default, bool):
raise AssertionError(
f"justknobs only support booleans, {self.default} is not a boolean"
)
assert isinstance(self.default, bool), (
f"justknobs only support booleans, {self.default} is not a boolean"
)
if self.value_type is not None and (
config.env_name_default is not None or config.env_name_force is not None
):
if self.value_type not in (
assert self.value_type in (
bool,
str,
Optional[bool],
Optional[str],
):
raise AssertionError(
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
)
), (
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
)
class ConfigModule(ModuleType):
@ -426,10 +417,7 @@ class ConfigModule(ModuleType):
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
data = self._get_alias_module_and_name(entry)
if data is None:
raise AssertionError(
"alias data should not be None when setting alias value"
)
assert data is not None
module, constant_name = data
setattr(module, constant_name, val)
@ -654,32 +642,19 @@ class ConfigModule(ModuleType):
changes: dict[str, Any]
if arg1 is not None:
if arg2 is not None:
if not isinstance(arg1, str):
raise AssertionError(
"first argument must be a string when passing 2 positional args to patch"
)
assert isinstance(arg1, str)
# patch("key", True) syntax
changes = {arg1: arg2}
else:
if not isinstance(arg1, dict):
raise AssertionError(
"first argument must be a dict when passing a single positional arg to patch"
)
assert isinstance(arg1, dict)
# patch({"key": True}) syntax
changes = arg1
if kwargs:
raise AssertionError(
"cannot pass both positional and keyword arguments to patch"
)
assert not kwargs
else:
# patch(key=True) syntax
changes = kwargs
if arg2 is not None:
raise AssertionError(
"second positional argument is only valid when first argument is a key string"
)
if not isinstance(changes, dict):
raise AssertionError(f"expected `dict` got {type(changes)}")
assert arg2 is None
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
prior: dict[str, Any] = {}
config = self
@ -688,10 +663,7 @@ class ConfigModule(ModuleType):
self.changes = changes
def __enter__(self) -> None:
if prior:
raise AssertionError(
"prior should be empty when entering ConfigPatch"
)
assert not prior
for key in self.changes.keys():
# KeyError on invalid entry
prior[key] = config.__getattr__(key)

View File

@ -21,8 +21,7 @@ This file should be imported into any file that uses install_config_module like
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
"""
if not TYPE_CHECKING: # noqa: PYI002
raise AssertionError("Do not use at runtime") # noqa: W291
assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ...
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...

View File

@ -217,10 +217,7 @@ class ContentStoreReader:
weights_only=True,
map_location=device,
)._untyped_storage
if s is None:
raise AssertionError(
f"expected storage for hash {h} in {os.path.join(self.loc, 'storages')}, got None"
)
assert s is not None
if self.storage_cache is not None:
self.storage_cache[device][h] = StorageWeakRef(s)
return s

View File

@ -86,14 +86,13 @@ def context_decorator(ctx, func):
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
if callable(ctx) and hasattr(ctx, "__enter__"):
raise AssertionError(
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
if not callable(ctx):

View File

@ -933,10 +933,7 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
if not _is_pytreespec_instance(treespec):
raise AssertionError(
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
)
assert _is_pytreespec_instance(treespec)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)

View File

@ -87,18 +87,12 @@ class DeviceContext(TorchFunctionMode):
# or else someone else has popped it!
for _ in range(_len_torch_function_stack() - 1):
mode = _pop_mode()
if isinstance(mode, DeviceContext):
raise AssertionError(
"Found nested DeviceContext on the mode stack where none expected"
)
assert not isinstance(mode, DeviceContext)
cur_stack.append(mode)
if _len_torch_function_stack() > 0:
mode = _pop_mode()
if not isinstance(mode, DeviceContext):
raise AssertionError(
"Expected a DeviceContext at the bottom of the mode stack"
)
assert isinstance(mode, DeviceContext)
for mode in reversed(cur_stack):
_push_mode(mode)

View File

@ -31,8 +31,7 @@ def cache_method(
@functools.wraps(f)
def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T:
if kwargs:
raise AssertionError("cache_method does not accept keyword arguments")
assert not kwargs
if not (cache := getattr(self, cache_name, None)):
cache = {}
setattr(self, cache_name, cache)

View File

@ -1,12 +1,11 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import functools
import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
from typing import Optional, overload, Protocol, Union
from typing_extensions import TypeIs
import torch
@ -21,10 +20,6 @@ from torch._C import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
@ -88,8 +83,7 @@ class TorchDispatchMode:
def __init__(self, _dispatch_key=None):
if _dispatch_key is not None:
if not isinstance(_dispatch_key, torch._C.DispatchKey):
raise AssertionError("_dispatch_key must be a torch._C.DispatchKey")
assert isinstance(_dispatch_key, torch._C.DispatchKey)
self.__dict__["_dispatch_key"] = _dispatch_key
self.old_dispatch_mode_flags: deque[bool] = deque()
@ -218,24 +212,16 @@ def _get_current_dispatch_mode() -> Optional[TorchDispatchMode]:
def _detect_infra_mode(key):
if key not in (
assert key in [
torch._C._TorchDispatchModeKey.FUNCTIONAL,
torch._C._TorchDispatchModeKey.PROXY,
):
raise AssertionError(
f"key must be either FUNCTIONAL ({torch._C._TorchDispatchModeKey.FUNCTIONAL}) \
or PROXY ({torch._C._TorchDispatchModeKey.PROXY}) _TorchDispatchModeKey, \
got {key}"
)
]
from torch._ops import _get_dispatch_mode_pre_dispatch
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
post_dispatch_mode = torch._C._get_dispatch_mode(key)
if pre_dispatch_mode is not None and post_dispatch_mode is not None:
raise AssertionError(
"At most one of pre_dispatch_mode and post_dispatch_mode may be active"
)
assert (pre_dispatch_mode is None) or (post_dispatch_mode is None)
if pre_dispatch_mode is None:
return post_dispatch_mode
@ -261,13 +247,10 @@ def _unset_infra_mode(key):
def _disable_infra_mode(key):
if key not in (
assert key in (
torch._C._TorchDispatchModeKey.FUNCTIONAL,
torch._C._TorchDispatchModeKey.PROXY,
):
raise AssertionError(
"key must be either FUNCTIONAL or PROXY _TorchDispatchModeKey"
)
)
mode_unset = _unset_infra_mode(key)
try:
yield mode_unset
@ -288,10 +271,7 @@ def _get_current_dispatch_mode_stack() -> list[TorchDispatchMode]:
def _push_mode(mode: TorchDispatchMode):
k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
if k is not None and k != torch._C.DispatchKey.PreDispatch:
raise AssertionError(
"mode._dispatch_key must be None or DispatchKey.PreDispatch"
)
assert k is None or k == torch._C.DispatchKey.PreDispatch
if k is None:
_push_on_torch_dispatch_stack(mode)
return
@ -434,7 +414,7 @@ class TensorWithFlatten(Protocol):
@overload
def to(
self,
device: Optional[torch._prims_common.DeviceLikeType] = None,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
@ -529,16 +509,14 @@ def transform_subclass(t, callback, outer_size=None, outer_stride=None):
# NB: Purposefully guard here to simplify the inner / outer symbols.
# Using sym_eq() for symbolic comparison can result in an expression that's too
# difficult to guard on, so we use == here.
if sub.shape != outer_size:
raise AssertionError(
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
f"shape equal to {outer_size}, but got: {sub.shape}"
)
if sub.stride() != outer_stride:
raise AssertionError(
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
f"stride equal to {outer_stride}, but got: {sub.stride()}"
)
assert sub.shape == outer_size, (
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
f"shape equal to {outer_size}, but got: {sub.shape}"
)
assert sub.stride() == outer_stride, (
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
f"stride equal to {outer_stride}, but got: {sub.stride()}"
)
return sub
@ -555,12 +533,9 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
It does this by unsafely overwriting the storage field of the output tensor
to be the same storage as the input.
"""
if not isinstance(func, torch._ops.OpOverload):
raise AssertionError(f"func must be an OpOverload, got {type(args)}")
if not isinstance(args, tuple):
raise AssertionError(f"args must be a tuple, got {type(args)}")
if not isinstance(outs, (list, tuple)):
raise AssertionError(f"outs must be a list or tuple, got {type(args)}")
assert isinstance(func, torch._ops.OpOverload)
assert isinstance(args, tuple)
assert isinstance(outs, (list, tuple))
def alias_non_inplace_storage(arg, ret):
# This is hopefully a reasonable assert:
@ -581,11 +556,10 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
):
ret_list = ret if isinstance(ret, list) else [ret]
for r in ret_list:
if type(arg) is not type(r):
raise AssertionError(
f"Called {str(func)} with input of type {type(arg)}\n"
f"and output of type {type(ret)}. But expected types to match."
)
assert type(arg) is type(
r
), f"""Called {str(func)} with input of type {type(arg)}
and output of type {type(ret)}. But expected types to match."""
# Need to call a non-dispatcher helper, because we explicitly do **not**
# want our subclass to intercept the set_() call.
# instead, our subclass should directly have its storage swapped out.
@ -601,8 +575,7 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
for r in ret:
torch._functionalize_unsafe_set(r, arg)
else:
if not isinstance(ret, torch.Tensor):
raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch._functionalize_unsafe_set(ret, arg)
for arg_idx, schema_arg in enumerate(schema_info.args):
@ -646,10 +619,7 @@ def get_alias_info(func) -> SchemaInfo:
# properly for some ops that output tensorlists)
if func.namespace == "aten":
torchgen_schema_str = str(func._schema)
if not torchgen_schema_str.startswith("aten::"):
raise AssertionError(
"Expected torchgen schema string to start with 'aten::'"
)
assert torchgen_schema_str.startswith("aten::")
# remove the aten:: namespace, which is added by the torchscript parser,
# and torchgen doesn't know how to handle
torchgen_schema_str = torchgen_schema_str[6:]
@ -712,64 +682,6 @@ def get_alias_info(func) -> SchemaInfo:
return schema_info
def autograd_would_have_decomposed(
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
) -> bool:
"""
Suppose that an operator has CompositeImplicitAutograd decomp registered.
Would autograd have used this decomposition? It will only use it if there
isn't an explicit backend registration for the device as well. This function
will tell if this would have occurred.
Why do we need to apply these decompositions later? When inference mode is
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
on the decomposition have been applied. It's easy to accidentally never apply
the decomposition, resulting in an operator showing up in a graph that
is unexpected.
Why do we need to AVOID applying the decomposition when autograd wouldn't
have decomposed? If autograd doesn't decompose, this means in eager mode
we would have run the fused kernel. It must be possible to trace this
fused kernel directly into the graph for fidelity with eager (NB: a user
has the option of then further decomposing at proxy tensor mode via
decomposition table, but we must preserve it to proxy mode to have the
choice.)
Why does functionalization need to also perform the test here? This is
because some CompositeImplicitAutograd decompositions are not functional.
If we are eventually going to decompose, we need to do this while we can
still turn functionalization back on, so those decompositions get functionalized.
So an early decomposition in functionalization may still be necessary. Note that
if proxy tensor decomposition process could turn functionalization back on, this
wouldn't be necessary, and maybe that is a useful thing to do anyway because
the decomposition table is user specified and a user could violate the functional
decomp requirement with a bad decomp. If this happened, then you could always
pass through functionalization.
"""
has_backend_registration = False
for a in flat_args:
if isinstance(a, torch.Tensor):
backend_key = torch._C._parse_dispatch_key(
torch._C._dispatch_key_for_device(a.device.type)
)
if backend_key is None:
raise AssertionError(
"Failed to infer backend dispatch key from tensor device"
)
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
# but this one checks py_impl and CompositeImplicitAutograd
# incorrectly shows up as has backend reg here
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), backend_key
)
# in theory we should take all backend keys and take the highest priority one
# to properly mimic the dispatcher,
# this just grabs the first tensor and takes its device key
break
return not has_backend_registration
# See NOTE[SchemaInfo int_tags] above.
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
@ -799,8 +711,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
if not alias_set or not x.is_write:
return None
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
if len(alias_set) != 1:
raise AssertionError("Expected alias_set to contain exactly one element")
assert len(alias_set) == 1
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
# set of size 1 on Python 3.13.
return next(iter(alias_set))
@ -814,10 +725,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
]
# For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
if len(arg_indices) != 1:
raise AssertionError(
"Expected exactly one argument index for the given output alias"
)
assert len(arg_indices) == 1
idx = arg_indices[0]
arg_info = schema_info.args[idx]
if arg_info.name is not None and arg_info.name in new_kwargs:
@ -843,10 +751,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
]
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
# there is only a single argument that gets its metadata mutated.
if len(mutated_args) != 1:
raise AssertionError(
"expected exactly one mutated arg for inplace_view ops"
)
assert len(mutated_args) == 1
# This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
# but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
# so we don't actually need to update the metadata (and attempting to do so causes errors)

View File

@ -476,8 +476,7 @@ def _is_constant_holder(spec: "TreeSpec") -> bool:
def _retrieve_constant(spec: "TreeSpec") -> Any:
"""Given a spec from a pytree registered with register_constant, retrieves the constant"""
if not _is_constant_holder(spec):
raise AssertionError("spec does not correspond to a registered constant pytree")
assert _is_constant_holder(spec)
return tree_unflatten([], spec)
@ -900,25 +899,17 @@ def _defaultdict_serialize(context: Context) -> DumpableContext:
def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
if not isinstance(dumpable_context, dict):
raise AssertionError("dumpable_context must be a dict")
expected_keys = {
assert isinstance(dumpable_context, dict)
assert set(dumpable_context) == {
"default_factory_module",
"default_factory_name",
"dict_context",
}
if set(dumpable_context) != expected_keys:
raise AssertionError(
f"dumpable_context keys must be {expected_keys}, got {set(dumpable_context)}"
)
default_factory_module = dumpable_context["default_factory_module"]
default_factory_name = dumpable_context["default_factory_name"]
if not isinstance(default_factory_module, str):
raise AssertionError("default_factory_module must be a string")
if not isinstance(default_factory_name, str):
raise AssertionError("default_factory_name must be a string")
assert isinstance(default_factory_module, str)
assert isinstance(default_factory_name, str)
module = importlib.import_module(default_factory_module)
default_factory = getattr(module, default_factory_name)
@ -1742,8 +1733,7 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
if not isinstance(treespec, TreeSpec):
raise AssertionError("treespec must be a TreeSpec")
assert isinstance(treespec, TreeSpec)
if tree_is_leaf(tree, is_leaf=is_leaf):
return [tree] * treespec.num_leaves

View File

@ -206,8 +206,7 @@ class CapturedTraceback:
import torch._C._profiler
if script or cpp:
if skip != 0:
raise AssertionError("skip with script/cpp NYI")
assert skip == 0, "skip with script/cpp NYI"
return CapturedTraceback(
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),

View File

@ -430,8 +430,9 @@ def _get_custom_mod_func(func_name: str):
it is marked as private. It is a convenience function for backend implementers to
more easily call the hooks into their backend extensions.
"""
if not isinstance(func_name, str):
raise AssertionError(f"func_name must be `str`, but got `{type(func_name)}`.")
assert isinstance(func_name, str), (
f"func_name must be `str`, but got `{type(func_name)}`."
)
backend_name = _get_privateuse1_backend_name()
custom_device_mod = getattr(torch, backend_name, None)
function = getattr(custom_device_mod, func_name, None)

View File

@ -119,12 +119,10 @@ def bundle_inputs(
# Fortunately there is a function in _recursive that does exactly that conversion.
cloned_module = wrap_cpp_module(clone)
if isinstance(inputs, dict):
if not isinstance(info, dict) and info is not None:
raise AssertionError("If inputs is a dict, info must be a dict or None")
assert isinstance(info, dict) or info is None
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
else:
if not isinstance(info, list) and info is not None:
raise AssertionError("If inputs is a list, info must be a list or None")
assert isinstance(info, list) or info is None
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
return cloned_module

View File

@ -1034,10 +1034,8 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
out += f"{line['filename']}:{line['line']}:{line['name']}\n"
out += "\n\n"
return out
if capture_logs_fwd.logs is None:
raise AssertionError("capture_logs_fwd.logs is None")
if capture_logs_recompute.logs is None:
raise AssertionError("capture_logs_recompute.logs is None")
assert capture_logs_fwd.logs is not None
assert capture_logs_recompute.logs is not None
raise CheckpointError(
_checkpoint_error_template.format(
forward_traces=get_str_tb("original", capture_logs_fwd),
@ -1075,14 +1073,12 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def pack_hook(x):
x = x.detach() if x.requires_grad else x
target_frame = target_frame_ref()
if target_frame is None:
raise AssertionError("Internal error: target_frame reference is None")
assert target_frame is not None # appease mypy
recomp_idx = target_frame.recomp_counter[gid]
target_frame.recomp_counter[gid] += 1
if recomp_idx >= len(target_frame.weak_holders):
if target_frame.early_stop:
raise AssertionError("Unexpected state: target_frame.early_stop is set")
assert not target_frame.early_stop
if not target_frame.forward_completed:
# We run into this case when early stop is not enabled and do
# grad within checkpoint.
@ -1519,14 +1515,12 @@ def _checkpoint_without_reentrant_generator(
device_module = _get_device_module(device_type)
forward_context, recompute_context = context_fn()
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
if (
not isinstance(forward_context, TorchDispatchMode)
or not isinstance(recompute_context, TorchDispatchMode)
):
raise AssertionError(
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` "
"must generate a tuple of two `TorchDispatchMode`s."
)
assert (
isinstance(forward_context, TorchDispatchMode) and
isinstance(recompute_context, TorchDispatchMode)
), \
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
"must generate a tuple of two `TorchDispatchMode`s."
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)

View File

@ -290,8 +290,7 @@ def _get_icpx_version() -> str:
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
version = ['0', '0', '0'] if match is None else list(match.groups())
version = list(map(int, version))
if len(version) != 3:
raise AssertionError("Failed to parse DPC++ compiler version")
assert len(version) == 3, "Failed to parse DPC++ compiler version"
# Aligning version format with what torch.version.xpu() returns
return f"{version[0]}{version[1]:02}{version[2]:02}"
@ -325,8 +324,7 @@ def _get_sycl_device_flags(cflags):
# We need last occurrence of -fsycl-targets as it will be the one taking effect.
# So searching in reversed list.
flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')]
if not flags:
raise AssertionError("bug: -fsycl-targets should have been amended to cflags")
assert flags, "bug: -fsycl-targets should have been amended to cflags"
arch_list = _get_sycl_arch_list()
if arch_list != '':
@ -664,8 +662,7 @@ class BuildExtension(build_ext):
extension = next(extension_iter, None)
if sycl_ext:
if not self.use_ninja:
raise AssertionError("ninja is required to build sycl extensions.")
assert self.use_ninja, "ninja is required to build sycl extensions."
if cuda_ext and not IS_HIP_EXTENSION:
_check_cuda_version(compiler_name, compiler_version)
@ -697,10 +694,7 @@ class BuildExtension(build_ext):
self._define_torch_extension_name(extension)
if 'nvcc_dlink' in extension.extra_compile_args:
if not self.use_ninja:
raise AssertionError(
f"With dlink=True, ninja is required to build cuda extension {extension.name}."
)
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
# Register .cu, .cuh, .hip, .mm and .sycl as valid source extensions.
# NOTE: At the moment .sycl is not a standard extension for SYCL supported
@ -2659,11 +2653,9 @@ def _import_module_from_library(module_name, path, is_python_module):
if is_python_module:
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
spec = importlib.util.spec_from_file_location(module_name, filepath)
if spec is None:
raise AssertionError(f"Failed to create spec for module {module_name} at {filepath}")
assert spec is not None
module = importlib.util.module_from_spec(spec)
if not isinstance(spec.loader, importlib.abc.Loader):
raise AssertionError("spec.loader is not a valid importlib Loader")
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(module)
return module
else:
@ -2865,10 +2857,8 @@ e.
ldflags = sanitize_flags(ldflags)
# Sanity checks...
if len(sources) != len(objects):
raise AssertionError("sources and objects lists must be the same length")
if len(sources) == 0:
raise AssertionError("At least one source is required to build a library")
assert len(sources) == len(objects)
assert len(sources) > 0
compiler = get_cxx_compiler()

View File

@ -133,8 +133,9 @@ def from_dlpack(
if device is not None:
if isinstance(device, str):
device = torch.device(device)
if not isinstance(device, torch.device):
raise AssertionError(f"from_dlpack: unsupported device type: {type(device)}")
assert isinstance(device, torch.device), (
f"from_dlpack: unsupported device type: {type(device)}"
)
kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device)
ext_device = ext_tensor.__dlpack_device__()
@ -162,10 +163,10 @@ def from_dlpack(
dlpack = ext_tensor.__dlpack__(**kwargs)
else:
if device is not None or copy is not None:
raise AssertionError(
"device and copy kwargs not supported when ext_tensor is already a DLPack capsule."
)
assert device is None and copy is None, (
"device and copy kwargs not supported when ext_tensor is "
"already a DLPack capsule."
)
# Old versions just call the converter
dlpack = ext_tensor
return torch._C._from_dlpack(dlpack)

View File

@ -62,8 +62,7 @@ def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
# Inputs contains the shapes of two matrices.
m, k = a_shape
k2, n = b_shape
if k != k2:
raise AssertionError(f"matmul: inner dimensions must match (k == k2), got {k} and {k2}")
assert k == k2
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
return m * n * 2 * k
@ -79,10 +78,8 @@ def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
# Inputs contains the shapes of two tensor.
b, m, k = a_shape
b2, k2, n = b_shape
if b != b2:
raise AssertionError(f"bmm: batch dimensions must match (b == b2), got {b} and {b2}")
if k != k2:
raise AssertionError(f"bmm: inner dimensions must match (k == k2), got {k} and {k2}")
assert b == b2
assert k == k2
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
flop = b * m * n * 2 * k
return flop
@ -269,8 +266,7 @@ def sdpa_flop_count(query_shape, key_shape, value_shape):
b, h, s_q, d_q = query_shape
_b2, _h2, s_k, _d2 = key_shape
_b3, _h3, _s3, d_v = value_shape
if not b == _b2 == _b3 or not h == _h2 == _h3 or not d_q == _d2 or not s_k == _s3 or not d_q == _d2:
raise AssertionError("sdpa_flop_count: query/key/value shapes are incompatible")
assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
total_flops = 0
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
@ -324,21 +320,15 @@ def _unpack_flash_attention_nested_shapes(
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
# So the flops calculation in this case is an overestimate of the actual flops.
if len(key.shape) != 3:
raise AssertionError("sdpa_flop_count: expected key.shape to be 3-dimensional")
if len(value.shape) != 3:
raise AssertionError("sdpa_flop_count: expected value.shape to be 3-dimensional")
if grad_out is not None and grad_out.shape != query.shape:
raise AssertionError("sdpa_flop_count: grad_out.shape must match query.shape when provided")
assert len(key.shape) == 3
assert len(value.shape) == 3
assert grad_out is None or grad_out.shape == query.shape
_, h_q, d_q = query.shape
_, h_k, d_k = key.shape
_, h_v, d_v = value.shape
if cum_seq_q is None:
raise AssertionError("sdpa_flop_count: cum_seq_q must not be None")
if cum_seq_k is None:
raise AssertionError("sdpa_flop_count: cum_seq_k must not be None")
if cum_seq_q.shape != cum_seq_k.shape:
raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
assert cum_seq_q is not None
assert cum_seq_k is not None
assert cum_seq_q.shape == cum_seq_k.shape
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
@ -378,22 +368,15 @@ def _unpack_efficient_attention_nested_shapes(
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
# So the flops calculation in this case is an overestimate of the actual flops.
if len(key.shape) != 4:
raise AssertionError("_unpack_efficient_attention_nested_shapes: expected key.shape to be 4-dimensional")
if len(value.shape) != 4:
raise AssertionError("_unpack_efficient_attention_nested_shapes: expected value.shape to be 4-dimensional")
if grad_out is not None and grad_out.shape != query.shape:
raise AssertionError("_unpack_efficient_attention_nested_shapes: grad_out.shape must match query.shape when provided")
assert len(key.shape) == 4
assert len(value.shape) == 4
assert grad_out is None or grad_out.shape == query.shape
_, _, h_q, d_q = query.shape
_, _, h_k, d_k = key.shape
_, _, h_v, d_v = value.shape
if cu_seqlens_q is None:
raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_q must not be None")
if cu_seqlens_k is None:
raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_k must not be None")
if cu_seqlens_q.shape != cu_seqlens_k.shape:
raise AssertionError("_unpack_efficient_attention_nested_shapes: "
"cu_seqlens_q and cu_seqlens_k must have the same shape")
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert cu_seqlens_q.shape == cu_seqlens_k.shape
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
for len_q, len_k in zip(seqlens_q, seqlens_k):
@ -477,10 +460,8 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape
_b2, _h2, s_k, _d2 = key_shape
_b3, _h3, _s3, d_v = value_shape
_b4, _h4, _s4, _d4 = grad_out_shape
if not b == _b2 == _b3 == _b4 or not h == _h2 == _h3 == _h4 or not d_q == _d2:
raise AssertionError("sdpa_backward_flop_count: batch/heads/dimension mismatch among tensors")
if not d_v == _d4 or not s_k == _s3 or not s_q == _s4:
raise AssertionError("sdpa_backward_flop_count: grad_out/value/key/query shapes are incompatible")
assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
assert d_v == _d4 and s_k == _s3 and s_q == _s4
total_flops = 0
# Step 1: We recompute the scores matrix.
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
@ -761,8 +742,7 @@ class FlopCounterMode:
return self
def __exit__(self, *args):
if self.mode is None:
raise AssertionError("Internal error: FlopCounter.__exit__ called but mode is None")
assert self.mode is not None
b = self.mode.__exit__(*args)
self.mode = None # break cycles
self.mod_tracker.__exit__()

View File

@ -238,8 +238,7 @@ class BackwardHook:
self.grad_outputs = None
if local_grad_outputs is not None:
if self.output_tensors_index is None:
raise AssertionError("output_tensors_index should not be None when grad_outputs is not None")
assert self.output_tensors_index is not None # mypy
return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
grad_fn.register_hook(hook)

View File

@ -137,12 +137,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
def __init__(self, dense_module):
super().__init__()
if dense_module.training:
raise AssertionError("Only support eval mode batchnorm for mkldnn path now")
if not dense_module.track_running_stats:
raise AssertionError("Only support track_running_stats=True for mkldnn path now")
if not dense_module.affine:
raise AssertionError("Only support affine=True for mkldnn path now")
assert not dense_module.training
assert dense_module.track_running_stats
assert dense_module.affine
if dense_module.momentum is None:
self.exponential_average_factor = 0.0
@ -207,9 +204,8 @@ class MkldnnPrelu(torch.jit.ScriptModule):
return y
def to_mkldnn(module, dtype=torch.float):
if dtype not in (torch.float, torch.bfloat16, torch.half):
raise AssertionError("MKLDNN only support float, bfloat16, and half path now")
assert dtype in [torch.float, torch.bfloat16, torch.half], \
"MKLDNN only support float, bfloat16, and half path now"
def m_fn(m, d):
if isinstance(m, torch.nn.Linear):

View File

@ -5,8 +5,7 @@ import torch._C
def format_time(time_us=None, time_ms=None, time_s=None):
"""Define time formatting."""
if time_us is not None or time_ms is not None or time_s is not None:
raise AssertionError("Expected at least one of time_us, time_ms, time_s is not None.")
assert sum([time_us is not None, time_ms is not None, time_s is not None]) == 1
US_IN_SECOND = 1e6
US_IN_MS = 1e3

View File

@ -33,12 +33,12 @@ function version_space() {
};
}
function Segment(addr, size, stream, frames, version) {
return {addr, size, stream, version, frames};
function Segment(addr, size, stream, frames, version, user_metadata) {
return {addr, size, stream, version, frames, user_metadata};
}
function Block(addr, size, requested_size, frames, free_requested, version) {
return {addr, size, requested_size, frames, free_requested, version};
function Block(addr, size, requested_size, frames, free_requested, version, user_metadata) {
return {addr, size, requested_size, frames, free_requested, version, user_metadata};
}
function EventSelector(outer, events, stack_info, memory_view) {
@ -140,7 +140,9 @@ function eventStack(e, allocated, reserved) {
reserved,
)} reserved)\n${event}`;
}
return event + '\n' + format_frames(e.frames);
const user_metadata_str = format_user_metadata(e.user_metadata);
const frames_str = format_frames(e.frames);
return event + '\n' + (user_metadata_str ? user_metadata_str + '\n' : '') + frames_str;
}
function hashCode(num) {
@ -216,6 +218,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
seg.stream,
seg.frames || [],
seg.version,
seg.user_metadata,
),
);
for (const b of seg.blocks) {
@ -229,6 +232,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
b.frames,
b.state === 'active_pending_free',
b.version,
b.user_metadata,
);
}
}
@ -307,6 +311,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
event.frames,
false,
event.version,
event.user_metadata,
);
break;
case 'free_requested':
@ -320,6 +325,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
event.frames,
true,
event.version,
event.user_metadata,
);
break;
case 'alloc':
@ -335,6 +341,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
event.stream,
event.frames,
event.version,
event.user_metadata,
),
);
break;
@ -348,6 +355,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
event.stream,
event.frames,
event.version,
event.user_metadata,
),
);
break;
@ -426,13 +434,17 @@ function MemoryView(outer, stack_info, snapshot, device) {
if (t.internal_free > 0) {
internal = ` (${(t.internal_free / free) * 100}% internal)`;
}
const user_metadata_str = format_user_metadata(t.user_metadata);
const frames_str = format_frames(t.frames);
return (
`s${t.addr.toString(16)}_${t.version}: segment ${formatSize(
t.size,
)} allocated, ` +
`${formatSize(free)} free${internal} (stream ${
t.stream
})\n${format_frames(t.frames)}`
})\n` +
(user_metadata_str ? user_metadata_str + '\n' : '') +
frames_str
);
},
d => {
@ -493,12 +505,15 @@ function MemoryView(outer, stack_info, snapshot, device) {
if (t.free_requested) {
requested = ' (block freed but waiting due to record_stream)';
}
const user_metadata_str = format_user_metadata(t.user_metadata);
const frames_str = format_frames(t.frames);
return (
`b${t.addr.toString(16)}_${t.version} ` +
`${formatSize(t.requested_size)} allocation${requested} (stream ${
t.segment.stream
})\n` +
format_frames(t.frames)
(user_metadata_str ? user_metadata_str + '\n' : '') +
frames_str
);
},
removeStroke,
@ -524,12 +539,15 @@ function MemoryView(outer, stack_info, snapshot, device) {
d => {
addStroke(d);
const t = d.datum();
const user_metadata_str = format_user_metadata(t.user_metadata);
const frames_str = format_frames(t.frames);
return (
`Free space lost due to rounding ${formatSize(
t.size - t.requested_size,
)}` +
` (stream ${t.segment.stream})\n` +
format_frames(t.frames)
(user_metadata_str ? user_metadata_str + '\n' : '') +
frames_str
);
},
removeStroke,
@ -760,6 +778,23 @@ function frameFilter({name, filename}) {
return true;
}
function format_user_metadata(user_metadata) {
if (!user_metadata) {
return '';
}
// Handle string metadata
if (typeof user_metadata === 'string') {
return `User Metadata:\n ${user_metadata}`;
}
// Handle object metadata
if (typeof user_metadata === 'object' && Object.keys(user_metadata).length === 0) {
return '';
}
const metadata_lines = Object.entries(user_metadata)
.map(([key, value]) => ` ${key}: ${value}`);
return 'User Metadata:\n' + metadata_lines.join('\n');
}
function format_frames(frames) {
if (frames.length === 0) {
return (
@ -992,6 +1027,10 @@ function process_alloc_data(snapshot, device, plot_segments, max_entries) {
if (!elem.action.includes('alloc')) {
text = `${text}\nalloc not recorded, stack trace for free:`;
}
const user_metadata_str = format_user_metadata(elem.user_metadata);
if (user_metadata_str) {
text = `${text}\n${user_metadata_str}`;
}
text = `${text}\n${format_frames(elem.frames)}`;
return text;
},

View File

@ -351,16 +351,14 @@ class TensorWeakRef:
ref: WeakRef[Tensor]
def __init__(self, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise AssertionError(f"expected torch.Tensor, got {type(tensor)}.")
assert isinstance(tensor, Tensor)
self.ref = weakref.ref(tensor)
def __call__(self):
out = self.ref()
if out is None:
return out
if not isinstance(out, Tensor):
raise AssertionError(f"expected torch.Tensor, got {type(out)}.")
assert isinstance(out, Tensor)
# TODO, add _fix_weakref type binding
out._fix_weakref() # type: ignore[attr-defined]
return out

View File

@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
) -> list[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'
# Don't generate kernels in mobile build
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
if str(g.view.func.name) == "lift_fresh":
return []
view_str = []
view_str.append(emit_registration_helper(g.view))
if g.view_inplace is not None:
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str