mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
22 Commits
viable/str
...
cpp-docs-d
Author | SHA1 | Date | |
---|---|---|---|
5b6cc8215f | |||
1c43c9cfd0 | |||
102e0d5437 | |||
0bd12c1168 | |||
ce8a7764e2 | |||
d1269a0434 | |||
c87cf1be32 | |||
2fc5e45a41 | |||
f9022ba93b | |||
ff8be889ad | |||
292454942e | |||
6c4412f72b | |||
78bf6186f2 | |||
c40048472c | |||
3dfd0c7584 | |||
e6ba4d0725 | |||
bdf7cb9d9c | |||
6aed378958 | |||
8b3dc0d1b0 | |||
06773663b5 | |||
0bff65503c | |||
21131a2444 |
@ -1,15 +1,11 @@
|
||||
sphinx==5.3.0
|
||||
sphinx==7.2.6
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
#Pinned versions: 7.2.6
|
||||
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
pytorch_sphinx_theme2==0.1.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.1.0
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.34.0
|
||||
breathe==4.36.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.34.0
|
||||
#Pinned versions: 4.36.0
|
||||
|
||||
exhale==0.2.3
|
||||
exhale==0.3.7
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.2.3
|
||||
#Pinned versions: 0.3.7
|
||||
|
||||
docutils==0.16
|
||||
docutils==0.20
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.16
|
||||
#Pinned versions: 0.20
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -56,13 +52,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
myst-nb==1.3.0
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
#Pinned versions: 1.3.0
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-parser==4.0.1
|
||||
|
@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
echo "======================================"
|
||||
echo "ERROR: $undocumented undocumented objects found!"
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Full coverage report:"
|
||||
cat build/coverage/python.txt
|
||||
echo ""
|
||||
echo "======================================"
|
||||
echo "Undocumented modules/objects (lines after TOTAL):"
|
||||
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
|
||||
echo "======================================"
|
||||
echo ""
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
exit 1
|
||||
|
@ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then
|
||||
MEMORY_LIMIT_MAX_JOBS=12
|
||||
NUM_CPUS=$(( $(nproc) - 2 ))
|
||||
|
||||
# Defaults here for **binary** linux builds so they can be changed in one place
|
||||
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
|
||||
if [[ "$(uname)" == Linux ]]; then
|
||||
# Defaults here for **binary** linux builds so they can be changed in one place
|
||||
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
|
||||
else
|
||||
# For other builds
|
||||
export MAX_JOBS=${NUM_CPUS}
|
||||
fi
|
||||
|
||||
cat >>"$envfile" <<EOL
|
||||
export MAX_JOBS="${MAX_JOBS}"
|
||||
|
11
.github/actionlint.yaml
vendored
11
.github/actionlint.yaml
vendored
@ -54,17 +54,12 @@ self-hosted-runner:
|
||||
- windows-11-arm64
|
||||
- windows-11-arm64-preview
|
||||
# Organization-wide AMD-hosted runners
|
||||
# MI2xx non-ARC runners
|
||||
# MI2xx runners
|
||||
- linux.rocm.gpu
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.2
|
||||
- linux.rocm.gpu.4
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.gfx1100
|
||||
# MI2xx ARC runners
|
||||
- linux.rocm.gpu.mi250.1
|
||||
- linux.rocm.gpu.mi250.2
|
||||
- linux.rocm.gpu.mi250.4
|
||||
# gfx942 ARC runners
|
||||
# gfx942 runners
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
|
@ -79,9 +79,9 @@ jobs:
|
||||
runs-on: "windows-11-arm64-preview"
|
||||
{%- else %}
|
||||
{%- if branches == "nightly" %}
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
{%- else %}
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}
|
||||
|
8
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
8
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
libtorch-cpu-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -291,7 +291,7 @@ jobs:
|
||||
libtorch-cuda12_6-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -541,7 +541,7 @@ jobs:
|
||||
libtorch-cuda12_8-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -791,7 +791,7 @@ jobs:
|
||||
libtorch-cuda13_0-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
8
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
8
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
libtorch-cpu-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -291,7 +291,7 @@ jobs:
|
||||
libtorch-cuda12_6-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -541,7 +541,7 @@ jobs:
|
||||
libtorch-cuda12_8-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -791,7 +791,7 @@ jobs:
|
||||
libtorch-cuda13_0-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
70
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
70
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
wheel-py3_10-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -279,7 +279,7 @@ jobs:
|
||||
wheel-py3_10-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -517,7 +517,7 @@ jobs:
|
||||
wheel-py3_10-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -755,7 +755,7 @@ jobs:
|
||||
wheel-py3_10-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -993,7 +993,7 @@ jobs:
|
||||
wheel-py3_10-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -1229,7 +1229,7 @@ jobs:
|
||||
wheel-py3_11-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -1464,7 +1464,7 @@ jobs:
|
||||
wheel-py3_11-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -1702,7 +1702,7 @@ jobs:
|
||||
wheel-py3_11-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -1940,7 +1940,7 @@ jobs:
|
||||
wheel-py3_11-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -2178,7 +2178,7 @@ jobs:
|
||||
wheel-py3_11-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -2414,7 +2414,7 @@ jobs:
|
||||
wheel-py3_12-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -2649,7 +2649,7 @@ jobs:
|
||||
wheel-py3_12-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -2887,7 +2887,7 @@ jobs:
|
||||
wheel-py3_12-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -3125,7 +3125,7 @@ jobs:
|
||||
wheel-py3_12-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -3363,7 +3363,7 @@ jobs:
|
||||
wheel-py3_12-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -3599,7 +3599,7 @@ jobs:
|
||||
wheel-py3_13-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -3834,7 +3834,7 @@ jobs:
|
||||
wheel-py3_13-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -4072,7 +4072,7 @@ jobs:
|
||||
wheel-py3_13-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -4310,7 +4310,7 @@ jobs:
|
||||
wheel-py3_13-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -4548,7 +4548,7 @@ jobs:
|
||||
wheel-py3_13-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -4784,7 +4784,7 @@ jobs:
|
||||
wheel-py3_13t-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -5019,7 +5019,7 @@ jobs:
|
||||
wheel-py3_13t-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -5257,7 +5257,7 @@ jobs:
|
||||
wheel-py3_13t-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -5495,7 +5495,7 @@ jobs:
|
||||
wheel-py3_13t-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -5733,7 +5733,7 @@ jobs:
|
||||
wheel-py3_13t-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -5969,7 +5969,7 @@ jobs:
|
||||
wheel-py3_14-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -6204,7 +6204,7 @@ jobs:
|
||||
wheel-py3_14-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -6442,7 +6442,7 @@ jobs:
|
||||
wheel-py3_14-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -6680,7 +6680,7 @@ jobs:
|
||||
wheel-py3_14-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -6918,7 +6918,7 @@ jobs:
|
||||
wheel-py3_14-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -7154,7 +7154,7 @@ jobs:
|
||||
wheel-py3_14t-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -7389,7 +7389,7 @@ jobs:
|
||||
wheel-py3_14t-cuda12_6-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -7627,7 +7627,7 @@ jobs:
|
||||
wheel-py3_14t-cuda12_8-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -7865,7 +7865,7 @@ jobs:
|
||||
wheel-py3_14t-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -8103,7 +8103,7 @@ jobs:
|
||||
wheel-py3_14t-xpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
12
.github/workflows/rocm.yml
vendored
12
.github/workflows/rocm.yml
vendored
@ -36,12 +36,12 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
@ -39,7 +39,7 @@ struct HostBlock {
|
||||
};
|
||||
|
||||
template <typename B>
|
||||
struct alignas(64) FreeBlockList {
|
||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
|
||||
std::mutex mutex_;
|
||||
std::deque<B*> list_;
|
||||
};
|
||||
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
|
||||
// Struct containing memory allocator summary statistics for host, as they
|
||||
// are staged for reporting. This is a temporary struct that is used to
|
||||
// avoid locking the allocator while collecting stats.
|
||||
struct alignas(64) HostStatsStaged {
|
||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
|
||||
std::mutex timing_mutex_;
|
||||
// COUNT: total allocations (active + free)
|
||||
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
|
||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
|
||||
}
|
||||
|
||||
alignas(64) std::mutex blocks_mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
|
||||
ska::flat_hash_set<B*> blocks_; // block list
|
||||
ska::flat_hash_map<void*, B*> ptr_to_block_;
|
||||
|
||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
|
||||
// size. This allows us to quickly find a free block of the right size.
|
||||
// We use deque to store per size free list and guard the list with its own
|
||||
// mutex.
|
||||
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
|
||||
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
|
||||
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
|
||||
|
||||
alignas(64) std::mutex events_mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(64) HostStatsStaged stats_;
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
struct TORCH_API HostAllocator : public at::Allocator {
|
||||
|
@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
|
||||
try {
|
||||
mkldnn_matmul_i8i8i32(self, mat2, result);
|
||||
dispatched = true;
|
||||
} catch (const std::exception& e) {
|
||||
} catch ([[maybe_unused]] const std::exception& e) {
|
||||
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
|
||||
}
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
[&]() {
|
||||
using scalar_t =
|
||||
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
|
||||
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
cpu_kernel_vec(iter,
|
||||
|
@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
if (ret_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg0_t == rt_binary_specializations[arg_index][1] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][2])
|
||||
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
|
||||
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
|
||||
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
|
||||
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
|
||||
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
|
||||
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
|
||||
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][2]>::t)>(
|
||||
cret_t,
|
||||
carg0_t,
|
||||
carg1_t>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -202,7 +202,6 @@ supported:
|
||||
- select_backward
|
||||
- _trilinear
|
||||
- linalg_pinv.atol_rtol_tensor
|
||||
- svd
|
||||
- logsumexp.out
|
||||
symint:
|
||||
- empty.memory_format
|
||||
|
@ -9,6 +9,7 @@
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/alignment.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
||||
// where we would like to support composite implicit kernels but not
|
||||
// explicit kernels therefore we manually add the key to the
|
||||
// math_dispatch_keyset
|
||||
DispatchKeySet{DispatchKey::NestedTensor};
|
||||
DispatchKeySet{DispatchKey::NestedTensor} |
|
||||
// Functionalize should always reuse CompositeImplicit decomps.
|
||||
DispatchKeySet{DispatchKey::Functionalize};
|
||||
|
||||
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
DispatchKeySet(
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <new>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096;
|
||||
// since the default thp pagesize is 2MB, enable thp only
|
||||
// for buffers of size 2MB or larger to avoid memory bloating
|
||||
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
|
||||
|
||||
// Cache line size used to avoid false sharing between threads. Falls back to 64
|
||||
// bytes if C++17 feature is unavailable.
|
||||
#ifdef __cpp_lib_hardware_interference_size
|
||||
using std::hardware_destructive_interference_size;
|
||||
#else
|
||||
constexpr std::size_t hardware_destructive_interference_size = 64;
|
||||
#endif
|
||||
} // namespace c10
|
||||
|
@ -941,7 +941,7 @@ class EventPool {
|
||||
|
||||
private:
|
||||
struct PerDevicePool {
|
||||
alignas(64) std::mutex mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex mutex_;
|
||||
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
|
||||
};
|
||||
std::vector<PerDevicePool> pools_;
|
||||
@ -3758,11 +3758,6 @@ static void uncached_delete(void* ptr) {
|
||||
static void local_raw_delete(void* ptr);
|
||||
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
|
||||
thread_local std::string DeviceCachingAllocator::user_metadata;
|
||||
#ifdef __cpp_lib_hardware_interference_size
|
||||
using std::hardware_destructive_interference_size;
|
||||
#else
|
||||
static constexpr std::size_t hardware_destructive_interference_size = 64;
|
||||
#endif
|
||||
|
||||
class NativeCachingAllocator : public CUDAAllocator {
|
||||
private:
|
||||
|
@ -554,7 +554,7 @@ static void local_raw_delete(void* ptr);
|
||||
|
||||
class XPUAllocator : public DeviceAllocator {
|
||||
private:
|
||||
std::mutex mutex;
|
||||
alignas(hardware_destructive_interference_size) std::mutex mutex;
|
||||
ska::flat_hash_map<void*, Block*> allocated_blocks;
|
||||
|
||||
void add_allocated_block(Block* block) {
|
||||
|
@ -16,7 +16,7 @@ find_path(vecLib_INCLUDE_DIR vecLib.h
|
||||
DOC "vecLib include directory"
|
||||
PATHS /System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
|
||||
/System/Library/${__veclib_include_suffix}
|
||||
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
|
||||
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
|
||||
${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
|
@ -207,6 +207,42 @@ templates_path = [
|
||||
]
|
||||
# TODO: document these and remove them from here.
|
||||
|
||||
# Fixes the duplicated
|
||||
autosummary_filename_map = {
|
||||
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
|
||||
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
|
||||
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
|
||||
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
|
||||
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
|
||||
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
|
||||
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
|
||||
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
|
||||
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
|
||||
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
|
||||
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
|
||||
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
|
||||
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
|
||||
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
|
||||
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
|
||||
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
|
||||
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
|
||||
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
|
||||
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
|
||||
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
|
||||
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
|
||||
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
|
||||
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
|
||||
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
|
||||
"torch.mtia.stream": "torch.mtia.stream_function",
|
||||
"torch.mtia.Stream": "torch.mtia.Stream_class",
|
||||
"torch.cpu.stream": "torch.cpu.stream_function",
|
||||
"torch.cpu.Stream": "torch.cpu.Stream_class",
|
||||
"torch.cuda.stream": "torch.cuda.stream_function",
|
||||
"torch.cuda.Stream": "torch.cuda.Stream_class",
|
||||
"torch.xpu.stream": "torch.xpu.stream_function",
|
||||
"torch.xpu.Stream": "torch.xpu.Stream_class",
|
||||
}
|
||||
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
|
||||
autodoc_default_options = {
|
||||
"exclude-members": "from_bytes, to_bytes",
|
||||
}
|
||||
|
||||
# -- katex javascript in header
|
||||
#
|
||||
# def setup(app):
|
||||
|
@ -253,7 +253,6 @@ regular full-precision tensor.
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
view
|
||||
as_strided
|
||||
|
@ -15192,6 +15192,25 @@ graph():
|
||||
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
||||
)
|
||||
|
||||
def test_invalid_pytree_dynamo_graph_capture(self):
|
||||
class Block:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, block):
|
||||
return block.a + block.b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
|
||||
):
|
||||
_dynamo_graph_capture_for_export(Foo())(
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
def test_enum_str(self):
|
||||
class TensorDim(str, enum.Enum):
|
||||
DDP = "ddp"
|
||||
|
@ -318,17 +318,19 @@ class inner_f(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 2)
|
||||
|
||||
def forward(self, x, scale=1.0):
|
||||
def forward(self, x, *, scale):
|
||||
return self.linear(x) * scale
|
||||
|
||||
model = ModuleWithKwargs()
|
||||
inputs = (torch.randn(4, 3),)
|
||||
kwargs = {"scale": 2.0}
|
||||
kwargs = {"scale": torch.tensor(2.0)}
|
||||
|
||||
gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
|
||||
|
||||
with ExitStack() as stack:
|
||||
# Export joint with descriptors
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack, model, inputs, kwargs, decompositions=decomposition_table
|
||||
stack, gm, inputs, kwargs, decompositions=decomposition_table
|
||||
)
|
||||
|
||||
# Test the exported graph structure
|
||||
@ -336,9 +338,17 @@ class inner_f(torch.nn.Module):
|
||||
print_output=False, expanded_def=True
|
||||
)
|
||||
|
||||
# For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
|
||||
# I tried to fix this in normalize_gm but there are too many files
|
||||
# depending on that behavior..
|
||||
graph_code_str = normalize_gm(graph_code)
|
||||
graph_code_str = "\n".join(
|
||||
[line for line in graph_code_str.split("\n") if len(line.rstrip()) > 0]
|
||||
)
|
||||
|
||||
# Expect test on the printed graph
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(graph_code),
|
||||
graph_code_str,
|
||||
"""\
|
||||
class inner_f(torch.nn.Module):
|
||||
def forward(
|
||||
@ -346,19 +356,20 @@ class inner_f(torch.nn.Module):
|
||||
primals,
|
||||
tangents,
|
||||
):
|
||||
primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight')
|
||||
primals_2: "f32[2]" # ParamAOTInput(target='linear.bias')
|
||||
primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight')
|
||||
primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias')
|
||||
primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
|
||||
primals_4: "f32[]" # PlainAOTInput(idx=1)
|
||||
tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
|
||||
primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
|
||||
mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
|
||||
mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
|
||||
mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
|
||||
broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
|
||||
add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
|
||||
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0); add = None
|
||||
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0); tangents_1 = None
|
||||
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4); add = None
|
||||
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4); tangents_1 = primals_4 = None
|
||||
transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
|
||||
mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
|
||||
transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
|
||||
@ -368,12 +379,11 @@ class inner_f(torch.nn.Module):
|
||||
transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
|
||||
return pytree.tree_unflatten([
|
||||
mul_2, # PlainAOTOutput(idx=0)
|
||||
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight'))
|
||||
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias'))
|
||||
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight'))
|
||||
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias'))
|
||||
None, # None
|
||||
None, # None
|
||||
], self._out_spec)
|
||||
""",
|
||||
], self._out_spec)""",
|
||||
)
|
||||
|
||||
# Compile the result
|
||||
|
@ -7356,6 +7356,7 @@ metadata incorrectly.
|
||||
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
|
||||
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_rms_norm(self):
|
||||
# Only CUDA rms norm fails to be decomposed
|
||||
|
@ -1,74 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_CI
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
class Simple(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(16, 1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestAOTInductorWindowsCrossCompilation(TestCase):
|
||||
@requires_gpu()
|
||||
def test_simple_so(self):
|
||||
if is_fbcode() or IS_CI:
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
# TODO: enable in CI
|
||||
with torch.no_grad():
|
||||
device = "cuda"
|
||||
model = Simple().to(device=device)
|
||||
example_inputs = (torch.randn(8, 10, device=device),)
|
||||
batch_dim = torch.export.Dim("batch", min=1, max=1024)
|
||||
exported = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
|
||||
)
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
inductor_configs={
|
||||
"aot_inductor.model_name_for_generated_files": "model",
|
||||
"aot_inductor.cross_target_platform": "windows",
|
||||
"aot_inductor.link_libtorch": False,
|
||||
# TODO: need to add aoti_shim_library_path for CI
|
||||
"aot_inductor.aoti_shim_library": "executorch",
|
||||
# no fallback ops
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON,CPP",
|
||||
"max_autotune_conv_backends": "TRITON,CPP",
|
||||
"aot_inductor.embed_kernel_binary": True,
|
||||
# simplify things for now
|
||||
"aot_inductor.precompile_headers": False,
|
||||
"aot_inductor.package_constants_on_disk_format": "binary_blob",
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(package_path, "r") as zf:
|
||||
zf.extractall(tmpdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests(needs="filelock")
|
@ -9,6 +9,7 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
|
||||
|
||||
@ -77,6 +78,9 @@ class WindowsCrossCompilationTestFramework:
|
||||
"This test should run on Linux for cross-compilation"
|
||||
)
|
||||
|
||||
if is_fbcode():
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
self.assertTrue("WINDOWS_CUDA_HOME" in os.environ)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -128,6 +132,9 @@ class WindowsCrossCompilationTestFramework:
|
||||
if platform.system() != "Windows":
|
||||
raise unittest.SkipTest("This test should run on Windows")
|
||||
|
||||
if is_fbcode():
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
if not HAS_GPU:
|
||||
raise unittest.SkipTest("Test requires GPU")
|
||||
|
||||
|
@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase):
|
||||
@parametrize("format", ("binary", "unpacked"))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("graph_partition", (False, True))
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_basic(
|
||||
self, device: str, format: str, dynamic: bool, graph_partition: bool
|
||||
self,
|
||||
device: str,
|
||||
format: str,
|
||||
dynamic: bool,
|
||||
graph_partition: bool,
|
||||
is_aot: bool,
|
||||
) -> None:
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
|
||||
# AOT mode does not support unpacked format
|
||||
if is_aot and format == "unpacked":
|
||||
raise unittest.SkipTest("AOT mode does not support unpacked format")
|
||||
|
||||
mod = torch.nn.Linear(1, 3, device=device)
|
||||
x = torch.randn(4, 1, device=device)
|
||||
if dynamic:
|
||||
@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase):
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, aot=is_aot
|
||||
)
|
||||
compiled_artifact.save(path=path, format=format)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase):
|
||||
compiled_out = loaded(*concrete_args)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
if not is_aot:
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_call_in_backend(self, dynamic: bool) -> None:
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
|
||||
mod = torch.nn.Linear(1, 3)
|
||||
x = torch.randn(4, 1)
|
||||
if dynamic:
|
||||
@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase):
|
||||
eager_out = f(x)
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
return torch._inductor.standalone_compile(gm, args)
|
||||
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
|
||||
|
||||
with fresh_cache():
|
||||
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
||||
@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_dynamic_shapes_from_graph(self):
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_dynamic_shapes_from_graph(self, is_aot: bool):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes="from_graph"
|
||||
gm, args, dynamic_shapes="from_graph", aot=is_aot
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, x)
|
||||
@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
|
||||
def test_split_module(self):
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_split_module(self, is_aot):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x, a0, a1, b0, b1, c0, c1):
|
||||
x = x + (a0**2) + (a1 / 2)
|
||||
@ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
split = torch.fx.passes.split_module.split_module(gm, gm, split)
|
||||
|
||||
# Each of the split graphs only has one output.
|
||||
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
|
||||
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
|
||||
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
|
||||
ca0 = torch._inductor.standalone_compile(
|
||||
split.submod_0, (a0, x, a1), aot=is_aot
|
||||
)
|
||||
ca1 = torch._inductor.standalone_compile(
|
||||
split.submod_1, (b0, x, b1), aot=is_aot
|
||||
)
|
||||
ca2 = torch._inductor.standalone_compile(
|
||||
split.submod_2, (c0, x, c1), aot=is_aot
|
||||
)
|
||||
|
||||
y = ca0(a0, x, a1)
|
||||
y = ca1(b0, y, b1)
|
||||
y = ca2(c0, y, c1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
||||
if not is_aot:
|
||||
# fx graph cache doesn't run in AOT mode
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
||||
# TODO: split_module causes ca1 and ca2 to have different type annotations
|
||||
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (False, True))
|
||||
@parametrize("config_patches", [True, False])
|
||||
def test_dynamic_shapes_from_example_inputs(self, config_patches):
|
||||
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
(5, torch.ones(4)),
|
||||
dynamic_shapes="from_example_inputs",
|
||||
options={"config_patches": config_patches},
|
||||
aot=is_aot,
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(3, x)
|
||||
@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (True, False))
|
||||
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
|
||||
def test_static_shapes(self, dynamic_shapes):
|
||||
def test_static_shapes(self, dynamic_shapes, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
||||
assert not kwargs
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
|
||||
)
|
||||
x = torch.randn(3)
|
||||
(result,) = compiled_artifact(x)
|
||||
@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (True, False))
|
||||
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
|
||||
def test_backend(self, dynamic_shapes):
|
||||
def test_backend(self, dynamic_shapes, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes=dynamic_shapes
|
||||
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
|
||||
)
|
||||
y = torch.randn(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self):
|
||||
@parametrize("is_aot", (True, False))
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
|
||||
)
|
||||
y = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
|
@ -1543,26 +1543,22 @@ class CPUReproTests(TestCase):
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
inputs = (
|
||||
x,
|
||||
scale,
|
||||
zero_point,
|
||||
use_dequant,
|
||||
use_quant,
|
||||
quant_min,
|
||||
quant_max,
|
||||
dtype,
|
||||
dequant_out_dtype,
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
x,
|
||||
scale,
|
||||
zero_point,
|
||||
use_dequant,
|
||||
use_quant,
|
||||
quant_min,
|
||||
quant_max,
|
||||
dtype,
|
||||
dequant_out_dtype,
|
||||
),
|
||||
)
|
||||
self.common(fn, inputs)
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Check that both main and tail loops are vectorized
|
||||
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
compiled_fn = torch.compile(fn)
|
||||
_, code = run_and_get_cpp_code(compiled_fn, *inputs)
|
||||
FileCheck().check_count("loadu", 2, exactly=True).run(code)
|
||||
|
||||
@requires_vectorization
|
||||
def test_dequant_quant_lowering_uint8(self):
|
||||
self._test_dequant_quant_lowering_helper(torch.uint8)
|
||||
@ -4814,22 +4810,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn((22, 22), dtype=torch.double)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_double_reduction_vec(self):
|
||||
def fn(x):
|
||||
return x.sum(dim=1)
|
||||
@ -4839,22 +4819,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn((22, 22), dtype=torch.double)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_convert_fp32_to_double_vec(self):
|
||||
def fn(x):
|
||||
return x.to(torch.double)
|
||||
@ -4864,22 +4828,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn(22, 22)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::convert<double,2,float,1>", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_convert_double_to_fp32_vec(self):
|
||||
def fn(x):
|
||||
return x.to(torch.float32)
|
||||
@ -4889,22 +4837,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn((22, 22), dtype=torch.double)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::convert<float,1,double,2>", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
|
||||
# https://github.com/pytorch/pytorch/issues/115260
|
||||
p0 = torch.tensor([1.0879], dtype=torch.float16)
|
||||
|
@ -85,7 +85,6 @@ def init_lists():
|
||||
"linalg_inv_ex",
|
||||
"linalg_pinv.atol_rtol_tensor",
|
||||
"logsumexp",
|
||||
"svd",
|
||||
}
|
||||
# For some ops, we don't support all variants. Here we use formatted_name
|
||||
# to uniquely identify the variant.
|
||||
@ -221,15 +220,20 @@ class TestLazyOpInfo(TestCase):
|
||||
torch._lazy.wait_device_ops()
|
||||
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
|
||||
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
|
||||
metrics = remove_suffixes(torch._lazy.metrics.counter_names())
|
||||
cands = [f"{prefix}::{op.name}{symint_suffix}"]
|
||||
# check aliases
|
||||
for alias in op.aliases:
|
||||
cands.append(f"{prefix}::{alias.name}{symint_suffix}")
|
||||
|
||||
self.assertTrue(
|
||||
any(c in metrics for c in cands), f"none of {cands} not found in {metrics}"
|
||||
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
|
||||
torch._lazy.metrics.counter_names()
|
||||
)
|
||||
# check aliases
|
||||
if not found:
|
||||
for alias in op.aliases:
|
||||
alias_found = (
|
||||
f"{prefix}::{alias.name}{symint_suffix}"
|
||||
in remove_suffixes(torch._lazy.metrics.counter_names())
|
||||
)
|
||||
found = found or alias_found
|
||||
if found:
|
||||
break
|
||||
self.assertTrue(found)
|
||||
|
||||
@ops(
|
||||
[
|
||||
|
@ -1258,10 +1258,11 @@ class DecompOneOffTests(TestCase):
|
||||
)
|
||||
|
||||
# check RMSNorm was fused with sinh
|
||||
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
|
||||
in generated_codes[1]
|
||||
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
|
||||
)
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
|
||||
)
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from unittest.mock import patch, MagicMock, ANY
|
||||
import math
|
||||
import itertools
|
||||
import torch.optim as optim
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, largeTensorTest
|
||||
from torch.testing._internal.common_device_type import expectedFailureMPS, instantiate_device_type_tests, onlyCUDA, largeTensorTest
|
||||
from typing import Optional
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
@ -2022,6 +2022,7 @@ class TestSDPA(NNTestCase):
|
||||
for both cpu and cuda. If you're test is only applicable to cuda,
|
||||
add it to TestSDPACudaOnly.
|
||||
"""
|
||||
@expectedFailureMPS # No double support
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
|
||||
|
||||
@ -4625,13 +4626,13 @@ class TestAttnBias(NNTestCase):
|
||||
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
|
||||
|
||||
if NOTEST_CPU:
|
||||
device_types = ("cuda", )
|
||||
device_types = ("cuda", "mps")
|
||||
else:
|
||||
device_types = ("cpu", "cuda")
|
||||
device_types = ("cpu", "cuda", "mps")
|
||||
|
||||
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
|
||||
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
|
||||
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
|
||||
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types, allow_mps=True)
|
||||
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types, allow_mps=True)
|
||||
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
|
||||
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
|
||||
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
|
||||
|
@ -754,6 +754,10 @@ def align_trace_from_beginning(
|
||||
# Rank 3: [0, 1, 2, 3, 4, 5, None]
|
||||
# Then we should start from collective 2 not 0 because any collective before,
|
||||
# we don't have complete records from all ranks so we need to ignore them.
|
||||
# If we don't have any trace from some ranks, ignore them
|
||||
# as well.
|
||||
if len(entries[rank]) == 0:
|
||||
continue
|
||||
first_record_id = entries[rank][0]["record_id"]
|
||||
maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)
|
||||
|
||||
|
@ -404,7 +404,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.max_unpool3d,
|
||||
aten.mish,
|
||||
aten.mish_,
|
||||
aten.mish_backward,
|
||||
aten.mse_loss,
|
||||
aten.mse_loss_backward,
|
||||
aten.multi_margin_loss,
|
||||
@ -420,7 +419,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.native_dropout_backward,
|
||||
aten.native_group_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten._fused_rms_norm,
|
||||
aten._fused_rms_norm_backward,
|
||||
aten.new_empty,
|
||||
aten.new_full,
|
||||
@ -477,7 +475,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.silu,
|
||||
aten.silu_,
|
||||
aten.silu_backward.grad_input,
|
||||
aten.silu_backward,
|
||||
aten.sinc,
|
||||
aten.sinc_,
|
||||
aten.slice_backward,
|
||||
|
@ -1757,61 +1757,6 @@ def native_layer_norm_backward_out(
|
||||
return grad_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm.default)
|
||||
def _fused_rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: list[int],
|
||||
weight: Optional[Tensor],
|
||||
eps: Optional[float],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
dims_to_reduce: list[int] = []
|
||||
for i in range(len(normalized_shape)):
|
||||
dims_to_reduce.append(input.dim() - i - 1)
|
||||
|
||||
# upcast is needed for fp16 and bf16
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
upcasted_input = input.to(computation_dtype)
|
||||
|
||||
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
|
||||
if eps is None:
|
||||
if computation_dtype in (torch.float32, torch.complex64):
|
||||
eps_val = torch.finfo(torch.float32).eps
|
||||
else:
|
||||
eps_val = torch.finfo(torch.float64).eps
|
||||
else:
|
||||
eps_val = eps
|
||||
|
||||
rqrst_input = torch.rsqrt(
|
||||
# NB: don't inplace here, will violate functional IR invariant
|
||||
# NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
|
||||
torch.ops.aten.add.Scalar(
|
||||
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
|
||||
)
|
||||
)
|
||||
|
||||
upcasted_result = upcasted_input.mul(rqrst_input)
|
||||
|
||||
if weight is not None:
|
||||
upcasted_result = upcasted_result.mul(weight)
|
||||
|
||||
# NB: nested should be dead here, just here for fidelity
|
||||
is_nested = input.is_nested or (weight is not None and weight.is_nested)
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
is_channels_last = memory_format in (
|
||||
torch.channels_last,
|
||||
torch.channels_last_3d,
|
||||
)
|
||||
|
||||
if not is_nested and not is_channels_last:
|
||||
upcasted_result = upcasted_result.contiguous()
|
||||
rqrst_input = rqrst_input.contiguous()
|
||||
|
||||
# Cast normalized result back to original input type
|
||||
result = upcasted_result.type_as(input)
|
||||
|
||||
return result, rqrst_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||
def _fused_rms_norm_backward(
|
||||
grad_out: Tensor,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import importlib
|
||||
import inspect
|
||||
@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type
|
||||
from torch._dynamo.package import SystemInfo
|
||||
|
||||
from . import convert_frame
|
||||
from .aot_compile_types import (
|
||||
BundledAOTAutogradSerializableCallable,
|
||||
SerializableCallable,
|
||||
)
|
||||
from .hooks import Hooks
|
||||
|
||||
|
||||
@ -26,18 +29,6 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SerializableCallable(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
def bind_locals(
|
||||
signature: inspect.Signature, *args: Any, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
@ -149,53 +140,6 @@ class AOTCompiledFunction:
|
||||
self._guard_check_enabled = False
|
||||
|
||||
|
||||
class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
"""
|
||||
Represents a serializable callable generated by compile_fx.
|
||||
This class wraps around the compiled function generated by AOTAutograd.
|
||||
|
||||
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
|
||||
this object should be what's *returned* by aot_module_simplified.
|
||||
We'll do that refactor in a later PR.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_fn: Any) -> None:
|
||||
"""
|
||||
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
|
||||
of a compiled function generated by AOTAutograd.
|
||||
"""
|
||||
assert hasattr(compiled_fn, "serialize")
|
||||
self.compiled_fn = compiled_fn
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
if hasattr(self, attr):
|
||||
return getattr(super(), attr)
|
||||
else:
|
||||
return getattr(self.compiled_fn, attr)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, fn: "BundledAOTAutogradSerializableCallable"
|
||||
) -> bytes:
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
result = pickle.dumps(fn.compiled_fn.serialize())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.compiled_fn(*args, **kwargs)
|
||||
|
||||
|
||||
def aot_compile_fullgraph(
|
||||
model: Any,
|
||||
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
|
||||
|
61
torch/_dynamo/aot_compile_types.py
Normal file
61
torch/_dynamo/aot_compile_types.py
Normal file
@ -0,0 +1,61 @@
|
||||
import abc
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SerializableCallable(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
"""
|
||||
Represents a serializable callable generated by compile_fx.
|
||||
This class wraps around the compiled function generated by AOTAutograd.
|
||||
|
||||
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
|
||||
this object should be what's *returned* by aot_module_simplified.
|
||||
We'll do that refactor in a later PR.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_fn: Any) -> None:
|
||||
"""
|
||||
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
|
||||
of a compiled function generated by AOTAutograd.
|
||||
"""
|
||||
assert hasattr(compiled_fn, "serialize")
|
||||
self.compiled_fn = compiled_fn
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
return getattr(self.compiled_fn, attr)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, fn: "BundledAOTAutogradSerializableCallable"
|
||||
) -> bytes:
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
result = pickle.dumps(fn.compiled_fn.serialize())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.compiled_fn(*args, **kwargs)
|
@ -1707,6 +1707,39 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_user_input_output(flat_values: list[Any], error_type: UserErrorType) -> None:
|
||||
supported_types = [
|
||||
torch.Tensor,
|
||||
torch.SymInt,
|
||||
torch.SymFloat,
|
||||
torch.SymBool,
|
||||
torch._C.ScriptObject,
|
||||
_IntWrapper,
|
||||
] + list(common_constant_types)
|
||||
|
||||
def is_supported_type(val: Any) -> bool:
|
||||
return isinstance(val, tuple(supported_types))
|
||||
|
||||
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
|
||||
# We only check that the outputs are not None. Inputs can be None.
|
||||
for v in flat_values:
|
||||
if not is_supported_type(v):
|
||||
if error_type == UserErrorType.INVALID_INPUT and v is None:
|
||||
continue
|
||||
|
||||
raise UserError(
|
||||
error_type,
|
||||
f"It looks like one of the {value_type}s with type `{type(v)}` "
|
||||
"is not supported or pytree-flattenable. \n"
|
||||
f"Exported graphs {value_type}s can only contain the "
|
||||
f"following supported types: {supported_types}. \n"
|
||||
"If you are using a custom class object, "
|
||||
"please register a pytree_flatten/unflatten function "
|
||||
"using `torch.utils._pytree.register_pytree_node` or "
|
||||
"`torch.export.register_dataclass`.",
|
||||
)
|
||||
|
||||
|
||||
def rewrite_signature(
|
||||
f_sig: inspect.Signature,
|
||||
graph: torch.fx.GraphModule,
|
||||
@ -1721,40 +1754,6 @@ def rewrite_signature(
|
||||
) -> torch.fx.GraphModule:
|
||||
orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
|
||||
|
||||
def check_user_input_output(
|
||||
flat_values: list[Any], error_type: UserErrorType
|
||||
) -> None:
|
||||
supported_types = [
|
||||
torch.Tensor,
|
||||
torch.SymInt,
|
||||
torch.SymFloat,
|
||||
torch.SymBool,
|
||||
torch._C.ScriptObject,
|
||||
_IntWrapper,
|
||||
] + list(common_constant_types)
|
||||
|
||||
def is_supported_type(val: Any) -> bool:
|
||||
return isinstance(val, tuple(supported_types))
|
||||
|
||||
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
|
||||
# We only check that the outputs are not None. Inputs can be None.
|
||||
for v in flat_values:
|
||||
if not is_supported_type(v):
|
||||
if error_type == UserErrorType.INVALID_INPUT and v is None:
|
||||
continue
|
||||
|
||||
raise UserError(
|
||||
error_type,
|
||||
f"It looks like one of the {value_type}s with type `{type(v)}` "
|
||||
"is not supported or pytree-flattenable. \n"
|
||||
f"Exported graphs {value_type}s can only contain the "
|
||||
f"following supported types: {supported_types}. \n"
|
||||
"If you are using a custom class object, "
|
||||
"please register a pytree_flatten/unflatten function "
|
||||
"using `torch.utils._pytree.register_pytree_node` or "
|
||||
"`torch.export.register_dataclass`.",
|
||||
)
|
||||
|
||||
check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
|
||||
flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
|
||||
check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)
|
||||
|
@ -10,7 +10,8 @@ import torch
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
|
||||
from torch._dynamo.eval_frame import argument_names
|
||||
from torch._dynamo.eval_frame import argument_names, check_user_input_output
|
||||
from torch._dynamo.exc import UserErrorType
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._export.utils import _compiling_state_context
|
||||
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
||||
@ -479,6 +480,7 @@ def _dynamo_graph_capture_for_export(
|
||||
# This sets the is_exporting flag when building guards.
|
||||
with _compiling_state_context():
|
||||
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
check_user_input_output(flat_inputs, UserErrorType.INVALID_INPUT)
|
||||
module_to_trace = ModuleToTrace(mod, in_spec)
|
||||
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod
|
||||
|
||||
|
@ -200,10 +200,9 @@ class SuperVariable(VariableTracker):
|
||||
and not (args or kwargs)
|
||||
):
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, unpatched_nn_module_init, source=source
|
||||
)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
unpatched_nn_module_init, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
else:
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported super().__init__() call",
|
||||
@ -231,8 +230,9 @@ class SuperVariable(VariableTracker):
|
||||
elif isinstance(inner_fn, staticmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=source
|
||||
).call_function(tx, args, kwargs)
|
||||
elif isinstance(inner_fn, classmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
|
||||
tx, self.objvar.value_type, cls_source
|
||||
)
|
||||
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
)
|
||||
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
).call_function(tx, [cls_variable, *args], kwargs)
|
||||
elif isinstance(inner_fn, types.FunctionType):
|
||||
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
elif isinstance(inner_fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
inner_fn.__func__, self.objvar, source=source
|
||||
@ -574,8 +574,10 @@ class ComptimeVariable(VariableTracker):
|
||||
from ..comptime import comptime
|
||||
|
||||
# To support the comptime.print_graph convenience accessors
|
||||
return VariableTracker.build(
|
||||
tx, getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
from .functions import UserFunctionVariable
|
||||
|
||||
return UserFunctionVariable(
|
||||
getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
)
|
||||
|
||||
def call_function(
|
||||
@ -769,8 +771,9 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
sig = inspect.signature(fn)
|
||||
if len(args) - 1 == len(sig._parameters):
|
||||
args = args[1:] # Don't use context
|
||||
fn_vt = VariableTracker.build(tx, fn, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
return variables.UserFunctionVariable(fn, source=source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
elif isinstance(fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
fn.__func__,
|
||||
@ -796,8 +799,9 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
fn_source = AttrSource(self.source, "backward")
|
||||
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
|
||||
def call_function(self, tx: "InstructionTranslator", args, kwargs):
|
||||
return AutogradFunctionVariable(self.fn_cls)
|
||||
@ -1022,12 +1026,10 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
||||
assert tx.one_graph or tx.error_on_graph_break, (
|
||||
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||
)
|
||||
fn_vt = VariableTracker.build(
|
||||
tx,
|
||||
return variables.UserFunctionVariable(
|
||||
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
|
||||
source=self.source,
|
||||
)
|
||||
return fn_vt.call_function(
|
||||
).call_function(
|
||||
tx,
|
||||
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
|
||||
kwargs,
|
||||
|
@ -1347,6 +1347,15 @@ def create_functional_call(
|
||||
maybe_disable_thunkify(),
|
||||
):
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
if kwargs:
|
||||
# Handle **kwargs. FX only natively supports positional
|
||||
# arguments (through placeholders).
|
||||
arg_list = list(args[params_len:])
|
||||
arg_list.extend(list(kwargs.values()))
|
||||
args = tuple(arg_list)
|
||||
else:
|
||||
args = args[params_len:]
|
||||
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "Anomaly Detection has been enabled."
|
||||
@ -1355,9 +1364,7 @@ def create_functional_call(
|
||||
fake_mode = detect_fake_mode()
|
||||
assert fake_mode is not None
|
||||
fake_mode.epoch += 1
|
||||
out = PropagateUnbackedSymInts(mod).run(
|
||||
*args[params_len:], **kwargs
|
||||
)
|
||||
out = PropagateUnbackedSymInts(mod).run(*args)
|
||||
else:
|
||||
out = mod(*args[params_len:], **kwargs)
|
||||
|
||||
|
@ -391,6 +391,7 @@ def standalone_compile(
|
||||
"from_example_inputs", "from_tracing_context", "from_graph"
|
||||
] = "from_graph",
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Precompilation API for inductor.
|
||||
@ -422,5 +423,5 @@ def standalone_compile(
|
||||
|
||||
options = options if options else {}
|
||||
return standalone_compile(
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
|
||||
)
|
||||
|
@ -159,14 +159,11 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
]
|
||||
|
||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float64,
|
||||
torch.float,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
|
||||
|
||||
|
@ -5,10 +5,12 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.fx
|
||||
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
@ -30,9 +32,9 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompiledArtifact:
|
||||
class CompiledArtifact(ABC):
|
||||
"""
|
||||
CompiledArtifact class represents the precompiled inductor artifact that
|
||||
CompiledArtifact class represents the inductor cache artifacts that
|
||||
can be invoked in order to avoid repeated compilation.
|
||||
|
||||
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
|
||||
@ -45,11 +47,68 @@ class CompiledArtifact:
|
||||
binary or unpacked data.
|
||||
|
||||
Finally, the CompiledArtifact can be invoked via the __call__ method
|
||||
to execute the precompiled artifact.
|
||||
to execute the cached artifact.
|
||||
"""
|
||||
|
||||
_compiled_fn: Callable[..., Any]
|
||||
_artifacts: Optional[tuple[bytes, CacheInfo]]
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
artifacts: Optional[tuple[bytes, CacheInfo]],
|
||||
):
|
||||
self._compiled_fn = compiled_fn
|
||||
self._artifacts = artifacts
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args: Any) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
if format == "unpacked":
|
||||
# If format is unpacked, it must be a CacheCompiledArtifact
|
||||
return CacheCompiledArtifact.load(path=path, format=format)
|
||||
|
||||
assert format == "binary"
|
||||
with open(path, "rb") as file:
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
result_bytes = file.read()
|
||||
reader = BytesReader(result_bytes)
|
||||
header = reader.read_bytes()
|
||||
if header == AOTCompiledArtifact.AOT_HEADER:
|
||||
assert reader.read_bytes() == torch_key()
|
||||
artifact = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
return AOTCompiledArtifact.deserialize(artifact)
|
||||
# Otherwise, it's in the CacheCompiledArtifact format
|
||||
elif header == CacheCompiledArtifact.CACHE_HEADER:
|
||||
assert reader.read_bytes() == torch_key()
|
||||
key = reader.read_str()
|
||||
artifact_bytes = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
return CacheCompiledArtifact._load_impl(nullcontext(), key)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
|
||||
+ header.decode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class CacheCompiledArtifact(CompiledArtifact):
|
||||
"""
|
||||
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
|
||||
"""
|
||||
|
||||
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -83,6 +142,7 @@ class CompiledArtifact:
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
@ -116,9 +176,51 @@ class CompiledArtifact:
|
||||
log.info("Output code written to: %s", output_file)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
def _load_impl(
|
||||
cache_dir_ctx: AbstractContextManager[Any], key: str
|
||||
) -> CompiledArtifact:
|
||||
with (
|
||||
cache_dir_ctx,
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
args=[],
|
||||
cache_info={},
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
|
||||
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
|
||||
with torch._guards.tracing(context):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> tuple[str, AbstractContextManager[Any]]:
|
||||
"""
|
||||
Do format specific prep and loads, return a context manager and key
|
||||
"""
|
||||
path = normalize_path_separator(path)
|
||||
with dynamo_timed("CompiledArtifact.load"):
|
||||
if format == "binary":
|
||||
@ -137,8 +239,7 @@ class CompiledArtifact:
|
||||
assert reader.is_finished()
|
||||
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
|
||||
return key, nullcontext()
|
||||
else:
|
||||
assert format == "unpacked"
|
||||
assert os.path.isdir(path)
|
||||
@ -148,43 +249,105 @@ class CompiledArtifact:
|
||||
assert len(files) == 1
|
||||
key = files[0]
|
||||
cache_dir_ctx = temporary_cache_dir(path)
|
||||
return key, cache_dir_ctx
|
||||
|
||||
with (
|
||||
cache_dir_ctx,
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
|
||||
path=path, format=format
|
||||
)
|
||||
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
|
||||
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
args=[],
|
||||
cache_info={},
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
class AOTCompiledArtifact(CompiledArtifact):
|
||||
"""
|
||||
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
|
||||
This object is always a serializable callable function.
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
|
||||
is used by torch._dynamo.aot_compile for AOT Precompilation.
|
||||
"""
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
|
||||
|
||||
context = torch._guards.TracingContext(
|
||||
FakeTensorMode(shape_env=ShapeEnv())
|
||||
)
|
||||
with torch._guards.tracing(context):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
):
|
||||
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
|
||||
self._artifacts = (
|
||||
None # We don't need artifacts, the inner object handles everything
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bundled_callable(
|
||||
bundled_fn: BundledAOTAutogradSerializableCallable,
|
||||
) -> AOTCompiledArtifact:
|
||||
return AOTCompiledArtifact(bundled_fn.compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
return self.inner_fn(*args)
|
||||
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
if format == "unpacked":
|
||||
raise RuntimeError(
|
||||
"AOTCompiledArtifact does not support unpacked format yet"
|
||||
)
|
||||
result_bytes = self.serialize()
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_bytes(result_bytes)
|
||||
|
||||
from torch._inductor.codecache import write_atomic
|
||||
|
||||
# Save a sentinel file to indicate that this is AOT
|
||||
write_atomic(path, writer.to_bytes())
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
|
||||
self.inner_fn
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
|
||||
deserialized = (
|
||||
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
|
||||
result_bytes
|
||||
)
|
||||
)
|
||||
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
|
||||
return AOTCompiledArtifact.from_bundled_callable(deserialized)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
if format == "unpacked":
|
||||
raise RuntimeError(
|
||||
"AOTCompiledArtifact does not support unpacked format yet"
|
||||
)
|
||||
with open(path, "rb") as file:
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
result_bytes = file.read()
|
||||
reader = BytesReader(result_bytes)
|
||||
header = reader.read_bytes()
|
||||
assert header == AOTCompiledArtifact.AOT_HEADER
|
||||
assert reader.read_bytes() == torch_key()
|
||||
artifact = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
return AOTCompiledArtifact.deserialize(artifact)
|
||||
|
||||
|
||||
def standalone_compile(
|
||||
@ -193,7 +356,11 @@ def standalone_compile(
|
||||
*,
|
||||
dynamic_shapes: Any,
|
||||
options: Any,
|
||||
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Implementation of torch.inductor.standalone_compile
|
||||
"""
|
||||
from torch.compiler._cache import CacheArtifactManager
|
||||
|
||||
from .compile_fx import compile_fx
|
||||
@ -249,6 +416,7 @@ def standalone_compile(
|
||||
torch._guards.tracing(context),
|
||||
CacheArtifactManager.with_fresh_cache(),
|
||||
config.patch("triton.autotune_at_compile_time", True),
|
||||
torch._functorch.config.patch("bundled_autograd_cache", aot),
|
||||
):
|
||||
# compile_fx can mutate gm
|
||||
gm = copy.deepcopy(gm)
|
||||
@ -256,7 +424,12 @@ def standalone_compile(
|
||||
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
|
||||
)
|
||||
assert callable(compiled_fn)
|
||||
|
||||
if aot:
|
||||
if not hasattr(compiled_fn, "serialize"):
|
||||
raise RuntimeError(
|
||||
"Compiled function should have serialize method when aot=True"
|
||||
)
|
||||
return AOTCompiledArtifact(compiled_fn)
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
if artifacts is None:
|
||||
log.warning(
|
||||
@ -264,4 +437,4 @@ def standalone_compile(
|
||||
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
|
||||
)
|
||||
|
||||
return CompiledArtifact(compiled_fn, artifacts)
|
||||
return CacheCompiledArtifact(compiled_fn, artifacts)
|
||||
|
@ -15,7 +15,6 @@ from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
_disable_infra_mode,
|
||||
autograd_would_have_decomposed,
|
||||
return_and_correct_aliasing,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
@ -410,13 +409,8 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
return False
|
||||
return True
|
||||
|
||||
# in normal torch.compile IR, we only decompose an op if autograd
|
||||
# would have decomposed it (NB: autograd may have been skipped if
|
||||
# we are in inference mode)
|
||||
# TODO: the flatten here can potentially be deduped with the
|
||||
# unwrapping pytree_map later
|
||||
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
|
||||
return autograd_would_have_decomposed(func, flat_args_kwargs)
|
||||
# in normal torch.compile IR, we decompose functional composite ops
|
||||
return True
|
||||
|
||||
if (
|
||||
func not in FunctionalTensor.metadata_fns
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/runtime/logging.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
@ -33,7 +34,7 @@ int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
||||
return raw_counter.sum / raw_counter.count;
|
||||
} break;
|
||||
}
|
||||
throw std::runtime_error("Unknown aggregation type!");
|
||||
TORCH_CHECK(false, "Unknown aggregation type!");
|
||||
}
|
||||
|
||||
void LockingLogger::setAggregationType(
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
|
||||
#include <limits>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace torch::jit {
|
||||
@ -112,20 +113,17 @@ void listRemove<at::Tensor>(Stack& stack) {
|
||||
}
|
||||
|
||||
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
|
||||
if (t.requires_grad()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot input a tensor that requires grad as a scalar argument");
|
||||
}
|
||||
if (!t.sizes().empty()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
}
|
||||
if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot input a tensor of type " << t.scalar_type()
|
||||
<< " as an integral argument";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!t.requires_grad(),
|
||||
"Cannot input a tensor that requires grad as a scalar argument");
|
||||
TORCH_CHECK(
|
||||
t.sizes().empty(),
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
TORCH_CHECK(
|
||||
!toInt || isIntegralType(t.scalar_type(), /*includeBool=*/false),
|
||||
"Cannot input a tensor of type ",
|
||||
t.scalar_type(),
|
||||
" as an integral argument");
|
||||
}
|
||||
|
||||
void checkDoubleInRange(double a) {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
@ -159,9 +160,8 @@ void sort_op(Stack& stack) {
|
||||
|
||||
if (!g_list.empty()) {
|
||||
std::stringstream error_str;
|
||||
if (!isSortableListOfObjectsOrTuples(g_list, error_str)) {
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
isSortableListOfObjectsOrTuples(g_list, error_str), error_str.str());
|
||||
|
||||
c10::IValueComparator comparator;
|
||||
if (reverse) {
|
||||
@ -254,9 +254,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
int64_t lo = 0, hi = 0, step = 0;
|
||||
pop(stack, lo, hi, step);
|
||||
// error handling when step_val = 0 during runtime
|
||||
if (step == 0) {
|
||||
throw std::runtime_error("range() arg 3 must not be zero");
|
||||
}
|
||||
TORCH_CHECK(step != 0, "range() arg 3 must not be zero");
|
||||
if (step > 0 && lo < hi) {
|
||||
push(stack, 1 + (hi - 1 - lo) / step);
|
||||
} else if (step < 0 && lo > hi) {
|
||||
@ -382,14 +380,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto s = pop(stack).toString();
|
||||
std::string::size_type sz = 0;
|
||||
int64_t val = static_cast<int64_t>(std::stoll(s->string(), &sz));
|
||||
if (sz == s->string().size()) {
|
||||
push(stack, val);
|
||||
} else {
|
||||
std::stringstream error_str;
|
||||
error_str << "invalid literal for int() "
|
||||
<< "with base 10: '" << s->string() << "'";
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
sz == s->string().size(),
|
||||
"invalid literal for int() ",
|
||||
"with base 10: '",
|
||||
s->string(),
|
||||
"'");
|
||||
push(stack, val);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
@ -436,14 +433,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto s = pop(stack).toString();
|
||||
std::string::size_type sz = 0;
|
||||
double b = std::stod(s->string(), &sz);
|
||||
if (sz == s->string().size()) {
|
||||
push(stack, b);
|
||||
} else {
|
||||
std::stringstream error_str;
|
||||
error_str << "could not convert string "
|
||||
<< "to float: '" << s->string() << "'";
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
sz == s->string().size(),
|
||||
"could not convert string ",
|
||||
"to float: '",
|
||||
s->string(),
|
||||
"'");
|
||||
push(stack, b);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
@ -1793,10 +1789,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
}
|
||||
|
||||
const std::string& separator = ivalue.toStringRef();
|
||||
|
||||
if (separator.empty()) {
|
||||
throw std::runtime_error("ValueError: empty separator");
|
||||
}
|
||||
TORCH_CHECK(!separator.empty(), "ValueError: empty separator");
|
||||
|
||||
auto count = 0;
|
||||
|
||||
@ -1919,11 +1912,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
if (string.size() > static_cast<std::string::size_type>(width)) {
|
||||
push(stack, string);
|
||||
return;
|
||||
@ -2092,9 +2083,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string substr = pop(stack).toStringRef();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
auto result = stringFindImpl(string, substr, start, end);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
TORCH_CHECK(result >= 0, "ValueError: substring not found");
|
||||
push(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
@ -2107,9 +2096,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string substr = pop(stack).toStringRef();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
auto result = stringFindImpl(string, substr, start, end, true);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
TORCH_CHECK(result >= 0, "ValueError: substring not found");
|
||||
push(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
@ -2183,11 +2170,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
auto to_append =
|
||||
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
@ -2207,11 +2192,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
auto to_append =
|
||||
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
@ -3358,10 +3341,8 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
int64_t a = 0, b = 0;
|
||||
lldiv_t divresult = {};
|
||||
pop(stack, a, b);
|
||||
if (b == 0) {
|
||||
throw std::runtime_error(
|
||||
"ZeroDivisionError: integer division or modulo by zero");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
b != 0, "ZeroDivisionError: integer division or modulo by zero");
|
||||
divresult = lldiv(a, b);
|
||||
if (divresult.rem && (a < 0) != (b < 0)) {
|
||||
divresult.quot -= 1;
|
||||
@ -3379,9 +3360,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
[](Stack& stack) {
|
||||
double a = 0, b = 0;
|
||||
pop(stack, a, b);
|
||||
if (b == 0) {
|
||||
throw std::runtime_error("ZeroDivisionError: float divmod()");
|
||||
}
|
||||
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()");
|
||||
double rem = fmod(a, b);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||
if (rem && (a < 0) != (b < 0)) {
|
||||
@ -3426,9 +3405,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
type_a a; \
|
||||
type_b b; \
|
||||
pop(stack, a, b); \
|
||||
if (b == 0) { \
|
||||
throw std::runtime_error("ZeroDivisionError: float divmod()"); \
|
||||
} \
|
||||
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()"); \
|
||||
double quot = floor(a / b); \
|
||||
double rem = a - (quot * b); \
|
||||
push(stack, quot, rem); \
|
||||
|
@ -466,14 +466,6 @@ at::Tensor LazyNativeFunctions::linalg_pinv(
|
||||
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::svd(
|
||||
const at::Tensor& self,
|
||||
bool some,
|
||||
bool compute_uv) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(svd)>::call(
|
||||
self, some, compute_uv);
|
||||
}
|
||||
|
||||
// functionalize_aten_op can't handle out= ops directly.
|
||||
// Instead, we can call the composite kernel from core, and copy and mutations
|
||||
// back to the inputs.
|
||||
|
@ -21,10 +21,6 @@ backends are ready, this list allows opt-in one at a time.
|
||||
PRESERVED_ATEN_CIA_OPS = {
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.upsample_nearest2d.vec,
|
||||
# NB: don't use the C++ decomp, because it is not functional!
|
||||
torch.ops.aten.silu_backward.default,
|
||||
torch.ops.aten.mish_backward.default,
|
||||
torch.ops.aten._fused_rms_norm.default,
|
||||
}
|
||||
|
||||
|
||||
|
@ -63,7 +63,6 @@ from torch.utils._python_dispatch import (
|
||||
_disable_infra_mode,
|
||||
_push_mode,
|
||||
_unset_infra_mode,
|
||||
autograd_would_have_decomposed,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._stats import count
|
||||
@ -1033,16 +1032,11 @@ def proxy_call(
|
||||
return r
|
||||
|
||||
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
|
||||
if (
|
||||
not pre_dispatch
|
||||
and func
|
||||
not in [
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
]
|
||||
and autograd_would_have_decomposed(func, flat_args_kwargs)
|
||||
):
|
||||
if not pre_dispatch and func not in [
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
]:
|
||||
with proxy_mode:
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
|
@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
|
||||
|
||||
|
||||
def _compile_submod(gm, prefix):
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith(prefix):
|
||||
fake_inputs = []
|
||||
@ -56,13 +58,12 @@ def _compile_submod(gm, prefix):
|
||||
|
||||
submod = getattr(gm, node.target)
|
||||
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(
|
||||
torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context"
|
||||
)
|
||||
compiled_fn = torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
|
||||
)
|
||||
|
||||
assert isinstance(compiled_fn, AOTCompiledArtifact)
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(compiled_fn)
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
compiled_submod, args=node.args, kwargs=node.kwargs
|
||||
|
@ -63,15 +63,15 @@ struct dummy_int1_7_t {};
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(c10::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
_(c10::BFloat16, BFloat16) \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
@ -81,19 +81,19 @@ struct dummy_int1_7_t {};
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(c10::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
_(c10::BFloat16, BFloat16) \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
@ -103,7 +103,7 @@ struct dummy_int1_7_t {};
|
||||
_(int16_t, Short) /* 2 */ \
|
||||
_(int, Int) /* 3 */ \
|
||||
_(int64_t, Long) /* 4 */ \
|
||||
_(at::Half, Half) /* 5 */ \
|
||||
_(c10::Half, Half) /* 5 */ \
|
||||
_(float, Float) /* 6 */ \
|
||||
_(double, Double) /* 7 */ \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||
@ -113,7 +113,7 @@ struct dummy_int1_7_t {};
|
||||
_(c10::qint8, QInt8) /* 12 */ \
|
||||
_(c10::quint8, QUInt8) /* 13 */ \
|
||||
_(c10::qint32, QInt32) /* 14 */ \
|
||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||
@ -176,24 +176,19 @@ struct dummy_int1_7_t {};
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE>, SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
|
||||
SCALARTYPE1) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
@ -203,53 +198,41 @@ struct dummy_int1_7_t {};
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE6>::t), \
|
||||
SCALARTYPE6) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE7>::t), \
|
||||
SCALARTYPE7)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
|
||||
SCALARTYPE1) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
|
||||
SCALARTYPE2) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, \
|
||||
SCALARTYPE3) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE4>, \
|
||||
SCALARTYPE4) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE5>, \
|
||||
SCALARTYPE5) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE6>, \
|
||||
SCALARTYPE6) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE7>, SCALARTYPE7)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
@ -258,12 +241,12 @@ struct dummy_int1_7_t {};
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
@ -298,7 +281,12 @@ struct ScalarTypeToCPPType;
|
||||
/* can't pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
/* UPDATE: while the CUDA bug is fixed, we cannot remove the */ \
|
||||
/* workaround as it is BC breaking. However, it is recommended to */ \
|
||||
/* update any code that contains */ \
|
||||
/* decltype(ScalarTypeToCPPType<T>::t) */ \
|
||||
/* with */ \
|
||||
/* ScalarTypeToCPPTypeT<T> */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
|
@ -38,8 +38,7 @@ class BytesWriter:
|
||||
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
||||
4, byteorder="big", signed=False
|
||||
)
|
||||
if len(digest) != CHECKSUM_DIGEST_SIZE:
|
||||
raise AssertionError("Computed checksum digest has unexpected size")
|
||||
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
||||
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
|
||||
return bytes(self._data)
|
||||
|
||||
@ -47,13 +46,11 @@ class BytesWriter:
|
||||
class BytesReader:
|
||||
def __init__(self, data: bytes) -> None:
|
||||
# Check for data corruption
|
||||
if len(data) < CHECKSUM_DIGEST_SIZE:
|
||||
raise AssertionError("Input data is too short to contain checksum")
|
||||
assert len(data) >= CHECKSUM_DIGEST_SIZE
|
||||
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
||||
4, byteorder="big", signed=False
|
||||
)
|
||||
if len(digest) != CHECKSUM_DIGEST_SIZE:
|
||||
raise AssertionError("Computed checksum digest has unexpected size")
|
||||
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
||||
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
|
||||
raise RuntimeError(
|
||||
"Bytes object is corrupted, checksum does not match. "
|
||||
@ -123,11 +120,7 @@ class AppendingByteSerializer(Generic[T]):
|
||||
@staticmethod
|
||||
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
|
||||
reader = BytesReader(data)
|
||||
if reader.read_uint64() != _ENCODING_VERSION:
|
||||
raise AssertionError(
|
||||
f"Encoding version mismatch in AppendingByteSerializer.to_list, \
|
||||
got {reader.read_uint64()}"
|
||||
)
|
||||
assert reader.read_uint64() == _ENCODING_VERSION
|
||||
|
||||
result: list[T] = []
|
||||
while not reader.is_finished():
|
||||
|
@ -85,16 +85,12 @@ class _Config(Generic[T]):
|
||||
)
|
||||
|
||||
if self.alias is not None:
|
||||
if (
|
||||
self.default is not _UNSET_SENTINEL
|
||||
or self.justknob is not None
|
||||
or self.env_name_default is not None
|
||||
or self.env_name_force is not None
|
||||
):
|
||||
raise AssertionError(
|
||||
"if alias is set, none of {default, justknob, \
|
||||
env_name_default and env_name_force} can be set"
|
||||
)
|
||||
assert (
|
||||
self.default is _UNSET_SENTINEL
|
||||
and self.justknob is None
|
||||
and self.env_name_default is None
|
||||
and self.env_name_force is None
|
||||
), "if alias is set, none of {default, justknob and env var} can be set"
|
||||
|
||||
@staticmethod
|
||||
def string_or_list_of_string_to_list(
|
||||
@ -104,8 +100,7 @@ class _Config(Generic[T]):
|
||||
return None
|
||||
if isinstance(val, str):
|
||||
return [val]
|
||||
if not isinstance(val, list):
|
||||
raise AssertionError(f"val is not a list, got {type(val)}")
|
||||
assert isinstance(val, list)
|
||||
return val
|
||||
|
||||
|
||||
@ -198,10 +193,7 @@ def install_config_module(module: ModuleType) -> None:
|
||||
if dest is module:
|
||||
delattr(module, key)
|
||||
elif isinstance(value, type):
|
||||
if value.__module__ != module.__name__:
|
||||
raise AssertionError(
|
||||
f"subconfig class {value} must be defined in module {module.__name__}"
|
||||
)
|
||||
assert value.__module__ == module.__name__
|
||||
# a subconfig with `class Blah:` syntax
|
||||
proxy = SubConfigProxy(module, f"{name}.")
|
||||
visit(value, proxy, f"{name}.")
|
||||
@ -242,8 +234,10 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
|
||||
prev_name = ""
|
||||
maybe_current = token.string.strip()
|
||||
if COMPILE_IGNORED_MARKER in maybe_current:
|
||||
if current_comment != ("", -1):
|
||||
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
|
||||
assert current_comment == (
|
||||
"",
|
||||
-1,
|
||||
), f"unconsumed {COMPILE_IGNORED_MARKER}"
|
||||
current_comment = maybe_current, token.start[0]
|
||||
elif token.type == tokenize.NAME:
|
||||
# Only accept the first name token, to handle if you have
|
||||
@ -260,8 +254,7 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
|
||||
assignments.add(prev_name)
|
||||
current_comment = "", -1 # reset
|
||||
prev_name = ""
|
||||
if current_comment != ("", -1):
|
||||
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
|
||||
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
|
||||
return assignments
|
||||
|
||||
|
||||
@ -313,22 +306,20 @@ class _ConfigEntry:
|
||||
|
||||
# Ensure justknobs and envvars are allowlisted types
|
||||
if self.justknob is not None and self.default is not None:
|
||||
if not isinstance(self.default, bool):
|
||||
raise AssertionError(
|
||||
f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
)
|
||||
assert isinstance(self.default, bool), (
|
||||
f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
)
|
||||
if self.value_type is not None and (
|
||||
config.env_name_default is not None or config.env_name_force is not None
|
||||
):
|
||||
if self.value_type not in (
|
||||
assert self.value_type in (
|
||||
bool,
|
||||
str,
|
||||
Optional[bool],
|
||||
Optional[str],
|
||||
):
|
||||
raise AssertionError(
|
||||
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
)
|
||||
), (
|
||||
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
)
|
||||
|
||||
|
||||
class ConfigModule(ModuleType):
|
||||
@ -426,10 +417,7 @@ class ConfigModule(ModuleType):
|
||||
|
||||
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
|
||||
data = self._get_alias_module_and_name(entry)
|
||||
if data is None:
|
||||
raise AssertionError(
|
||||
"alias data should not be None when setting alias value"
|
||||
)
|
||||
assert data is not None
|
||||
module, constant_name = data
|
||||
setattr(module, constant_name, val)
|
||||
|
||||
@ -654,32 +642,19 @@ class ConfigModule(ModuleType):
|
||||
changes: dict[str, Any]
|
||||
if arg1 is not None:
|
||||
if arg2 is not None:
|
||||
if not isinstance(arg1, str):
|
||||
raise AssertionError(
|
||||
"first argument must be a string when passing 2 positional args to patch"
|
||||
)
|
||||
assert isinstance(arg1, str)
|
||||
# patch("key", True) syntax
|
||||
changes = {arg1: arg2}
|
||||
else:
|
||||
if not isinstance(arg1, dict):
|
||||
raise AssertionError(
|
||||
"first argument must be a dict when passing a single positional arg to patch"
|
||||
)
|
||||
assert isinstance(arg1, dict)
|
||||
# patch({"key": True}) syntax
|
||||
changes = arg1
|
||||
if kwargs:
|
||||
raise AssertionError(
|
||||
"cannot pass both positional and keyword arguments to patch"
|
||||
)
|
||||
assert not kwargs
|
||||
else:
|
||||
# patch(key=True) syntax
|
||||
changes = kwargs
|
||||
if arg2 is not None:
|
||||
raise AssertionError(
|
||||
"second positional argument is only valid when first argument is a key string"
|
||||
)
|
||||
if not isinstance(changes, dict):
|
||||
raise AssertionError(f"expected `dict` got {type(changes)}")
|
||||
assert arg2 is None
|
||||
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
|
||||
prior: dict[str, Any] = {}
|
||||
config = self
|
||||
|
||||
@ -688,10 +663,7 @@ class ConfigModule(ModuleType):
|
||||
self.changes = changes
|
||||
|
||||
def __enter__(self) -> None:
|
||||
if prior:
|
||||
raise AssertionError(
|
||||
"prior should be empty when entering ConfigPatch"
|
||||
)
|
||||
assert not prior
|
||||
for key in self.changes.keys():
|
||||
# KeyError on invalid entry
|
||||
prior[key] = config.__getattr__(key)
|
||||
|
@ -21,8 +21,7 @@ This file should be imported into any file that uses install_config_module like
|
||||
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
|
||||
"""
|
||||
|
||||
if not TYPE_CHECKING: # noqa: PYI002
|
||||
raise AssertionError("Do not use at runtime") # noqa: W291
|
||||
assert TYPE_CHECKING, "Do not use at runtime"
|
||||
|
||||
def save_config() -> bytes: ...
|
||||
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
|
||||
|
@ -217,10 +217,7 @@ class ContentStoreReader:
|
||||
weights_only=True,
|
||||
map_location=device,
|
||||
)._untyped_storage
|
||||
if s is None:
|
||||
raise AssertionError(
|
||||
f"expected storage for hash {h} in {os.path.join(self.loc, 'storages')}, got None"
|
||||
)
|
||||
assert s is not None
|
||||
if self.storage_cache is not None:
|
||||
self.storage_cache[device][h] = StorageWeakRef(s)
|
||||
return s
|
||||
|
@ -86,14 +86,13 @@ def context_decorator(ctx, func):
|
||||
be a multi-shot context manager that can be directly invoked multiple times)
|
||||
or a callable that produces a context manager.
|
||||
"""
|
||||
if callable(ctx) and hasattr(ctx, "__enter__"):
|
||||
raise AssertionError(
|
||||
f"Passed in {ctx} is both callable and also a valid context manager "
|
||||
"(has __enter__), making it ambiguous which interface to use. If you "
|
||||
"intended to pass a context manager factory, rewrite your call as "
|
||||
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
||||
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
||||
)
|
||||
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
|
||||
f"Passed in {ctx} is both callable and also a valid context manager "
|
||||
"(has __enter__), making it ambiguous which interface to use. If you "
|
||||
"intended to pass a context manager factory, rewrite your call as "
|
||||
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
||||
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
||||
)
|
||||
|
||||
if not callable(ctx):
|
||||
|
||||
|
@ -933,10 +933,7 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise AssertionError(
|
||||
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
|
||||
)
|
||||
assert _is_pytreespec_instance(treespec)
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
|
||||
|
@ -87,18 +87,12 @@ class DeviceContext(TorchFunctionMode):
|
||||
# or else someone else has popped it!
|
||||
for _ in range(_len_torch_function_stack() - 1):
|
||||
mode = _pop_mode()
|
||||
if isinstance(mode, DeviceContext):
|
||||
raise AssertionError(
|
||||
"Found nested DeviceContext on the mode stack where none expected"
|
||||
)
|
||||
assert not isinstance(mode, DeviceContext)
|
||||
cur_stack.append(mode)
|
||||
|
||||
if _len_torch_function_stack() > 0:
|
||||
mode = _pop_mode()
|
||||
if not isinstance(mode, DeviceContext):
|
||||
raise AssertionError(
|
||||
"Expected a DeviceContext at the bottom of the mode stack"
|
||||
)
|
||||
assert isinstance(mode, DeviceContext)
|
||||
|
||||
for mode in reversed(cur_stack):
|
||||
_push_mode(mode)
|
||||
|
@ -31,8 +31,7 @@ def cache_method(
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
if kwargs:
|
||||
raise AssertionError("cache_method does not accept keyword arguments")
|
||||
assert not kwargs
|
||||
if not (cache := getattr(self, cache_name, None)):
|
||||
cache = {}
|
||||
setattr(self, cache_name, cache)
|
||||
|
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
|
||||
from typing import Optional, overload, Protocol, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -21,10 +20,6 @@ from torch._C import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
||||
# - We need a better user-facing api for _DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
@ -88,8 +83,7 @@ class TorchDispatchMode:
|
||||
|
||||
def __init__(self, _dispatch_key=None):
|
||||
if _dispatch_key is not None:
|
||||
if not isinstance(_dispatch_key, torch._C.DispatchKey):
|
||||
raise AssertionError("_dispatch_key must be a torch._C.DispatchKey")
|
||||
assert isinstance(_dispatch_key, torch._C.DispatchKey)
|
||||
self.__dict__["_dispatch_key"] = _dispatch_key
|
||||
|
||||
self.old_dispatch_mode_flags: deque[bool] = deque()
|
||||
@ -218,24 +212,16 @@ def _get_current_dispatch_mode() -> Optional[TorchDispatchMode]:
|
||||
|
||||
|
||||
def _detect_infra_mode(key):
|
||||
if key not in (
|
||||
assert key in [
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
||||
torch._C._TorchDispatchModeKey.PROXY,
|
||||
):
|
||||
raise AssertionError(
|
||||
f"key must be either FUNCTIONAL ({torch._C._TorchDispatchModeKey.FUNCTIONAL}) \
|
||||
or PROXY ({torch._C._TorchDispatchModeKey.PROXY}) _TorchDispatchModeKey, \
|
||||
got {key}"
|
||||
)
|
||||
]
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
|
||||
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
|
||||
post_dispatch_mode = torch._C._get_dispatch_mode(key)
|
||||
|
||||
if pre_dispatch_mode is not None and post_dispatch_mode is not None:
|
||||
raise AssertionError(
|
||||
"At most one of pre_dispatch_mode and post_dispatch_mode may be active"
|
||||
)
|
||||
assert (pre_dispatch_mode is None) or (post_dispatch_mode is None)
|
||||
|
||||
if pre_dispatch_mode is None:
|
||||
return post_dispatch_mode
|
||||
@ -261,13 +247,10 @@ def _unset_infra_mode(key):
|
||||
|
||||
|
||||
def _disable_infra_mode(key):
|
||||
if key not in (
|
||||
assert key in (
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
||||
torch._C._TorchDispatchModeKey.PROXY,
|
||||
):
|
||||
raise AssertionError(
|
||||
"key must be either FUNCTIONAL or PROXY _TorchDispatchModeKey"
|
||||
)
|
||||
)
|
||||
mode_unset = _unset_infra_mode(key)
|
||||
try:
|
||||
yield mode_unset
|
||||
@ -288,10 +271,7 @@ def _get_current_dispatch_mode_stack() -> list[TorchDispatchMode]:
|
||||
|
||||
def _push_mode(mode: TorchDispatchMode):
|
||||
k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
|
||||
if k is not None and k != torch._C.DispatchKey.PreDispatch:
|
||||
raise AssertionError(
|
||||
"mode._dispatch_key must be None or DispatchKey.PreDispatch"
|
||||
)
|
||||
assert k is None or k == torch._C.DispatchKey.PreDispatch
|
||||
if k is None:
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
return
|
||||
@ -434,7 +414,7 @@ class TensorWithFlatten(Protocol):
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
@ -529,16 +509,14 @@ def transform_subclass(t, callback, outer_size=None, outer_stride=None):
|
||||
# NB: Purposefully guard here to simplify the inner / outer symbols.
|
||||
# Using sym_eq() for symbolic comparison can result in an expression that's too
|
||||
# difficult to guard on, so we use == here.
|
||||
if sub.shape != outer_size:
|
||||
raise AssertionError(
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"shape equal to {outer_size}, but got: {sub.shape}"
|
||||
)
|
||||
if sub.stride() != outer_stride:
|
||||
raise AssertionError(
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
||||
)
|
||||
assert sub.shape == outer_size, (
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"shape equal to {outer_size}, but got: {sub.shape}"
|
||||
)
|
||||
assert sub.stride() == outer_stride, (
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
||||
)
|
||||
|
||||
return sub
|
||||
|
||||
@ -555,12 +533,9 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
It does this by unsafely overwriting the storage field of the output tensor
|
||||
to be the same storage as the input.
|
||||
"""
|
||||
if not isinstance(func, torch._ops.OpOverload):
|
||||
raise AssertionError(f"func must be an OpOverload, got {type(args)}")
|
||||
if not isinstance(args, tuple):
|
||||
raise AssertionError(f"args must be a tuple, got {type(args)}")
|
||||
if not isinstance(outs, (list, tuple)):
|
||||
raise AssertionError(f"outs must be a list or tuple, got {type(args)}")
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
assert isinstance(args, tuple)
|
||||
assert isinstance(outs, (list, tuple))
|
||||
|
||||
def alias_non_inplace_storage(arg, ret):
|
||||
# This is hopefully a reasonable assert:
|
||||
@ -581,11 +556,10 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
):
|
||||
ret_list = ret if isinstance(ret, list) else [ret]
|
||||
for r in ret_list:
|
||||
if type(arg) is not type(r):
|
||||
raise AssertionError(
|
||||
f"Called {str(func)} with input of type {type(arg)}\n"
|
||||
f"and output of type {type(ret)}. But expected types to match."
|
||||
)
|
||||
assert type(arg) is type(
|
||||
r
|
||||
), f"""Called {str(func)} with input of type {type(arg)}
|
||||
and output of type {type(ret)}. But expected types to match."""
|
||||
# Need to call a non-dispatcher helper, because we explicitly do **not**
|
||||
# want our subclass to intercept the set_() call.
|
||||
# instead, our subclass should directly have its storage swapped out.
|
||||
@ -601,8 +575,7 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
for r in ret:
|
||||
torch._functionalize_unsafe_set(r, arg)
|
||||
else:
|
||||
if not isinstance(ret, torch.Tensor):
|
||||
raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
|
||||
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
|
||||
torch._functionalize_unsafe_set(ret, arg)
|
||||
|
||||
for arg_idx, schema_arg in enumerate(schema_info.args):
|
||||
@ -646,10 +619,7 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
# properly for some ops that output tensorlists)
|
||||
if func.namespace == "aten":
|
||||
torchgen_schema_str = str(func._schema)
|
||||
if not torchgen_schema_str.startswith("aten::"):
|
||||
raise AssertionError(
|
||||
"Expected torchgen schema string to start with 'aten::'"
|
||||
)
|
||||
assert torchgen_schema_str.startswith("aten::")
|
||||
# remove the aten:: namespace, which is added by the torchscript parser,
|
||||
# and torchgen doesn't know how to handle
|
||||
torchgen_schema_str = torchgen_schema_str[6:]
|
||||
@ -712,64 +682,6 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
return schema_info
|
||||
|
||||
|
||||
def autograd_would_have_decomposed(
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
|
||||
) -> bool:
|
||||
"""
|
||||
Suppose that an operator has CompositeImplicitAutograd decomp registered.
|
||||
Would autograd have used this decomposition? It will only use it if there
|
||||
isn't an explicit backend registration for the device as well. This function
|
||||
will tell if this would have occurred.
|
||||
|
||||
Why do we need to apply these decompositions later? When inference mode is
|
||||
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
|
||||
on the decomposition have been applied. It's easy to accidentally never apply
|
||||
the decomposition, resulting in an operator showing up in a graph that
|
||||
is unexpected.
|
||||
|
||||
Why do we need to AVOID applying the decomposition when autograd wouldn't
|
||||
have decomposed? If autograd doesn't decompose, this means in eager mode
|
||||
we would have run the fused kernel. It must be possible to trace this
|
||||
fused kernel directly into the graph for fidelity with eager (NB: a user
|
||||
has the option of then further decomposing at proxy tensor mode via
|
||||
decomposition table, but we must preserve it to proxy mode to have the
|
||||
choice.)
|
||||
|
||||
Why does functionalization need to also perform the test here? This is
|
||||
because some CompositeImplicitAutograd decompositions are not functional.
|
||||
If we are eventually going to decompose, we need to do this while we can
|
||||
still turn functionalization back on, so those decompositions get functionalized.
|
||||
So an early decomposition in functionalization may still be necessary. Note that
|
||||
if proxy tensor decomposition process could turn functionalization back on, this
|
||||
wouldn't be necessary, and maybe that is a useful thing to do anyway because
|
||||
the decomposition table is user specified and a user could violate the functional
|
||||
decomp requirement with a bad decomp. If this happened, then you could always
|
||||
pass through functionalization.
|
||||
"""
|
||||
has_backend_registration = False
|
||||
for a in flat_args:
|
||||
if isinstance(a, torch.Tensor):
|
||||
backend_key = torch._C._parse_dispatch_key(
|
||||
torch._C._dispatch_key_for_device(a.device.type)
|
||||
)
|
||||
if backend_key is None:
|
||||
raise AssertionError(
|
||||
"Failed to infer backend dispatch key from tensor device"
|
||||
)
|
||||
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
|
||||
# but this one checks py_impl and CompositeImplicitAutograd
|
||||
# incorrectly shows up as has backend reg here
|
||||
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), backend_key
|
||||
)
|
||||
|
||||
# in theory we should take all backend keys and take the highest priority one
|
||||
# to properly mimic the dispatcher,
|
||||
# this just grabs the first tensor and takes its device key
|
||||
break
|
||||
return not has_backend_registration
|
||||
|
||||
|
||||
# See NOTE[SchemaInfo int_tags] above.
|
||||
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
|
||||
|
||||
@ -799,8 +711,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
if not alias_set or not x.is_write:
|
||||
return None
|
||||
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
|
||||
if len(alias_set) != 1:
|
||||
raise AssertionError("Expected alias_set to contain exactly one element")
|
||||
assert len(alias_set) == 1
|
||||
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
|
||||
# set of size 1 on Python 3.13.
|
||||
return next(iter(alias_set))
|
||||
@ -814,10 +725,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
|
||||
]
|
||||
# For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
|
||||
if len(arg_indices) != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one argument index for the given output alias"
|
||||
)
|
||||
assert len(arg_indices) == 1
|
||||
idx = arg_indices[0]
|
||||
arg_info = schema_info.args[idx]
|
||||
if arg_info.name is not None and arg_info.name in new_kwargs:
|
||||
@ -843,10 +751,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
]
|
||||
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
|
||||
# there is only a single argument that gets its metadata mutated.
|
||||
if len(mutated_args) != 1:
|
||||
raise AssertionError(
|
||||
"expected exactly one mutated arg for inplace_view ops"
|
||||
)
|
||||
assert len(mutated_args) == 1
|
||||
# This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
|
||||
# but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
|
||||
# so we don't actually need to update the metadata (and attempting to do so causes errors)
|
||||
|
@ -476,8 +476,7 @@ def _is_constant_holder(spec: "TreeSpec") -> bool:
|
||||
|
||||
def _retrieve_constant(spec: "TreeSpec") -> Any:
|
||||
"""Given a spec from a pytree registered with register_constant, retrieves the constant"""
|
||||
if not _is_constant_holder(spec):
|
||||
raise AssertionError("spec does not correspond to a registered constant pytree")
|
||||
assert _is_constant_holder(spec)
|
||||
return tree_unflatten([], spec)
|
||||
|
||||
|
||||
@ -900,25 +899,17 @@ def _defaultdict_serialize(context: Context) -> DumpableContext:
|
||||
|
||||
|
||||
def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
|
||||
if not isinstance(dumpable_context, dict):
|
||||
raise AssertionError("dumpable_context must be a dict")
|
||||
|
||||
expected_keys = {
|
||||
assert isinstance(dumpable_context, dict)
|
||||
assert set(dumpable_context) == {
|
||||
"default_factory_module",
|
||||
"default_factory_name",
|
||||
"dict_context",
|
||||
}
|
||||
if set(dumpable_context) != expected_keys:
|
||||
raise AssertionError(
|
||||
f"dumpable_context keys must be {expected_keys}, got {set(dumpable_context)}"
|
||||
)
|
||||
|
||||
default_factory_module = dumpable_context["default_factory_module"]
|
||||
default_factory_name = dumpable_context["default_factory_name"]
|
||||
if not isinstance(default_factory_module, str):
|
||||
raise AssertionError("default_factory_module must be a string")
|
||||
if not isinstance(default_factory_name, str):
|
||||
raise AssertionError("default_factory_name must be a string")
|
||||
assert isinstance(default_factory_module, str)
|
||||
assert isinstance(default_factory_name, str)
|
||||
module = importlib.import_module(default_factory_module)
|
||||
default_factory = getattr(module, default_factory_name)
|
||||
|
||||
@ -1742,8 +1733,7 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise AssertionError("treespec must be a TreeSpec")
|
||||
assert isinstance(treespec, TreeSpec)
|
||||
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
return [tree] * treespec.num_leaves
|
||||
|
@ -206,8 +206,7 @@ class CapturedTraceback:
|
||||
import torch._C._profiler
|
||||
|
||||
if script or cpp:
|
||||
if skip != 0:
|
||||
raise AssertionError("skip with script/cpp NYI")
|
||||
assert skip == 0, "skip with script/cpp NYI"
|
||||
|
||||
return CapturedTraceback(
|
||||
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
|
||||
|
@ -430,8 +430,9 @@ def _get_custom_mod_func(func_name: str):
|
||||
it is marked as private. It is a convenience function for backend implementers to
|
||||
more easily call the hooks into their backend extensions.
|
||||
"""
|
||||
if not isinstance(func_name, str):
|
||||
raise AssertionError(f"func_name must be `str`, but got `{type(func_name)}`.")
|
||||
assert isinstance(func_name, str), (
|
||||
f"func_name must be `str`, but got `{type(func_name)}`."
|
||||
)
|
||||
backend_name = _get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, backend_name, None)
|
||||
function = getattr(custom_device_mod, func_name, None)
|
||||
|
@ -119,12 +119,10 @@ def bundle_inputs(
|
||||
# Fortunately there is a function in _recursive that does exactly that conversion.
|
||||
cloned_module = wrap_cpp_module(clone)
|
||||
if isinstance(inputs, dict):
|
||||
if not isinstance(info, dict) and info is not None:
|
||||
raise AssertionError("If inputs is a dict, info must be a dict or None")
|
||||
assert isinstance(info, dict) or info is None
|
||||
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
||||
else:
|
||||
if not isinstance(info, list) and info is not None:
|
||||
raise AssertionError("If inputs is a list, info must be a list or None")
|
||||
assert isinstance(info, list) or info is None
|
||||
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
||||
return cloned_module
|
||||
|
||||
|
@ -1034,10 +1034,8 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
||||
out += f"{line['filename']}:{line['line']}:{line['name']}\n"
|
||||
out += "\n\n"
|
||||
return out
|
||||
if capture_logs_fwd.logs is None:
|
||||
raise AssertionError("capture_logs_fwd.logs is None")
|
||||
if capture_logs_recompute.logs is None:
|
||||
raise AssertionError("capture_logs_recompute.logs is None")
|
||||
assert capture_logs_fwd.logs is not None
|
||||
assert capture_logs_recompute.logs is not None
|
||||
raise CheckpointError(
|
||||
_checkpoint_error_template.format(
|
||||
forward_traces=get_str_tb("original", capture_logs_fwd),
|
||||
@ -1075,14 +1073,12 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def pack_hook(x):
|
||||
x = x.detach() if x.requires_grad else x
|
||||
target_frame = target_frame_ref()
|
||||
if target_frame is None:
|
||||
raise AssertionError("Internal error: target_frame reference is None")
|
||||
assert target_frame is not None # appease mypy
|
||||
recomp_idx = target_frame.recomp_counter[gid]
|
||||
target_frame.recomp_counter[gid] += 1
|
||||
|
||||
if recomp_idx >= len(target_frame.weak_holders):
|
||||
if target_frame.early_stop:
|
||||
raise AssertionError("Unexpected state: target_frame.early_stop is set")
|
||||
assert not target_frame.early_stop
|
||||
if not target_frame.forward_completed:
|
||||
# We run into this case when early stop is not enabled and do
|
||||
# grad within checkpoint.
|
||||
@ -1519,14 +1515,12 @@ def _checkpoint_without_reentrant_generator(
|
||||
device_module = _get_device_module(device_type)
|
||||
forward_context, recompute_context = context_fn()
|
||||
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
|
||||
if (
|
||||
not isinstance(forward_context, TorchDispatchMode)
|
||||
or not isinstance(recompute_context, TorchDispatchMode)
|
||||
):
|
||||
raise AssertionError(
|
||||
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` "
|
||||
"must generate a tuple of two `TorchDispatchMode`s."
|
||||
)
|
||||
assert (
|
||||
isinstance(forward_context, TorchDispatchMode) and
|
||||
isinstance(recompute_context, TorchDispatchMode)
|
||||
), \
|
||||
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
|
||||
"must generate a tuple of two `TorchDispatchMode`s."
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)
|
||||
|
||||
|
@ -290,8 +290,7 @@ def _get_icpx_version() -> str:
|
||||
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
|
||||
version = ['0', '0', '0'] if match is None else list(match.groups())
|
||||
version = list(map(int, version))
|
||||
if len(version) != 3:
|
||||
raise AssertionError("Failed to parse DPC++ compiler version")
|
||||
assert len(version) == 3, "Failed to parse DPC++ compiler version"
|
||||
# Aligning version format with what torch.version.xpu() returns
|
||||
return f"{version[0]}{version[1]:02}{version[2]:02}"
|
||||
|
||||
@ -325,8 +324,7 @@ def _get_sycl_device_flags(cflags):
|
||||
# We need last occurrence of -fsycl-targets as it will be the one taking effect.
|
||||
# So searching in reversed list.
|
||||
flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')]
|
||||
if not flags:
|
||||
raise AssertionError("bug: -fsycl-targets should have been amended to cflags")
|
||||
assert flags, "bug: -fsycl-targets should have been amended to cflags"
|
||||
|
||||
arch_list = _get_sycl_arch_list()
|
||||
if arch_list != '':
|
||||
@ -664,8 +662,7 @@ class BuildExtension(build_ext):
|
||||
extension = next(extension_iter, None)
|
||||
|
||||
if sycl_ext:
|
||||
if not self.use_ninja:
|
||||
raise AssertionError("ninja is required to build sycl extensions.")
|
||||
assert self.use_ninja, "ninja is required to build sycl extensions."
|
||||
|
||||
if cuda_ext and not IS_HIP_EXTENSION:
|
||||
_check_cuda_version(compiler_name, compiler_version)
|
||||
@ -697,10 +694,7 @@ class BuildExtension(build_ext):
|
||||
self._define_torch_extension_name(extension)
|
||||
|
||||
if 'nvcc_dlink' in extension.extra_compile_args:
|
||||
if not self.use_ninja:
|
||||
raise AssertionError(
|
||||
f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
)
|
||||
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
|
||||
# Register .cu, .cuh, .hip, .mm and .sycl as valid source extensions.
|
||||
# NOTE: At the moment .sycl is not a standard extension for SYCL supported
|
||||
@ -2659,11 +2653,9 @@ def _import_module_from_library(module_name, path, is_python_module):
|
||||
if is_python_module:
|
||||
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
||||
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||||
if spec is None:
|
||||
raise AssertionError(f"Failed to create spec for module {module_name} at {filepath}")
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
if not isinstance(spec.loader, importlib.abc.Loader):
|
||||
raise AssertionError("spec.loader is not a valid importlib Loader")
|
||||
assert isinstance(spec.loader, importlib.abc.Loader)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
else:
|
||||
@ -2865,10 +2857,8 @@ e.
|
||||
ldflags = sanitize_flags(ldflags)
|
||||
|
||||
# Sanity checks...
|
||||
if len(sources) != len(objects):
|
||||
raise AssertionError("sources and objects lists must be the same length")
|
||||
if len(sources) == 0:
|
||||
raise AssertionError("At least one source is required to build a library")
|
||||
assert len(sources) == len(objects)
|
||||
assert len(sources) > 0
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
|
||||
|
@ -133,8 +133,9 @@ def from_dlpack(
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if not isinstance(device, torch.device):
|
||||
raise AssertionError(f"from_dlpack: unsupported device type: {type(device)}")
|
||||
assert isinstance(device, torch.device), (
|
||||
f"from_dlpack: unsupported device type: {type(device)}"
|
||||
)
|
||||
kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device)
|
||||
|
||||
ext_device = ext_tensor.__dlpack_device__()
|
||||
@ -162,10 +163,10 @@ def from_dlpack(
|
||||
dlpack = ext_tensor.__dlpack__(**kwargs)
|
||||
|
||||
else:
|
||||
if device is not None or copy is not None:
|
||||
raise AssertionError(
|
||||
"device and copy kwargs not supported when ext_tensor is already a DLPack capsule."
|
||||
)
|
||||
assert device is None and copy is None, (
|
||||
"device and copy kwargs not supported when ext_tensor is "
|
||||
"already a DLPack capsule."
|
||||
)
|
||||
# Old versions just call the converter
|
||||
dlpack = ext_tensor
|
||||
return torch._C._from_dlpack(dlpack)
|
||||
|
@ -62,8 +62,7 @@ def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
|
||||
# Inputs contains the shapes of two matrices.
|
||||
m, k = a_shape
|
||||
k2, n = b_shape
|
||||
if k != k2:
|
||||
raise AssertionError(f"matmul: inner dimensions must match (k == k2), got {k} and {k2}")
|
||||
assert k == k2
|
||||
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
||||
return m * n * 2 * k
|
||||
|
||||
@ -79,10 +78,8 @@ def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
||||
# Inputs contains the shapes of two tensor.
|
||||
b, m, k = a_shape
|
||||
b2, k2, n = b_shape
|
||||
if b != b2:
|
||||
raise AssertionError(f"bmm: batch dimensions must match (b == b2), got {b} and {b2}")
|
||||
if k != k2:
|
||||
raise AssertionError(f"bmm: inner dimensions must match (k == k2), got {k} and {k2}")
|
||||
assert b == b2
|
||||
assert k == k2
|
||||
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
||||
flop = b * m * n * 2 * k
|
||||
return flop
|
||||
@ -269,8 +266,7 @@ def sdpa_flop_count(query_shape, key_shape, value_shape):
|
||||
b, h, s_q, d_q = query_shape
|
||||
_b2, _h2, s_k, _d2 = key_shape
|
||||
_b3, _h3, _s3, d_v = value_shape
|
||||
if not b == _b2 == _b3 or not h == _h2 == _h3 or not d_q == _d2 or not s_k == _s3 or not d_q == _d2:
|
||||
raise AssertionError("sdpa_flop_count: query/key/value shapes are incompatible")
|
||||
assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
|
||||
total_flops = 0
|
||||
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
|
||||
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
|
||||
@ -324,21 +320,15 @@ def _unpack_flash_attention_nested_shapes(
|
||||
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
|
||||
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
|
||||
# So the flops calculation in this case is an overestimate of the actual flops.
|
||||
if len(key.shape) != 3:
|
||||
raise AssertionError("sdpa_flop_count: expected key.shape to be 3-dimensional")
|
||||
if len(value.shape) != 3:
|
||||
raise AssertionError("sdpa_flop_count: expected value.shape to be 3-dimensional")
|
||||
if grad_out is not None and grad_out.shape != query.shape:
|
||||
raise AssertionError("sdpa_flop_count: grad_out.shape must match query.shape when provided")
|
||||
assert len(key.shape) == 3
|
||||
assert len(value.shape) == 3
|
||||
assert grad_out is None or grad_out.shape == query.shape
|
||||
_, h_q, d_q = query.shape
|
||||
_, h_k, d_k = key.shape
|
||||
_, h_v, d_v = value.shape
|
||||
if cum_seq_q is None:
|
||||
raise AssertionError("sdpa_flop_count: cum_seq_q must not be None")
|
||||
if cum_seq_k is None:
|
||||
raise AssertionError("sdpa_flop_count: cum_seq_k must not be None")
|
||||
if cum_seq_q.shape != cum_seq_k.shape:
|
||||
raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
|
||||
assert cum_seq_q is not None
|
||||
assert cum_seq_k is not None
|
||||
assert cum_seq_q.shape == cum_seq_k.shape
|
||||
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
|
||||
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
|
||||
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
|
||||
@ -378,22 +368,15 @@ def _unpack_efficient_attention_nested_shapes(
|
||||
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
|
||||
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
|
||||
# So the flops calculation in this case is an overestimate of the actual flops.
|
||||
if len(key.shape) != 4:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: expected key.shape to be 4-dimensional")
|
||||
if len(value.shape) != 4:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: expected value.shape to be 4-dimensional")
|
||||
if grad_out is not None and grad_out.shape != query.shape:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: grad_out.shape must match query.shape when provided")
|
||||
assert len(key.shape) == 4
|
||||
assert len(value.shape) == 4
|
||||
assert grad_out is None or grad_out.shape == query.shape
|
||||
_, _, h_q, d_q = query.shape
|
||||
_, _, h_k, d_k = key.shape
|
||||
_, _, h_v, d_v = value.shape
|
||||
if cu_seqlens_q is None:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_q must not be None")
|
||||
if cu_seqlens_k is None:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_k must not be None")
|
||||
if cu_seqlens_q.shape != cu_seqlens_k.shape:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: "
|
||||
"cu_seqlens_q and cu_seqlens_k must have the same shape")
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert cu_seqlens_q.shape == cu_seqlens_k.shape
|
||||
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
|
||||
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
|
||||
for len_q, len_k in zip(seqlens_q, seqlens_k):
|
||||
@ -477,10 +460,8 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape
|
||||
_b2, _h2, s_k, _d2 = key_shape
|
||||
_b3, _h3, _s3, d_v = value_shape
|
||||
_b4, _h4, _s4, _d4 = grad_out_shape
|
||||
if not b == _b2 == _b3 == _b4 or not h == _h2 == _h3 == _h4 or not d_q == _d2:
|
||||
raise AssertionError("sdpa_backward_flop_count: batch/heads/dimension mismatch among tensors")
|
||||
if not d_v == _d4 or not s_k == _s3 or not s_q == _s4:
|
||||
raise AssertionError("sdpa_backward_flop_count: grad_out/value/key/query shapes are incompatible")
|
||||
assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
|
||||
assert d_v == _d4 and s_k == _s3 and s_q == _s4
|
||||
total_flops = 0
|
||||
# Step 1: We recompute the scores matrix.
|
||||
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
|
||||
@ -761,8 +742,7 @@ class FlopCounterMode:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.mode is None:
|
||||
raise AssertionError("Internal error: FlopCounter.__exit__ called but mode is None")
|
||||
assert self.mode is not None
|
||||
b = self.mode.__exit__(*args)
|
||||
self.mode = None # break cycles
|
||||
self.mod_tracker.__exit__()
|
||||
|
@ -238,8 +238,7 @@ class BackwardHook:
|
||||
self.grad_outputs = None
|
||||
|
||||
if local_grad_outputs is not None:
|
||||
if self.output_tensors_index is None:
|
||||
raise AssertionError("output_tensors_index should not be None when grad_outputs is not None")
|
||||
assert self.output_tensors_index is not None # mypy
|
||||
return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
|
||||
|
||||
grad_fn.register_hook(hook)
|
||||
|
@ -137,12 +137,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
|
||||
def __init__(self, dense_module):
|
||||
super().__init__()
|
||||
|
||||
if dense_module.training:
|
||||
raise AssertionError("Only support eval mode batchnorm for mkldnn path now")
|
||||
if not dense_module.track_running_stats:
|
||||
raise AssertionError("Only support track_running_stats=True for mkldnn path now")
|
||||
if not dense_module.affine:
|
||||
raise AssertionError("Only support affine=True for mkldnn path now")
|
||||
assert not dense_module.training
|
||||
assert dense_module.track_running_stats
|
||||
assert dense_module.affine
|
||||
|
||||
if dense_module.momentum is None:
|
||||
self.exponential_average_factor = 0.0
|
||||
@ -207,9 +204,8 @@ class MkldnnPrelu(torch.jit.ScriptModule):
|
||||
return y
|
||||
|
||||
def to_mkldnn(module, dtype=torch.float):
|
||||
if dtype not in (torch.float, torch.bfloat16, torch.half):
|
||||
raise AssertionError("MKLDNN only support float, bfloat16, and half path now")
|
||||
|
||||
assert dtype in [torch.float, torch.bfloat16, torch.half], \
|
||||
"MKLDNN only support float, bfloat16, and half path now"
|
||||
|
||||
def m_fn(m, d):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
|
@ -5,8 +5,7 @@ import torch._C
|
||||
|
||||
def format_time(time_us=None, time_ms=None, time_s=None):
|
||||
"""Define time formatting."""
|
||||
if time_us is not None or time_ms is not None or time_s is not None:
|
||||
raise AssertionError("Expected at least one of time_us, time_ms, time_s is not None.")
|
||||
assert sum([time_us is not None, time_ms is not None, time_s is not None]) == 1
|
||||
|
||||
US_IN_SECOND = 1e6
|
||||
US_IN_MS = 1e3
|
||||
|
@ -33,12 +33,12 @@ function version_space() {
|
||||
};
|
||||
}
|
||||
|
||||
function Segment(addr, size, stream, frames, version) {
|
||||
return {addr, size, stream, version, frames};
|
||||
function Segment(addr, size, stream, frames, version, user_metadata) {
|
||||
return {addr, size, stream, version, frames, user_metadata};
|
||||
}
|
||||
|
||||
function Block(addr, size, requested_size, frames, free_requested, version) {
|
||||
return {addr, size, requested_size, frames, free_requested, version};
|
||||
function Block(addr, size, requested_size, frames, free_requested, version, user_metadata) {
|
||||
return {addr, size, requested_size, frames, free_requested, version, user_metadata};
|
||||
}
|
||||
|
||||
function EventSelector(outer, events, stack_info, memory_view) {
|
||||
@ -140,7 +140,9 @@ function eventStack(e, allocated, reserved) {
|
||||
reserved,
|
||||
)} reserved)\n${event}`;
|
||||
}
|
||||
return event + '\n' + format_frames(e.frames);
|
||||
const user_metadata_str = format_user_metadata(e.user_metadata);
|
||||
const frames_str = format_frames(e.frames);
|
||||
return event + '\n' + (user_metadata_str ? user_metadata_str + '\n' : '') + frames_str;
|
||||
}
|
||||
|
||||
function hashCode(num) {
|
||||
@ -216,6 +218,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
seg.stream,
|
||||
seg.frames || [],
|
||||
seg.version,
|
||||
seg.user_metadata,
|
||||
),
|
||||
);
|
||||
for (const b of seg.blocks) {
|
||||
@ -229,6 +232,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
b.frames,
|
||||
b.state === 'active_pending_free',
|
||||
b.version,
|
||||
b.user_metadata,
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -307,6 +311,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
event.frames,
|
||||
false,
|
||||
event.version,
|
||||
event.user_metadata,
|
||||
);
|
||||
break;
|
||||
case 'free_requested':
|
||||
@ -320,6 +325,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
event.frames,
|
||||
true,
|
||||
event.version,
|
||||
event.user_metadata,
|
||||
);
|
||||
break;
|
||||
case 'alloc':
|
||||
@ -335,6 +341,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
event.stream,
|
||||
event.frames,
|
||||
event.version,
|
||||
event.user_metadata,
|
||||
),
|
||||
);
|
||||
break;
|
||||
@ -348,6 +355,7 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
event.stream,
|
||||
event.frames,
|
||||
event.version,
|
||||
event.user_metadata,
|
||||
),
|
||||
);
|
||||
break;
|
||||
@ -426,13 +434,17 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
if (t.internal_free > 0) {
|
||||
internal = ` (${(t.internal_free / free) * 100}% internal)`;
|
||||
}
|
||||
const user_metadata_str = format_user_metadata(t.user_metadata);
|
||||
const frames_str = format_frames(t.frames);
|
||||
return (
|
||||
`s${t.addr.toString(16)}_${t.version}: segment ${formatSize(
|
||||
t.size,
|
||||
)} allocated, ` +
|
||||
`${formatSize(free)} free${internal} (stream ${
|
||||
t.stream
|
||||
})\n${format_frames(t.frames)}`
|
||||
})\n` +
|
||||
(user_metadata_str ? user_metadata_str + '\n' : '') +
|
||||
frames_str
|
||||
);
|
||||
},
|
||||
d => {
|
||||
@ -493,12 +505,15 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
if (t.free_requested) {
|
||||
requested = ' (block freed but waiting due to record_stream)';
|
||||
}
|
||||
const user_metadata_str = format_user_metadata(t.user_metadata);
|
||||
const frames_str = format_frames(t.frames);
|
||||
return (
|
||||
`b${t.addr.toString(16)}_${t.version} ` +
|
||||
`${formatSize(t.requested_size)} allocation${requested} (stream ${
|
||||
t.segment.stream
|
||||
})\n` +
|
||||
format_frames(t.frames)
|
||||
(user_metadata_str ? user_metadata_str + '\n' : '') +
|
||||
frames_str
|
||||
);
|
||||
},
|
||||
removeStroke,
|
||||
@ -524,12 +539,15 @@ function MemoryView(outer, stack_info, snapshot, device) {
|
||||
d => {
|
||||
addStroke(d);
|
||||
const t = d.datum();
|
||||
const user_metadata_str = format_user_metadata(t.user_metadata);
|
||||
const frames_str = format_frames(t.frames);
|
||||
return (
|
||||
`Free space lost due to rounding ${formatSize(
|
||||
t.size - t.requested_size,
|
||||
)}` +
|
||||
` (stream ${t.segment.stream})\n` +
|
||||
format_frames(t.frames)
|
||||
(user_metadata_str ? user_metadata_str + '\n' : '') +
|
||||
frames_str
|
||||
);
|
||||
},
|
||||
removeStroke,
|
||||
@ -760,6 +778,23 @@ function frameFilter({name, filename}) {
|
||||
return true;
|
||||
}
|
||||
|
||||
function format_user_metadata(user_metadata) {
|
||||
if (!user_metadata) {
|
||||
return '';
|
||||
}
|
||||
// Handle string metadata
|
||||
if (typeof user_metadata === 'string') {
|
||||
return `User Metadata:\n ${user_metadata}`;
|
||||
}
|
||||
// Handle object metadata
|
||||
if (typeof user_metadata === 'object' && Object.keys(user_metadata).length === 0) {
|
||||
return '';
|
||||
}
|
||||
const metadata_lines = Object.entries(user_metadata)
|
||||
.map(([key, value]) => ` ${key}: ${value}`);
|
||||
return 'User Metadata:\n' + metadata_lines.join('\n');
|
||||
}
|
||||
|
||||
function format_frames(frames) {
|
||||
if (frames.length === 0) {
|
||||
return (
|
||||
@ -992,6 +1027,10 @@ function process_alloc_data(snapshot, device, plot_segments, max_entries) {
|
||||
if (!elem.action.includes('alloc')) {
|
||||
text = `${text}\nalloc not recorded, stack trace for free:`;
|
||||
}
|
||||
const user_metadata_str = format_user_metadata(elem.user_metadata);
|
||||
if (user_metadata_str) {
|
||||
text = `${text}\n${user_metadata_str}`;
|
||||
}
|
||||
text = `${text}\n${format_frames(elem.frames)}`;
|
||||
return text;
|
||||
},
|
||||
|
@ -351,16 +351,14 @@ class TensorWeakRef:
|
||||
ref: WeakRef[Tensor]
|
||||
|
||||
def __init__(self, tensor: Tensor):
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise AssertionError(f"expected torch.Tensor, got {type(tensor)}.")
|
||||
assert isinstance(tensor, Tensor)
|
||||
self.ref = weakref.ref(tensor)
|
||||
|
||||
def __call__(self):
|
||||
out = self.ref()
|
||||
if out is None:
|
||||
return out
|
||||
if not isinstance(out, Tensor):
|
||||
raise AssertionError(f"expected torch.Tensor, got {type(out)}.")
|
||||
assert isinstance(out, Tensor)
|
||||
# TODO, add _fix_weakref type binding
|
||||
out._fix_weakref() # type: ignore[attr-defined]
|
||||
return out
|
||||
|
@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
|
||||
) -> list[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> str:
|
||||
if f.has_composite_implicit_autograd_kernel:
|
||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||
# Note [Composite view ops in the functionalization pass]
|
||||
# We don't need to worry about implemententing functionalization kernels for views with
|
||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
|
||||
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
|
||||
registration_str = (
|
||||
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
|
||||
)
|
||||
else:
|
||||
# non-composite view ops (and inplace ops) get a normal registration.
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
assert not f.has_composite_implicit_autograd_kernel
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
return f'm.impl("{f.func.name}", {registration_str});'
|
||||
|
||||
# Don't generate kernels in mobile build
|
||||
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
|
||||
if str(g.view.func.name) == "lift_fresh":
|
||||
return []
|
||||
view_str = []
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if g.view_inplace is not None:
|
||||
if not g.view.has_composite_implicit_autograd_kernel:
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if (
|
||||
g.view_inplace is not None
|
||||
and not g.view_inplace.has_composite_implicit_autograd_kernel
|
||||
):
|
||||
assert g.view_inplace.is_view_op
|
||||
view_str.append(emit_registration_helper(g.view_inplace))
|
||||
return view_str
|
||||
|
Reference in New Issue
Block a user