mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Compare commits
115 Commits
yiming/add
...
annotate_f
| Author | SHA1 | Date | |
|---|---|---|---|
| 98826fd37b | |||
| 585b9dbb5e | |||
| d795fb225a | |||
| 7df9aca529 | |||
| d4a713cd9c | |||
| 5daef30b26 | |||
| 6dedd34c31 | |||
| a303d6dda9 | |||
| 7669ac9402 | |||
| 86fd4fc23e | |||
| 99097b6d89 | |||
| a214371008 | |||
| 7d87d7052e | |||
| 1a34ff4e04 | |||
| fe5ccb1a74 | |||
| 85586d7efc | |||
| e1d71a6b35 | |||
| d61a9b88cf | |||
| 99b32a6750 | |||
| 783da8b8e7 | |||
| ed74dc054d | |||
| f33c7e1a43 | |||
| 219fb6aafc | |||
| 515b5ff539 | |||
| 608a6d4a26 | |||
| 03e5dbb26e | |||
| 7ee45f7503 | |||
| e6d9d68598 | |||
| 1a5b7eca7b | |||
| 8573574b32 | |||
| e6033f6efb | |||
| 9272437cde | |||
| f06e669f6c | |||
| 69b05913fb | |||
| d73c283c3a | |||
| eaeaa08e3a | |||
| d0c32971b4 | |||
| d7ffa8b8a2 | |||
| 00afa06800 | |||
| 5d0b22008d | |||
| ab6014a903 | |||
| f6daffc54d | |||
| 66b75693ae | |||
| 21697feff2 | |||
| 12fa4192c5 | |||
| 23fb7e9f4b | |||
| 5e480b8ecf | |||
| 19ba506ca3 | |||
| 003dd13073 | |||
| c2bd41ac9f | |||
| ca8bd5dbed | |||
| 26f3803433 | |||
| 48064acf37 | |||
| e5a9c247bc | |||
| 36371b8ec7 | |||
| 7e6721fb0a | |||
| 901bbcba12 | |||
| febb603230 | |||
| 568d2f3ae7 | |||
| b54e466fd0 | |||
| 53f9ae0e50 | |||
| b42fe389b9 | |||
| 66ea76ec44 | |||
| e787d532b6 | |||
| b3f6d49b69 | |||
| bc1f2108d7 | |||
| f071f17911 | |||
| fa1539594b | |||
| dfc8a1c5dd | |||
| 7f9b745494 | |||
| 83f9baf413 | |||
| ffc7552e01 | |||
| 78f5a1ec60 | |||
| 2b71b62045 | |||
| 8c4b528403 | |||
| 066f818eea | |||
| 14af1dc3da | |||
| 2395d7d7da | |||
| 0aa7ebaf03 | |||
| 7a97832585 | |||
| 84d141e910 | |||
| 7c6c5d04fe | |||
| b509fb9b5d | |||
| 331b7cc054 | |||
| 815d641599 | |||
| ffe3cb226a | |||
| 7ae123d72c | |||
| 7719cb75bf | |||
| 712f54d453 | |||
| f58f301313 | |||
| 5c583e2573 | |||
| 0c14f55de6 | |||
| 8e510e1095 | |||
| 59d30d1b75 | |||
| 3915898c22 | |||
| 3044e1a460 | |||
| b11593c31b | |||
| 36871622f1 | |||
| b4fd47179e | |||
| 4f400ab520 | |||
| 839f6facdb | |||
| ca65023b90 | |||
| 132ae8e6dd | |||
| a20afb6100 | |||
| 47524dcc48 | |||
| 9ffba8a2f9 | |||
| 3681312ce0 | |||
| 7778a58e7c | |||
| e7091a47da | |||
| bcfea48ab7 | |||
| d2e1dbc8f2 | |||
| 89298ada83 | |||
| c467e59cb0 | |||
| bbb902c8dd | |||
| e6f766c7d7 |
@ -187,19 +187,22 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
|
||||
export USE_CUFILE=0
|
||||
else
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/cuda/lib64/libnvToolsExt.so.1"
|
||||
"/usr/local/cuda/lib64/libcublas.so.12"
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12"
|
||||
"/usr/local/cuda/lib64/libcudart.so.12"
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12"
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
|
||||
DEPS_SONAME+=(
|
||||
"libnvToolsExt.so.1"
|
||||
"libcublas.so.12"
|
||||
"libcublasLt.so.12"
|
||||
"libcudart.so.12"
|
||||
"libnvrtc.so.12"
|
||||
"libcupti.so.12")
|
||||
|
||||
if [[ $CUDA_VERSION != 12.9* ]]; then
|
||||
DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
|
||||
DEPS_SONAME+=("libnvToolsExt.so.1")
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Using nvidia libs from pypi."
|
||||
|
||||
@ -1615,6 +1615,7 @@ test_operator_benchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
ARCH=$(uname -m)
|
||||
|
||||
test_inductor_set_cpu_affinity
|
||||
|
||||
@ -1629,7 +1630,7 @@ test_operator_benchmark() {
|
||||
pip_install pandas
|
||||
python check_perf_csv.py \
|
||||
--actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \
|
||||
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
|
||||
--expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv"
|
||||
}
|
||||
|
||||
test_operator_microbenchmark() {
|
||||
|
||||
@ -65,7 +65,7 @@ runs:
|
||||
cd .ci/lumen_cli
|
||||
python3 -m pip install -e .
|
||||
)
|
||||
MAX_JOBS="$(nproc --ignore=6)"
|
||||
MAX_JOBS="$(nproc --ignore=10)"
|
||||
export MAX_JOBS
|
||||
|
||||
# Split the comma-separated list and build each target
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
8ad2aa5d354d1bf432339113860185d5a5d1abbd
|
||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
|
||||
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||
|
||||
12
.github/scripts/generate_binary_build_matrix.py
vendored
12
.github/scripts/generate_binary_build_matrix.py
vendored
@ -241,7 +241,11 @@ def generate_libtorch_matrix(
|
||||
arches += CUDA_ARCHES
|
||||
arches += ROCM_ARCHES
|
||||
elif os == "windows":
|
||||
arches += CUDA_ARCHES
|
||||
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
|
||||
# in 2.10
|
||||
windows_cuda_arches = CUDA_ARCHES.copy()
|
||||
windows_cuda_arches.remove("12.9")
|
||||
arches += windows_cuda_arches
|
||||
if libtorch_variants is None:
|
||||
libtorch_variants = [
|
||||
"shared-with-deps",
|
||||
@ -305,7 +309,11 @@ def generate_wheels_matrix(
|
||||
if os == "linux":
|
||||
arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES
|
||||
elif os == "windows":
|
||||
arches += CUDA_ARCHES + XPU_ARCHES
|
||||
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
|
||||
# in 2.10
|
||||
windows_cuda_arches = CUDA_ARCHES.copy()
|
||||
windows_cuda_arches.remove("12.9")
|
||||
arches += windows_cuda_arches + XPU_ARCHES
|
||||
elif os == "linux-aarch64":
|
||||
# Separate new if as the CPU type is different and
|
||||
# uses different build/test scripts
|
||||
|
||||
2
.github/workflows/_linux-build.yml
vendored
2
.github/workflows/_linux-build.yml
vendored
@ -37,7 +37,7 @@ on:
|
||||
runner:
|
||||
required: false
|
||||
type: string
|
||||
default: "linux.2xlarge"
|
||||
default: "linux.c7i.2xlarge"
|
||||
description: |
|
||||
Label of the runner this job should run on.
|
||||
test-matrix:
|
||||
|
||||
19
.github/workflows/build-vllm-wheel.yml
vendored
19
.github/workflows/build-vllm-wheel.yml
vendored
@ -27,9 +27,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [ '3.12' ]
|
||||
# TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
|
||||
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
|
||||
device: [ 'cu128', 'cu129' ]
|
||||
device: [ 'cu128', 'cu129', 'cu130' ]
|
||||
include:
|
||||
- platform: manylinux_2_28_x86_64
|
||||
device: cu128
|
||||
@ -39,6 +38,10 @@ jobs:
|
||||
device: cu129
|
||||
manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
|
||||
runner: linux.12xlarge.memory
|
||||
- platform: manylinux_2_28_x86_64
|
||||
device: cu130
|
||||
manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
|
||||
runner: linux.12xlarge.memory
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu128
|
||||
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
|
||||
@ -47,6 +50,11 @@ jobs:
|
||||
device: cu129
|
||||
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
|
||||
runner: linux.arm64.r7g.12xlarge.memory
|
||||
exclude:
|
||||
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
|
||||
# xformers is update to support 13.0
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu130
|
||||
name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
|
||||
runs-on: ${{ matrix.runner }}
|
||||
timeout-minutes: 480
|
||||
@ -169,7 +177,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
|
||||
device: [ 'cu128', 'cu129' ]
|
||||
device: [ 'cu128', 'cu129', 'cu130' ]
|
||||
exclude:
|
||||
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
|
||||
# xformers is update to support 13.0
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu130
|
||||
env:
|
||||
PLATFORM: ${{ matrix.platform }}
|
||||
BUILD_DEVICE: ${{ matrix.device }}
|
||||
|
||||
250
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
250
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
@ -788,256 +788,6 @@ jobs:
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda12_9-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-debug
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cuda12_9-shared-with-deps-debug-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cuda12_9-shared-with-deps-debug-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-debug
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
libtorch-cuda12_9-shared-with-deps-debug-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: libtorch-cuda12_9-shared-with-deps-debug-test
|
||||
with:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
build_name: libtorch-cuda12_9-shared-with-deps-debug
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda13_0-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
|
||||
250
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
250
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
@ -788,256 +788,6 @@ jobs:
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda12_9-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-release
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cuda12_9-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cuda12_9-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-release
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: libtorch-cuda12_9-shared-with-deps-release-test
|
||||
with:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
build_name: libtorch-cuda12_9-shared-with-deps-release
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda13_0-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
|
||||
1666
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
1666
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
4
.github/workflows/lint.yml
vendored
4
.github/workflows/lint.yml
vendored
@ -118,9 +118,9 @@ jobs:
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
echo "Running all other linters"
|
||||
if [ "$CHANGED_FILES" = '*' ]; then
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||
else
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
fi
|
||||
|
||||
quick-checks:
|
||||
|
||||
37
.github/workflows/operator_benchmark.yml
vendored
37
.github/workflows/operator_benchmark.yml
vendored
@ -7,9 +7,11 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test_mode:
|
||||
required: false
|
||||
type: string
|
||||
default: 'short'
|
||||
type: choice
|
||||
options:
|
||||
- 'short'
|
||||
- 'long'
|
||||
- 'all'
|
||||
description: tag filter for operator benchmarks, options from long, short, all
|
||||
schedule:
|
||||
# Run at 07:00 UTC every Sunday
|
||||
@ -28,38 +30,25 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
opbenchmark-build:
|
||||
x86-opbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opbenchmark-build
|
||||
name: x86-opbenchmark-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opbenchmark-on-demand-build:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
|
||||
name: opbenchmark-on-demand-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opbenchmark-test:
|
||||
name: opbenchmark-test
|
||||
x86-opbenchmark-test:
|
||||
name: x86-opbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opbenchmark-build
|
||||
needs: x86-opbenchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }}
|
||||
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
@ -209,6 +209,46 @@ command = [
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
|
||||
[[linter]]
|
||||
code = 'PYREFLY'
|
||||
include_patterns = [
|
||||
'torch/**/*.py',
|
||||
'torch/**/*.pyi',
|
||||
'torchgen/**/*.py',
|
||||
'torchgen/**/*.pyi',
|
||||
'functorch/**/*.py',
|
||||
'functorch/**/*.pyi',
|
||||
]
|
||||
exclude_patterns = []
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pyrefly_linter.py',
|
||||
'--config=pyrefly.toml',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
'sympy==1.13.3',
|
||||
'types-requests==2.27.25',
|
||||
'types-pyyaml==6.0.2',
|
||||
'types-tabulate==0.8.8',
|
||||
'types-protobuf==5.29.1.20250403',
|
||||
'types-setuptools==79.0.0.20250422',
|
||||
'types-jinja2==2.11.9',
|
||||
'types-colorama==0.4.6',
|
||||
'filelock==3.18.0',
|
||||
'junitparser==2.1.1',
|
||||
'rich==14.1.0',
|
||||
'optree==0.17.0',
|
||||
'types-openpyxl==3.1.5.20250919',
|
||||
'types-python-dateutil==2.9.0.20251008'
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'CLANGTIDY'
|
||||
include_patterns = [
|
||||
|
||||
@ -256,6 +256,7 @@ endif()
|
||||
IF(USE_FBGEMM_GENAI)
|
||||
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
|
||||
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
|
||||
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
@ -292,58 +293,64 @@ IF(USE_FBGEMM_GENAI)
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
else()
|
||||
if(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
elseif(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
endif()
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
endif()
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -692,12 +699,6 @@ if(USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
if(USE_FBGEMM_GENAI)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
namespace at::autocast {
|
||||
@ -36,10 +37,29 @@ namespace {
|
||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||
using val_type = std::tuple<weakref_type, Tensor>;
|
||||
|
||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
|
||||
return cached_casts;
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
|
||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
|
||||
// are not incorrectly reused in gradient-enabled contexts.
|
||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
|
||||
return cached_casts_grad_enabled;
|
||||
}
|
||||
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
|
||||
return cached_casts_grad_disabled;
|
||||
}
|
||||
|
||||
// Helper function to get the appropriate cache based on current gradient mode.
|
||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
|
||||
// preventing incorrect cache hits when gradient mode changes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
return at::GradMode::is_enabled() ?
|
||||
get_cached_casts_grad_enabled() :
|
||||
get_cached_casts_grad_disabled();
|
||||
}
|
||||
|
||||
std::mutex cached_casts_mutex;
|
||||
|
||||
|
||||
@ -86,7 +106,9 @@ thread_local bool cache_enabled = true;
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
get_cached_casts().clear();
|
||||
// Clear both caches to ensure consistent behavior regardless of current gradient mode
|
||||
get_cached_casts_grad_enabled().clear();
|
||||
get_cached_casts_grad_disabled().clear();
|
||||
}
|
||||
|
||||
int increment_nesting() {
|
||||
@ -121,6 +143,11 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
||||
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
||||
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
||||
// See cached_casts declaration above for detailed strategy.
|
||||
//
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes
|
||||
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
|
||||
// with torch.autocast(), while maintaining optimal performance for both training and inference.
|
||||
// This fixes issue #158232 without any performance regression.
|
||||
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
|
||||
BLASType = "unknown";
|
||||
}
|
||||
return BLASType;
|
||||
|
||||
}
|
||||
|
||||
// Similar to Compute Type in GemmRocblas.h
|
||||
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
|
||||
|
||||
namespace detail {
|
||||
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
|
||||
|
||||
if (!config.enabled) {
|
||||
return true; // skip when disabled
|
||||
}
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
||||
// comparison done as 1D tensor
|
||||
at::Tensor ref = at::from_blob(c, {size}, options);
|
||||
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
||||
at::Tensor ref_float = ref.to(at::kFloat);
|
||||
at::Tensor oth_float = oth.to(at::kFloat);
|
||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
double last_succeed_atol = 1;
|
||||
double last_succeed_rtol = 1;
|
||||
for (auto& atol : atols) {
|
||||
for (auto& rtol : rtols) {
|
||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
||||
last_succeed_atol = atol;
|
||||
last_succeed_rtol = rtol;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_succeed_atol == 1) {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
||||
}
|
||||
|
||||
return true;
|
||||
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
|
||||
if (ok) {
|
||||
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
} else {
|
||||
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
}
|
||||
return ok;
|
||||
}
|
||||
|
||||
}
|
||||
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
|
||||
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
|
||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
|
||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
|
||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
|
||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
|
||||
@ -173,10 +173,9 @@ All python APIs exist in the `torch.cuda.tunable` module.
|
||||
| get_max_tuning_iterations() -> int | |
|
||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
|
||||
| get_filename() -> str | |
|
||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
|
||||
| get_results() -> Tuple[str, str, str, float] | |
|
||||
| get_validators() -> Tuple[str, str] | |
|
||||
| write_file_on_exit(val: bool) -> None | Default is True. |
|
||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
|
||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
|
||||
|
||||
@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
|
||||
}
|
||||
|
||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
|
||||
std::scoped_lock l{lock_};
|
||||
bool is_new = false;
|
||||
ResultEntry inserted = ResultEntry::Null();
|
||||
|
||||
auto it = results_.find(op_signature);
|
||||
if (it == results_.end()) {
|
||||
it = results_.insert({op_signature, {}}).first;
|
||||
// ---- mutate maps under results lock ----
|
||||
{
|
||||
std::scoped_lock l{lock_};
|
||||
auto& km = results_[op_signature]; // creates if missing
|
||||
is_new = (km.find(params_signature) == km.end());
|
||||
AddImpl(op_signature, params_signature, std::move(best), km);
|
||||
if (is_new) {
|
||||
inserted = km.at(params_signature); // snapshot for I/O after unlocking
|
||||
}
|
||||
}
|
||||
if (!is_new) return; // only write once per unique (op, params)
|
||||
|
||||
TuningContext* ctx = getTuningContext();
|
||||
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
|
||||
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
|
||||
|
||||
if (is_new && realtime_out_ && realtime_out_->good()) {
|
||||
AppendResultLine(op_signature, params_signature, inserted);
|
||||
}
|
||||
}
|
||||
|
||||
AddImpl(op_signature, params_signature, std::move(best), it->second);
|
||||
}
|
||||
|
||||
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
||||
@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (realtime_out_ && realtime_filename_ != filename) {
|
||||
realtime_out_->flush();
|
||||
realtime_out_->close();
|
||||
realtime_out_.reset();
|
||||
validators_written_ = false;
|
||||
}
|
||||
|
||||
bool file_exists = false;
|
||||
bool file_empty = true;
|
||||
|
||||
{
|
||||
std::ifstream check_file(filename);
|
||||
if (check_file.good()) {
|
||||
file_exists = true;
|
||||
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
|
||||
}
|
||||
}
|
||||
|
||||
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
|
||||
|
||||
if (!realtime_out_->good()) {
|
||||
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
|
||||
realtime_out_.reset();
|
||||
return;
|
||||
}
|
||||
|
||||
if(!file_exists || file_empty) {
|
||||
for(const auto& [key, val] : validators) {
|
||||
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
|
||||
realtime_out_->flush();
|
||||
}
|
||||
validators_written_ = true;
|
||||
|
||||
TUNABLE_LOG2("Wrote validators to realtime output file");
|
||||
}
|
||||
|
||||
realtime_filename_ = filename;
|
||||
}
|
||||
|
||||
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
if(!realtime_out_ || !realtime_out_->good()) {
|
||||
return;
|
||||
}
|
||||
|
||||
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
realtime_out_->flush(); //ensure immediate write to disk
|
||||
|
||||
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
|
||||
}
|
||||
|
||||
void TuningResultsManager::CloseRealtimeAppend() {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
|
||||
if(realtime_out_) {
|
||||
realtime_out_->flush();
|
||||
realtime_out_->close();
|
||||
realtime_out_.reset();
|
||||
TUNABLE_LOG2("Closed realtime output file");
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
|
||||
std::scoped_lock l{lock_};
|
||||
|
||||
@ -396,7 +483,6 @@ TuningContext::TuningContext() :
|
||||
tuning_enable_{true},
|
||||
record_untuned_enable_{false},
|
||||
manager_initialized_{false},
|
||||
write_file_on_exit_{true},
|
||||
numerics_check_enable_{false},
|
||||
max_tuning_duration_ms_{30},
|
||||
max_tuning_iterations_{100},
|
||||
@ -417,20 +503,8 @@ TuningContext::~TuningContext() {
|
||||
// but doesn't do any computation itself.
|
||||
return;
|
||||
}
|
||||
auto filename = GetFilename();
|
||||
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
|
||||
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
|
||||
if (results_count_from_input_file_ > 0) {
|
||||
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG1("writing file ", filename);
|
||||
}
|
||||
if (!WriteFile(filename)) {
|
||||
TUNABLE_LOG1("failed to write file ", filename);
|
||||
}
|
||||
}
|
||||
}
|
||||
TUNABLE_LOG1("Closing File");
|
||||
GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
|
||||
|
||||
if (untuned_file_.good()) {
|
||||
untuned_file_.close();
|
||||
@ -511,20 +585,54 @@ std::ofstream& TuningContext::GetUntunedFile(){
|
||||
return untuned_file_;
|
||||
}
|
||||
|
||||
void TuningContext::WriteFileOnExit(bool value) {
|
||||
write_file_on_exit_ = value;
|
||||
}
|
||||
|
||||
void TuningContext::EnableNumericsCheck(bool value) {
|
||||
numerics_check_enable_ = value;
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env == "1") {
|
||||
return true;
|
||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
|
||||
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
|
||||
if (!env_opt.has_value()) {
|
||||
return numerics_cfg_;
|
||||
}
|
||||
return numerics_check_enable_;
|
||||
|
||||
const std::string& env = env_opt.value();
|
||||
|
||||
if (env == "0") {
|
||||
return NumericalCheckConfig(false, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
const size_t underscore = env.find('_');
|
||||
|
||||
TORCH_CHECK(
|
||||
underscore != std::string::npos,
|
||||
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
|
||||
"Expected 'atol_rtol', got: ",
|
||||
env);
|
||||
|
||||
double atol = 0.0;
|
||||
double rtol = 0.0;
|
||||
|
||||
try {
|
||||
atol = std::stod(env.substr(0, underscore));
|
||||
rtol = std::stod(env.substr(underscore + 1));
|
||||
} catch (const std::exception& e) {
|
||||
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
|
||||
}
|
||||
|
||||
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
|
||||
return NumericalCheckConfig(true, atol, rtol);
|
||||
}
|
||||
|
||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
|
||||
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
|
||||
numerics_cfg_ = {enabled, atol, rtol};
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto cfg = GetNumericalCheckConfig();
|
||||
return cfg.enabled || numerics_check_enable_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
||||
@ -634,11 +742,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
|
||||
auto filename = GetFilename();
|
||||
if (!filename.empty() && !IsRecordUntunedEnabled()) {
|
||||
ReadFile(filename);
|
||||
// attempt immediately to open file for writing to catch errors early
|
||||
std::ofstream file(filename, std::ios::out | std::ios::app);
|
||||
if (!file.good()) {
|
||||
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
|
||||
}
|
||||
}
|
||||
});
|
||||
return manager_;
|
||||
@ -744,27 +847,6 @@ bool TuningContext::ReadFile(const std::string& filename_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TuningContext::WriteFile(const std::string& filename_) {
|
||||
std::string filename = filename_.empty() ? GetFilename() : filename_;
|
||||
std::ofstream file(filename, std::ios::out | std::ios::trunc);
|
||||
if (!file.good()) {
|
||||
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
|
||||
return false;
|
||||
}
|
||||
auto validators = GetTuningResultsValidator().GetAllValidators();
|
||||
for (const auto& [key, val] : validators) {
|
||||
file << "Validator," << key << "," << val << std::endl;
|
||||
}
|
||||
auto results = GetTuningResultsManager().Dump();
|
||||
for (const auto& [op_sig, kernelmap] : results) {
|
||||
for (const auto& [param_sig, result] : kernelmap) {
|
||||
file << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
}
|
||||
}
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct MaybeDelete {
|
||||
|
||||
@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
|
||||
|
||||
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
||||
const std::string& params_signature, const std::string& blas_signature);
|
||||
|
||||
void InitRealtimeAppend(
|
||||
const std::string& filename,
|
||||
const std::unordered_map<std::string, std::string>& validators);
|
||||
|
||||
void AppendResultLine(const std::string& op_sig,
|
||||
const std::string& param_sig,
|
||||
const ResultEntry& result);
|
||||
|
||||
void CloseRealtimeAppend(); // For clean shutdown
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::mutex realtime_file_mutex_;
|
||||
std::unique_ptr<std::ofstream> realtime_out_;
|
||||
std::string realtime_filename_;
|
||||
ResultsMap results_;
|
||||
UntunedMap untuned_results_;
|
||||
bool validators_written_ = false;
|
||||
|
||||
};
|
||||
|
||||
@ -134,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
|
||||
GetValidateFuncs validators_;
|
||||
};
|
||||
|
||||
struct NumericalCheckConfig {
|
||||
bool enabled{false};
|
||||
double atol{1e-5};
|
||||
double rtol{1e-5};
|
||||
|
||||
NumericalCheckConfig() = default;
|
||||
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
|
||||
};
|
||||
|
||||
|
||||
class TORCH_CUDA_CPP_API TuningContext {
|
||||
public:
|
||||
TuningContext();
|
||||
@ -155,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
|
||||
void EnableNumericsCheck(bool value);
|
||||
bool IsNumericsCheckEnabled() const;
|
||||
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
|
||||
NumericalCheckConfig GetNumericalCheckConfig() const;
|
||||
|
||||
void SetMaxTuningDurationMs(int max_duration_ms);
|
||||
int GetMaxTuningDurationMs() const;
|
||||
@ -185,10 +211,7 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
|
||||
std::string GetFilename() const;
|
||||
|
||||
void WriteFileOnExit(bool value);
|
||||
|
||||
bool ReadFile(const std::string& filename={});
|
||||
bool WriteFile(const std::string& filename={});
|
||||
|
||||
template<class... Types>
|
||||
void Log(int level, Types... args) {
|
||||
@ -207,7 +230,6 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
bool tuning_enable_;
|
||||
bool record_untuned_enable_;
|
||||
bool manager_initialized_;
|
||||
bool write_file_on_exit_;
|
||||
bool numerics_check_enable_;
|
||||
int max_tuning_duration_ms_;
|
||||
int max_tuning_iterations_;
|
||||
@ -222,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
std::ofstream untuned_file_;
|
||||
size_t results_count_from_input_file_;
|
||||
bool is_shutting_down_;
|
||||
|
||||
NumericalCheckConfig numerics_cfg_{};
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
||||
|
||||
@ -267,27 +267,10 @@ class TunableOp {
|
||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
// collect a small profile
|
||||
@ -310,6 +293,22 @@ class TunableOp {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// for warmup does user set max duration, max iters, or both?
|
||||
// warmup is skipped by default, i.e. warmup_iter = 0
|
||||
// warmup will be set to the non-zero value of max_warmup_duration
|
||||
|
||||
@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
|
||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
||||
}
|
||||
|
||||
// TODO: replace with targetable functionalization
|
||||
// uses functional formulation for one_hot under vmap to be compatible with
|
||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
|
||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
|
||||
// but requires explicit positive num_classes under vmap to avoid
|
||||
// data-dependent output shapes.
|
||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
|
||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
// disallow implicit inference under vmap; this would be data-dependent
|
||||
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
|
||||
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
|
||||
"provide an explicit positive num_classes argument.");
|
||||
|
||||
// Disabling all of the following checks. This is OK because scatter has checks too.
|
||||
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
|
||||
// // non-empty tensor
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //for cuda, rely on device assert thrown by scatter
|
||||
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
|
||||
// }
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //rely on device asserts from scatter to avoid sync here
|
||||
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
||||
// }
|
||||
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
return ret.scatter(-1, self.unsqueeze(-1), 1);
|
||||
const auto options = self.options();
|
||||
at::Tensor index = at::arange(num_classes, options);
|
||||
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
|
||||
}
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
|
||||
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = self.sizes().vec();
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.numel() == 0) {
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.push_back(num_classes);
|
||||
return at::empty(shape, self.options());
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
shape.push_back(num_classes);
|
||||
Tensor ret = at::zeros(shape, self.options());
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1906,11 +1906,9 @@ Tensor& index_fill_(
|
||||
"This also applies to advanced indexing e.g. tensor[mask] = scalar");
|
||||
}
|
||||
|
||||
if (!self.is_complex() && source.isComplex()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
self.is_complex() || !source.isComplex(),
|
||||
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
|
||||
|
||||
// Handle the case when `self` is 0-dim
|
||||
Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
|
||||
|
||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
[&]() {
|
||||
using scalar_t =
|
||||
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
|
||||
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
cpu_kernel_vec(iter,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
if (ret_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg0_t == rt_binary_specializations[arg_index][1] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][2])
|
||||
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
|
||||
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
|
||||
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
|
||||
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
|
||||
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
|
||||
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
|
||||
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][2]>::t)>(
|
||||
cret_t,
|
||||
carg0_t,
|
||||
carg1_t>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -38,12 +38,41 @@ __device__ inline int min(int a, int b) {
|
||||
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
|
||||
#endif
|
||||
|
||||
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
|
||||
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
|
||||
template <typename index_t>
|
||||
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
|
||||
const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
|
||||
return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
|
||||
}
|
||||
|
||||
static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
|
||||
return min((size + pad) / stride + 1, pooled_size);
|
||||
template <typename index_t>
|
||||
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
|
||||
return std::min((size + pad) / stride + 1, pooled_size);
|
||||
}
|
||||
|
||||
static inline bool can_use_int32_nhwc(
|
||||
int64_t nbatch, int64_t channels,
|
||||
int64_t height, int64_t width,
|
||||
int64_t pooled_height, int64_t pooled_width,
|
||||
int64_t in_stride_n, int64_t in_stride_c,
|
||||
int64_t in_stride_h, int64_t in_stride_w)
|
||||
{
|
||||
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
||||
|
||||
int64_t max_intra_batch =
|
||||
(height ? (height - 1) * in_stride_h : 0) +
|
||||
(width ? (width - 1) * in_stride_w : 0) +
|
||||
(channels? (channels - 1) * in_stride_c : 0);
|
||||
|
||||
int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;
|
||||
|
||||
if (max_input_offset > int_max) return false;
|
||||
|
||||
int64_t out_batch_stride = pooled_height * pooled_width * channels;
|
||||
if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;
|
||||
|
||||
if (height * width > int_max) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
@ -85,21 +114,25 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename index_t>
|
||||
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
|
||||
__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
|
||||
const int64_t channels, const int64_t height,
|
||||
const int64_t width, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int in_stride_n, const int in_stride_c,
|
||||
const int in_stride_h, const int in_stride_w,
|
||||
const int kernel_stride_C, const int kernel_size_C,
|
||||
scalar_t* top_data, int64_t* top_mask) {
|
||||
extern __shared__ int smem[];
|
||||
int *out_mask_cached = smem;
|
||||
scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
|
||||
__global__ void max_pool_forward_nhwc(
|
||||
const scalar_t* bottom_data,
|
||||
const int nbatch,
|
||||
const index_t channels, const index_t height, const index_t width,
|
||||
const index_t pooled_height, const index_t pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const index_t in_stride_n, const index_t in_stride_c,
|
||||
const index_t in_stride_h, const index_t in_stride_w,
|
||||
const int kernel_stride_C, const int kernel_size_C,
|
||||
scalar_t* top_data, int64_t* top_mask) {
|
||||
|
||||
extern __shared__ unsigned char smem_raw[];
|
||||
index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
|
||||
scalar_t *out_cached = reinterpret_cast<scalar_t*>(
|
||||
out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);
|
||||
|
||||
// flattening cta for pre-computation & smem initialization;
|
||||
int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
|
||||
@ -118,26 +151,26 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
|
||||
int channel_id = blockIdx.x / nbatch;
|
||||
int channel_offset = threadIdx.x + channel_id * blockDim.x;
|
||||
|
||||
top_data = top_data + batch_id * pooled_height * pooled_width * channels;
|
||||
top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
|
||||
bottom_data = bottom_data + batch_id * in_stride_n;
|
||||
top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
|
||||
top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
|
||||
bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;
|
||||
|
||||
out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
|
||||
out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
|
||||
out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
|
||||
out_mask_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
|
||||
|
||||
int oH = (pooled_height + gridDim.z-1) / gridDim.z;
|
||||
int oW = (pooled_width + gridDim.y-1) / gridDim.y;
|
||||
int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
|
||||
int oW = (static_cast<int>(pooled_width) + gridDim.y - 1) / gridDim.y;
|
||||
int ostartH = threadIdx.z + blockIdx.z*oH;
|
||||
int oendH = ::min(ostartH+oH, pooled_height);
|
||||
int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
|
||||
int ostartW = threadIdx.y + blockIdx.y*oW;
|
||||
int oendW = ::min(ostartW+oW, pooled_width);
|
||||
int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));
|
||||
|
||||
for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
|
||||
int hstart = oh * stride_h - pad_h;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
|
||||
index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
|
||||
for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
|
||||
int wstart = ow * stride_w - pad_w;
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
|
||||
index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
@ -185,12 +218,12 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
|
||||
// Else do it Non-Prefetch...
|
||||
else
|
||||
#endif
|
||||
for (int ih = hstart; ih < hend; ih += dilation_h) {
|
||||
for (int iw = wstart; iw < wend; iw += dilation_w) {
|
||||
for (index_t ih = hstart; ih < hend; ih += dilation_h) {
|
||||
for (index_t iw = wstart; iw < wend; iw += dilation_w) {
|
||||
int cached_index = threadIdx.x;
|
||||
const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
|
||||
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
|
||||
scalar_t val = ptr_input[c*in_stride_c];
|
||||
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
|
||||
scalar_t val = ptr_input[c * in_stride_c];
|
||||
if ((val > out_cached[cached_index]) || at::_isnan(val)) {
|
||||
out_cached[cached_index] = val;
|
||||
out_mask_cached[cached_index] = ih * width + iw;
|
||||
@ -200,15 +233,15 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
|
||||
int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
|
||||
scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
|
||||
int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
|
||||
|
||||
int cached_index = threadIdx.x;
|
||||
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
|
||||
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
|
||||
ptr_output_data[c] = out_cached[cached_index];
|
||||
ptr_output_mask[c] = out_mask_cached[cached_index];
|
||||
ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
|
||||
out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
|
||||
out_mask_cached[cached_index] = 0;
|
||||
out_mask_cached[cached_index] = index_t(0);
|
||||
cached_index += blockDim.x;
|
||||
}
|
||||
}
|
||||
@ -462,6 +495,11 @@ const Tensor& indices) {
|
||||
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
|
||||
const dim3 block(block_x, block_y, block_z);
|
||||
|
||||
bool use_int32 = can_use_int32_nhwc(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
in_stride_n, in_stride_c, in_stride_h, in_stride_w);
|
||||
|
||||
int kernel_stride_C = ceil_div(
|
||||
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
|
||||
int kernel_size_C = ceil_div(
|
||||
@ -476,18 +514,41 @@ const Tensor& indices) {
|
||||
ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
|
||||
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
|
||||
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
|
||||
size_t shmem_size;
|
||||
size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;
|
||||
|
||||
max_pool_forward_nhwc<scalar_t>
|
||||
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input_data, nbatch,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
in_stride_n, in_stride_c,
|
||||
in_stride_h, in_stride_w,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
output_data, indices_data);
|
||||
if (use_int32) {
|
||||
shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
|
||||
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
|
||||
"shared memory too small");
|
||||
max_pool_forward_nhwc<scalar_t, int32_t>
|
||||
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input_data, static_cast<int>(nbatch),
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
static_cast<int32_t>(in_stride_n),
|
||||
static_cast<int32_t>(in_stride_c),
|
||||
static_cast<int32_t>(in_stride_h),
|
||||
static_cast<int32_t>(in_stride_w),
|
||||
kernel_stride_C, kernel_size_C,
|
||||
output_data, indices_data);
|
||||
} else {
|
||||
shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
|
||||
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
|
||||
"shared memory too small");
|
||||
max_pool_forward_nhwc<scalar_t, int64_t>
|
||||
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input_data, static_cast<int>(nbatch),
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
in_stride_n, in_stride_c, in_stride_h, in_stride_w,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
output_data, indices_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
|
||||
@ -655,8 +655,14 @@ struct ReduceOp {
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
|
||||
// matching Triton, etc.
|
||||
// todo for AMD
|
||||
#ifdef USE_ROCM
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = ops.warp_shfl_down(value[i], offset);
|
||||
|
||||
@ -77,8 +77,8 @@ struct nansum_functor_complex {
|
||||
#if AT_USE_JITERATOR()
|
||||
void operator()(TensorIterator& iter) {
|
||||
std::string func = jiterator_stringify(
|
||||
arg_t combine(arg_t a, scalar_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
|
||||
arg_t combine(arg_t a, arg_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : b);
|
||||
}
|
||||
);
|
||||
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
|
||||
|
||||
@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
}
|
||||
#endif
|
||||
int32_t trailingSize;
|
||||
int nDimsLocal = nDims;
|
||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
||||
if (isInOutAligned) {
|
||||
// in this case we can and should flatten the tensors after the cat dim
|
||||
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
||||
// for input, we will fix up the sizes and strides in the kernel directly
|
||||
kernelOutputParam = outputParam;
|
||||
nDims = dimension + 1;
|
||||
nDimsLocal = dimension + 1;
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
||||
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
||||
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
||||
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
cat_dim = nDims - cat_dim;
|
||||
cat_dim = nDimsLocal - cat_dim;
|
||||
break;
|
||||
default:
|
||||
cat_dim--;
|
||||
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
}\
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
switch (nDims) {
|
||||
switch (nDimsLocal) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
|
||||
@ -21,9 +21,15 @@ namespace {
|
||||
struct offset_t {
|
||||
int stride;
|
||||
int begin;
|
||||
__device__ int operator[](int i) {
|
||||
__device__ int operator[](int i) const {
|
||||
return stride * (begin + i);
|
||||
}
|
||||
#if CCCL_VERSION >= 3001000
|
||||
__device__ offset_t& operator+=(int i) {
|
||||
begin += i;
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
// Segmented sort by full sort algorithm:.
|
||||
// Say we are sorting a (2, 3) tensor. We have in flattened form:
|
||||
|
||||
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
||||
@ -466,7 +466,11 @@ struct ReduceJitOp {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = reducer::warp_shfl_down(value[i], offset);
|
||||
|
||||
@ -160,8 +160,12 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
|
||||
}
|
||||
|
||||
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
|
||||
cpuinfo_has_x86_amx_fp16();
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
|
||||
cpuinfo_has_x86_amx_fp16();
|
||||
#else
|
||||
return false; //TF32 not supported on power system
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
|
||||
|
||||
@ -74,8 +74,12 @@ static bool use_mkldnn_bf32_linear() {
|
||||
}
|
||||
|
||||
static bool use_mkldnn_tf32_linear() {
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
|
||||
cpuinfo_has_x86_amx_fp16();
|
||||
#else
|
||||
return false; // TF32 not supported on power system
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor mkldnn_linear(
|
||||
|
||||
@ -114,8 +114,13 @@ static bool use_mkldnn_bf32_matmul() {
|
||||
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
|
||||
}
|
||||
|
||||
|
||||
static bool use_mkldnn_tf32_matmul() {
|
||||
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
|
||||
#else
|
||||
return false; // TF32 not supported on power system
|
||||
#endif
|
||||
}
|
||||
|
||||
// returns an ideep::tensor
|
||||
|
||||
@ -712,7 +712,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) {
|
||||
} else if (scalar.isBoolean()) {
|
||||
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool));
|
||||
} else if (scalar.isComplex()) {
|
||||
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble));
|
||||
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexFloat));
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(scalar.isIntegral(false));
|
||||
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong));
|
||||
|
||||
@ -54,6 +54,10 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSBinaryCachedGraph;
|
||||
|
||||
if (self.numel() == 0 & other.numel() == 0) {
|
||||
return zeros({}, self.options());
|
||||
}
|
||||
|
||||
dot_check(self, other);
|
||||
|
||||
auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
|
||||
|
||||
@ -907,6 +907,8 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te
|
||||
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
|
||||
"index_fill_(): Expected dtype int32 or int64 for index");
|
||||
TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
TORCH_CHECK(self.is_complex() || !source.is_complex(),
|
||||
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
|
||||
// MPS.scatter crashes if used with complex dtypes
|
||||
TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported");
|
||||
|
||||
|
||||
@ -7183,6 +7183,12 @@
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda_v2
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -7378,7 +7384,7 @@
|
||||
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_mask
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_mask
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed
|
||||
autogen: sparse_mask.out
|
||||
|
||||
|
||||
@ -184,15 +184,23 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
|
||||
0 & \text{ else }
|
||||
\end{cases}
|
||||
*/
|
||||
float scale_val = scale[0].item<float>();
|
||||
float inv_scale_val = 1.0f / scale_val;
|
||||
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
|
||||
|
||||
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
|
||||
bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
|
||||
|
||||
at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
|
||||
at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
|
||||
at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
|
||||
at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
|
||||
|
||||
float scale_val = scale_[0].item<float>();
|
||||
float inv_scale_val = 1.0f / scale_val;
|
||||
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false);
|
||||
|
||||
TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size");
|
||||
TORCH_CHECK(
|
||||
quant_min <= 0 && quant_max >= 0,
|
||||
"`quant_min` should be less than or \
|
||||
@ -200,28 +208,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
|
||||
TORCH_CHECK(
|
||||
zero_point_val >= quant_min && zero_point_val <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
if (X.numel() <= 0) {
|
||||
if (X_.numel() <= 0) {
|
||||
return std::make_tuple(X, scale, zero_point);
|
||||
}
|
||||
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
.add_output(dX)
|
||||
.add_output(dScale_vec)
|
||||
.add_output(dZeroPoint_vec)
|
||||
.add_input(X)
|
||||
.add_input(dY)
|
||||
.add_input(X_)
|
||||
.add_input(dY_)
|
||||
.build();
|
||||
|
||||
fake_quant_grad_learnable_tensor_stub(
|
||||
X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
|
||||
X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
|
||||
|
||||
// The total sums over the scale and zero point gradient vectors are what will be returned in the end.
|
||||
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
|
||||
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
|
||||
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device());
|
||||
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device());
|
||||
|
||||
return std::make_tuple(dX, dScale, dZeroPoint);
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/SparseTensorUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/sparse/SparseStubs.h>
|
||||
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -13,6 +15,8 @@
|
||||
#include <ATen/ops/mul_native.h>
|
||||
#include <ATen/ops/empty_native.h>
|
||||
#include <ATen/ops/zeros_native.h>
|
||||
#include <ATen/ops/ones_like.h>
|
||||
#include <ATen/ops/argsort.h>
|
||||
#include <ATen/ops/result_type.h>
|
||||
#include <ATen/ops/copy_sparse_to_sparse.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
@ -436,4 +440,137 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||
return out;
|
||||
}
|
||||
|
||||
using OptTensor = std::optional<Tensor>;
|
||||
|
||||
|
||||
static void sparse_mask_apply_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& src_in,
|
||||
const Tensor& mask_in,
|
||||
bool accumulate_matches,
|
||||
bool require_same_sizes,
|
||||
bool coalesce_mask) {
|
||||
TORCH_CHECK(src_in.is_sparse() && mask_in.is_sparse(),
|
||||
"sparse_mask: expected both inputs to be sparse COO");
|
||||
TORCH_CHECK(src_in.is_mps() && mask_in.is_mps(),
|
||||
"sparse_mask: expected tensors to be on MPS device");
|
||||
TORCH_CHECK(src_in.sparse_dim() == mask_in.sparse_dim(),
|
||||
"sparse_mask: sparse_dim mismatch: ", src_in.sparse_dim(), " vs ", mask_in.sparse_dim());
|
||||
if (require_same_sizes) {
|
||||
TORCH_CHECK(src_in.sizes().equals(mask_in.sizes()),
|
||||
"sparse_mask: sizes must match exactly (no broadcasting)");
|
||||
}
|
||||
auto src = src_in.coalesce();
|
||||
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
|
||||
|
||||
const int64_t src_nnz = src._nnz();
|
||||
const int64_t mask_nnz = mask._nnz();
|
||||
const int64_t sd = src.sparse_dim();
|
||||
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
|
||||
|
||||
auto commonDtype = at::result_type(src, mask);
|
||||
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
|
||||
"Can't convert result type ", commonDtype, " to output ", result.scalar_type());
|
||||
|
||||
if (mask_nnz == 0) {
|
||||
alias_into_sparse(
|
||||
result,
|
||||
mask._indices().narrow(1, 0, 0),
|
||||
at::empty({0}, result.options().dtype(result.scalar_type())));
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(sd > 0 || (src_nnz <= 1 && mask_nnz <= 1),
|
||||
"sparse_mask: invalid sparse_dim or nnz");
|
||||
|
||||
if (sd == 0) {
|
||||
auto out_indices = mask._indices().narrow(1, 0, 1);
|
||||
auto out_values = src_nnz
|
||||
? src._values().narrow(0, 0, 1).to(commonDtype)
|
||||
: at::zeros({1}, at::device(result.device()).dtype(commonDtype));
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
if (src_nnz == 0) {
|
||||
auto out_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype);
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_indices = src._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
|
||||
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const bool A_is_src = (src_nnz <= mask_nnz);
|
||||
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto A_keys = A_is_src ? src_keys : mask_keys;
|
||||
auto B_keys = A_is_src ? mask_keys : src_keys;
|
||||
|
||||
const auto device = result.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_src);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
|
||||
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
|
||||
if (M > 0) {
|
||||
auto src_match = outA_idx.narrow(0, 0, M);
|
||||
auto mask_match = outB_idx.narrow(0, 0, M);
|
||||
|
||||
auto src_rows = src_values.index_select(0, src_match);
|
||||
if (accumulate_matches) {
|
||||
out_values.index_add_(0, mask_match, src_rows);
|
||||
} else {
|
||||
out_values.index_copy_(0, mask_match, src_rows);
|
||||
}
|
||||
}
|
||||
|
||||
alias_into_sparse(result, mask_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_intersection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
const Tensor& rhs,
|
||||
const OptTensor& = std::nullopt) {
|
||||
sparse_mask_apply_out_mps_kernel(
|
||||
result,
|
||||
/*src_in=*/lhs,
|
||||
/*mask_in=*/rhs,
|
||||
/*accumulate_matches=*/false,
|
||||
/*require_same_sizes=*/false,
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -3,6 +3,9 @@
|
||||
using namespace metal;
|
||||
|
||||
|
||||
template <typename T> struct MulAccum { using type = float; };
|
||||
template <> struct MulAccum<float2> { using type = float2; };
|
||||
|
||||
template <typename T>
|
||||
kernel void dense_sparse_mul_kernel(
|
||||
device const T* dense [[buffer(0)]],
|
||||
@ -29,8 +32,9 @@ kernel void dense_sparse_mul_kernel(
|
||||
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
|
||||
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const auto a = static_cast<float>(values[val_idx]);
|
||||
const auto b = static_cast<float>(dense[dense_idx]);
|
||||
using accum_t = typename MulAccum<T>::type;
|
||||
const accum_t a = static_cast<accum_t>(values[val_idx]);
|
||||
const accum_t b = static_cast<accum_t>(dense[dense_idx]);
|
||||
out_values[val_idx] = static_cast<T>(a * b);
|
||||
}
|
||||
|
||||
@ -130,6 +134,8 @@ kernel void fused_gather_mul_kernel(
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(float);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(half);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(bfloat);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(long);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(float2);
|
||||
|
||||
#define INSTANTIATE_FUSED_GATHER_MUL(DTYPE) \
|
||||
template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void \
|
||||
|
||||
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -13,20 +13,22 @@ constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB
|
||||
|
||||
AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() {
|
||||
static AcceleratorAllocatorConfig instance;
|
||||
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \
|
||||
auto env##_name = c10::utils::get_env(#env); \
|
||||
if (env##_name.has_value()) { \
|
||||
if (deprecated) { \
|
||||
TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \
|
||||
} \
|
||||
instance.parseArgs(env##_name.value()); \
|
||||
return true; \
|
||||
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env) \
|
||||
auto env##_name = c10::utils::get_env(#env); \
|
||||
if (env##_name.has_value()) { \
|
||||
instance.parseArgs(env##_name.value()); \
|
||||
return true; \
|
||||
}
|
||||
static bool env_flag [[maybe_unused]] = []() {
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false)
|
||||
// Keep this for backwards compatibility
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true)
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true)
|
||||
// Parse allocator configuration from environment variables.
|
||||
// The first two entries are kept for backward compatibility with legacy
|
||||
// CUDA and HIP environment variable names. The new unified variable
|
||||
// (PYTORCH_ALLOC_CONF) should be used going forward.
|
||||
// Note: keep the parsing order and logic stable to avoid potential
|
||||
// performance regressions in internal tests.
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF)
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF)
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF)
|
||||
return false;
|
||||
}();
|
||||
#undef C10_ALLOCATOR_CONFIG_PARSE_ENV
|
||||
@ -127,8 +129,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
std::fill(
|
||||
std::next(
|
||||
roundup_power2_divisions_.begin(),
|
||||
static_cast<std::vector<size_t>::difference_type>(
|
||||
last_index + 1)),
|
||||
static_cast<std::vector<size_t>::difference_type>(last_index)),
|
||||
roundup_power2_divisions_.end(),
|
||||
value);
|
||||
} else {
|
||||
|
||||
@ -28,101 +28,8 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// [dtype Macros note] For the macros below:
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
|
||||
// designed to behave similarly to the Dispatch macros with the same name.
|
||||
//
|
||||
// For adding a new dtype: In the beginning, we had an idea that there was a
|
||||
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
|
||||
// iterate over them. But over the years we added weird types which couldn't
|
||||
// be handled uniformly everywhere and so in the end we ended up with some
|
||||
// mish-mosh of some helper macros, but mostly use sites making a call about
|
||||
// what dtypes they can or can't support. So if you want to add a new dtype,
|
||||
// the preferred resolution is to find a dtype similar to what you want,
|
||||
// grep for it and edit all the sites you find this way. If you need to add
|
||||
// a completely new kind of dtype, you're going to have to laboriously audit
|
||||
// all of the sites everywhere to figure out how it should work. Consulting
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
//
|
||||
// TODO: To add unsigned int types here, we must define accumulate type.
|
||||
// But uint8 currently accumulates into int64, so we would have to make
|
||||
// an inconsistent choice for the larger types. Difficult.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
namespace impl {
|
||||
|
||||
// These are used to map ScalarTypes to C++ types.
|
||||
|
||||
template <c10::ScalarType N>
|
||||
struct ScalarTypeToCPPType;
|
||||
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* can't pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
||||
|
||||
#undef SPECIALIZE_ScalarTypeToCPPType
|
||||
|
||||
template <c10::ScalarType N>
|
||||
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
||||
|
||||
} // namespace impl
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
@ -138,130 +45,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
// NB: despite its generic sounding name, the macros that don't take _AND
|
||||
// are mostly only used by tensorexpr
|
||||
#define AT_FORALL_INT_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double)
|
||||
|
||||
// These macros are often controlling how many template instantiations we
|
||||
// create for kernels. It is typically inappropriate to add new dtypes here,
|
||||
// instead, new types should be added to use sites on a case-by-case basis.
|
||||
// We generally are not accepting new dtypes due to binary size concerns.
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE6>::t), \
|
||||
SCALARTYPE6) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE7>::t), \
|
||||
SCALARTYPE7)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
_(c10::qint32, QInt32) \
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble)
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
@ -269,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||
#undef DEFINE_CONSTANT
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline size_t elementSize(ScalarType t) {
|
||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
@ -525,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
||||
|
||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
@ -8,386 +7,119 @@
|
||||
|
||||
namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
|
||||
|
||||
CUDAAllocatorConfig::CUDAAllocatorConfig()
|
||||
: m_max_split_size(std::numeric_limits<size_t>::max()),
|
||||
m_max_non_split_rounding_size(kLargeBuffer),
|
||||
m_garbage_collection_threshold(0),
|
||||
m_pinned_num_register_threads(1),
|
||||
m_pinned_reserve_segment_size_mb(0),
|
||||
m_expandable_segments(false),
|
||||
#if CUDA_VERSION >= 12030
|
||||
m_expandable_segments_handle_type(
|
||||
Expandable_Segments_Handle_Type::UNSPECIFIED),
|
||||
#else
|
||||
m_expandable_segments_handle_type(
|
||||
Expandable_Segments_Handle_Type::POSIX_FD),
|
||||
#endif
|
||||
m_release_lock_on_cudamalloc(false),
|
||||
m_pinned_use_cuda_host_register(false),
|
||||
m_graph_capture_record_stream_reuse(false),
|
||||
m_pinned_use_background_threads(false) {
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
|
||||
size_t log_size = (63 - llvm::countLeadingZeros(size));
|
||||
|
||||
// Our intervals start at 1MB and end at 64GB
|
||||
const size_t interval_start =
|
||||
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
|
||||
const size_t interval_end =
|
||||
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
|
||||
TORCH_CHECK(
|
||||
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
|
||||
"kRoundUpPowerOfTwoIntervals mismatch");
|
||||
|
||||
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
|
||||
|
||||
index = std::max(0, index);
|
||||
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
|
||||
return instance().m_roundup_power2_divisions[index];
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::lexArgs(
|
||||
const std::string& env,
|
||||
std::vector<std::string>& config) {
|
||||
std::vector<char> buf;
|
||||
|
||||
for (char ch : env) {
|
||||
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
buf.clear();
|
||||
}
|
||||
config.emplace_back(1, ch);
|
||||
} else if (ch != ' ') {
|
||||
buf.emplace_back(ch);
|
||||
}
|
||||
}
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::consumeToken(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
const char c) {
|
||||
TORCH_CHECK(
|
||||
i < config.size() && config[i] == std::string(1, c),
|
||||
"Error parsing CachingAllocator settings, expected ",
|
||||
c,
|
||||
"");
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseMaxSplitSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
constexpr int mb = 1024 * 1024;
|
||||
if (++i < config.size()) {
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > kLargeBuffer / mb,
|
||||
"CachingAllocator option max_split_size_mb too small, must be > ",
|
||||
kLargeBuffer / mb,
|
||||
"");
|
||||
val1 = std::max(val1, kLargeBuffer / mb);
|
||||
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
||||
m_max_split_size = val1 * 1024 * 1024;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
constexpr int mb = 1024 * 1024;
|
||||
if (++i < config.size()) {
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > kLargeBuffer / mb,
|
||||
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
|
||||
kLargeBuffer / mb,
|
||||
"");
|
||||
val1 = std::max(val1, kLargeBuffer / mb);
|
||||
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
||||
m_max_non_split_rounding_size = val1 * 1024 * 1024;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
double val1 = stod(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
|
||||
TORCH_CHECK(
|
||||
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
|
||||
m_garbage_collection_threshold = val1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting garbage_collection_threshold value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
bool first_value = true;
|
||||
|
||||
if (++i < config.size()) {
|
||||
if (std::string_view(config[i]) == "[") {
|
||||
size_t last_index = 0;
|
||||
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
|
||||
while (++i < config.size() && std::string_view(config[i]) != "]") {
|
||||
const std::string& val1 = config[i];
|
||||
size_t val2 = 0;
|
||||
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
val2 = stoi(config[i]);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error parsing roundup_power2_divisions value", "");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
val2 == 0 || llvm::isPowerOf2_64(val2),
|
||||
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
|
||||
"");
|
||||
|
||||
if (std::string_view(val1) == ">") {
|
||||
std::fill(
|
||||
std::next(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
static_cast<std::vector<unsigned long>::difference_type>(
|
||||
last_index)),
|
||||
m_roundup_power2_divisions.end(),
|
||||
val2);
|
||||
} else {
|
||||
size_t val1_long = stoul(val1);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val1_long),
|
||||
"For roundups, the intervals have to be power of 2 ",
|
||||
"");
|
||||
|
||||
size_t index = 63 - llvm::countLeadingZeros(val1_long);
|
||||
index = std::max((size_t)0, index);
|
||||
index = std::min(index, m_roundup_power2_divisions.size() - 1);
|
||||
|
||||
if (first_value) {
|
||||
std::fill(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
std::next(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
static_cast<std::vector<unsigned long>::difference_type>(
|
||||
index)),
|
||||
val2);
|
||||
first_value = false;
|
||||
}
|
||||
if (index < m_roundup_power2_divisions.size()) {
|
||||
m_roundup_power2_divisions[index] = val2;
|
||||
}
|
||||
last_index = index;
|
||||
}
|
||||
|
||||
if (std::string_view(config[i + 1]) != "]") {
|
||||
consumeToken(config, ++i, ',');
|
||||
}
|
||||
}
|
||||
} else { // Keep this for backwards compatibility
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val1),
|
||||
"For roundups, the divisions has to be power of 2 ",
|
||||
"");
|
||||
std::fill(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
m_roundup_power2_divisions.end(),
|
||||
val1);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseAllocatorConfig(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i,
|
||||
bool& used_cudaMallocAsync) {
|
||||
// For ease of maintenance and understanding, the CUDA and ROCm
|
||||
// implementations of this function are separated. This avoids having many
|
||||
// #ifdef's throughout.
|
||||
#ifdef USE_ROCM
|
||||
// Ease burden on ROCm users by allowing either cuda or hip tokens.
|
||||
// cuda token is broken up to prevent hipify matching it.
|
||||
#define PYTORCH_TOKEN1 \
|
||||
"cud" \
|
||||
"aMallocAsync"
|
||||
#define PYTORCH_TOKEN2 "hipMallocAsync"
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
|
||||
(config[i] == PYTORCH_TOKEN2)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
||||
used_cudaMallocAsync =
|
||||
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
config[i] == get()->name() ||
|
||||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time, ",
|
||||
config[i],
|
||||
" != ",
|
||||
get()->name());
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error parsing backend value", "");
|
||||
}
|
||||
return i;
|
||||
#undef PYTORCH_TOKEN1
|
||||
#undef PYTORCH_TOKEN2
|
||||
tokenizer.checkToken(++i, ":");
|
||||
i++; // Move to the value after the colon
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(
|
||||
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
|
||||
(tokenizer[i] == PYTORCH_TOKEN2)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
||||
used_cudaMallocAsync =
|
||||
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tokenizer[i] == get()->name() ||
|
||||
(tokenizer[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time, ",
|
||||
tokenizer[i],
|
||||
" != ",
|
||||
get()->name());
|
||||
#else // USE_ROCM
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
|
||||
"Unknown allocator backend, "
|
||||
"options are native and cudaMallocAsync");
|
||||
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
|
||||
if (used_cudaMallocAsync) {
|
||||
TORCH_CHECK(
|
||||
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native and " PYTORCH_TOKEN1);
|
||||
used_cudaMallocAsync = (tokenizer[i] == PYTORCH_TOKEN1);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tokenizer[i] == get()->name(),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time, ",
|
||||
tokenizer[i],
|
||||
" != ",
|
||||
get()->name());
|
||||
if (used_cudaMallocAsync) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
"backend:cudaMallocAsync requires CUDA runtime "
|
||||
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||
version);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||
CUDA_VERSION);
|
||||
#endif
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
config[i] == get()->name(),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time");
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error parsing backend value", "");
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
"backend:cudaMallocAsync requires CUDA runtime "
|
||||
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||
version);
|
||||
#else // CUDA_VERSION >= 11040
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||
CUDA_VERSION);
|
||||
#endif // CUDA_VERSION >= 11040
|
||||
}
|
||||
return i;
|
||||
#endif // USE_ROCM
|
||||
return i;
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
||||
// If empty, set the default values
|
||||
m_max_split_size = std::numeric_limits<size_t>::max();
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
m_garbage_collection_threshold = 0;
|
||||
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
|
||||
bool used_cudaMallocAsync = false;
|
||||
bool used_native_specific_option = false;
|
||||
|
||||
if (!env.has_value()) {
|
||||
return;
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
|
||||
m_last_allocator_settings = env.value();
|
||||
}
|
||||
|
||||
std::vector<std::string> config;
|
||||
lexArgs(env.value(), config);
|
||||
|
||||
for (size_t i = 0; i < config.size(); i++) {
|
||||
std::string_view config_item_view(config[i]);
|
||||
if (config_item_view == "max_split_size_mb") {
|
||||
i = parseMaxSplitSize(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "max_non_split_rounding_mb") {
|
||||
i = parseMaxNonSplitRoundingSize(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "garbage_collection_threshold") {
|
||||
i = parseGarbageCollectionThreshold(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "roundup_power2_divisions") {
|
||||
i = parseRoundUpPower2Divisions(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "backend") {
|
||||
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
|
||||
} else if (config_item_view == "expandable_segments") {
|
||||
used_native_specific_option = true;
|
||||
consumeToken(config, ++i, ':');
|
||||
++i;
|
||||
TORCH_CHECK(
|
||||
i < config.size() &&
|
||||
(std::string_view(config[i]) == "True" ||
|
||||
std::string_view(config[i]) == "False"),
|
||||
"Expected a single True/False argument for expandable_segments");
|
||||
config_item_view = config[i];
|
||||
m_expandable_segments = (config_item_view == "True");
|
||||
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
|
||||
for (size_t i = 0; i < tokenizer.size(); i++) {
|
||||
const auto& key = tokenizer[i];
|
||||
if (key == "backend") {
|
||||
i = parseAllocatorConfig(tokenizer, i, used_cudaMallocAsync);
|
||||
} else if (
|
||||
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
||||
// use, accept both. We must break up the string to prevent hipify here.
|
||||
config_item_view == "release_lock_on_hipmalloc" ||
|
||||
config_item_view ==
|
||||
key == "release_lock_on_hipmalloc" ||
|
||||
key ==
|
||||
"release_lock_on_c"
|
||||
"udamalloc") {
|
||||
used_native_specific_option = true;
|
||||
consumeToken(config, ++i, ':');
|
||||
++i;
|
||||
TORCH_CHECK(
|
||||
i < config.size() &&
|
||||
(std::string_view(config[i]) == "True" ||
|
||||
std::string_view(config[i]) == "False"),
|
||||
"Expected a single True/False argument for release_lock_on_cudamalloc");
|
||||
config_item_view = config[i];
|
||||
m_release_lock_on_cudamalloc = (config_item_view == "True");
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
|
||||
} else if (
|
||||
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
||||
// use, accept both. We must break up the string to prevent hipify here.
|
||||
config_item_view == "pinned_use_hip_host_register" ||
|
||||
config_item_view ==
|
||||
key == "pinned_use_hip_host_register" ||
|
||||
key ==
|
||||
"pinned_use_c"
|
||||
"uda_host_register") {
|
||||
i = parsePinnedUseCudaHostRegister(config, i);
|
||||
i = parsePinnedUseCudaHostRegister(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "pinned_num_register_threads") {
|
||||
i = parsePinnedNumRegisterThreads(config, i);
|
||||
} else if (key == "pinned_num_register_threads") {
|
||||
i = parsePinnedNumRegisterThreads(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "pinned_reserve_segment_size_mb") {
|
||||
i = parsePinnedReserveSegmentSize(config, i);
|
||||
} else if (key == "pinned_reserve_segment_size_mb") {
|
||||
i = parsePinnedReserveSegmentSize(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "pinned_use_background_threads") {
|
||||
i = parsePinnedUseBackgroundThreads(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "graph_capture_record_stream_reuse") {
|
||||
i = parseGraphCaptureRecordStreamReuse(config, i);
|
||||
} else if (key == "graph_capture_record_stream_reuse") {
|
||||
i = parseGraphCaptureRecordStreamReuse(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else {
|
||||
const auto& keys =
|
||||
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
|
||||
TORCH_CHECK(
|
||||
false, "Unrecognized CachingAllocator option: ", config_item_view);
|
||||
keys.find(key) != keys.end(),
|
||||
"Unrecognized key '",
|
||||
key,
|
||||
"' in CUDA allocator config.");
|
||||
// Skip the key and its value
|
||||
i = tokenizer.skipKey(i);
|
||||
}
|
||||
|
||||
if (i + 1 < config.size()) {
|
||||
consumeToken(config, ++i, ',');
|
||||
if (i + 1 < tokenizer.size()) {
|
||||
tokenizer.checkToken(++i, ",");
|
||||
}
|
||||
}
|
||||
|
||||
@ -399,97 +131,51 @@ void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for pinned_use_cuda_host_register");
|
||||
m_pinned_use_cuda_host_register = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_use_cuda_host_register value", "");
|
||||
}
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for graph_capture_record_stream_reuse");
|
||||
m_graph_capture_record_stream_reuse = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting graph_capture_record_stream_reuse value", "");
|
||||
}
|
||||
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_graph_capture_record_stream_reuse = tokenizer.toBool(++i);
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
size_t val2 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val2),
|
||||
"Number of register threads has to be power of 2 ",
|
||||
"");
|
||||
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
||||
TORCH_CHECK(
|
||||
val2 <= maxThreads,
|
||||
"Number of register threads should be less than or equal to " +
|
||||
std::to_string(maxThreads),
|
||||
"");
|
||||
m_pinned_num_register_threads = val2;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_num_register_threads value", "");
|
||||
}
|
||||
tokenizer.checkToken(++i, ":");
|
||||
size_t val2 = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val2),
|
||||
"Number of register threads has to be power of 2, got ",
|
||||
val2);
|
||||
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
||||
TORCH_CHECK(
|
||||
val2 <= maxThreads,
|
||||
"Number of register threads should be less than or equal to ",
|
||||
maxThreads,
|
||||
", got ",
|
||||
val2);
|
||||
m_pinned_num_register_threads = val2;
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedReserveSegmentSize(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
size_t val2 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val2 > 0, "Pinned reserve segment size has to be greater than 0 ", "");
|
||||
m_pinned_reserve_segment_size_mb = val2;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_reserve_segment_size_mb value", "");
|
||||
}
|
||||
tokenizer.checkToken(++i, ":");
|
||||
size_t val2 = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK(val2 > 0, "Pinned reserve segment size has to be greater than 0");
|
||||
m_pinned_reserve_segment_size_mb = val2;
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for pinned_use_background_threads");
|
||||
m_pinned_use_background_threads = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_use_background_threads value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
// General caching allocator utilities
|
||||
void setAllocatorSettings(const std::string& env) {
|
||||
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
|
||||
}
|
||||
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/env.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
enum class Expandable_Segments_Handle_Type : int {
|
||||
@ -23,20 +18,23 @@ enum class Expandable_Segments_Handle_Type : int {
|
||||
class C10_CUDA_API CUDAAllocatorConfig {
|
||||
public:
|
||||
static size_t max_split_size() {
|
||||
return instance().m_max_split_size;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
|
||||
}
|
||||
static double garbage_collection_threshold() {
|
||||
return instance().m_garbage_collection_threshold;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
garbage_collection_threshold();
|
||||
}
|
||||
|
||||
static bool expandable_segments() {
|
||||
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
use_expandable_segments();
|
||||
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
|
||||
if (instance().m_expandable_segments) {
|
||||
if (enabled) {
|
||||
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
|
||||
}
|
||||
return false;
|
||||
#else
|
||||
return instance().m_expandable_segments;
|
||||
return enabled;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -67,7 +65,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
}
|
||||
|
||||
static bool pinned_use_background_threads() {
|
||||
return instance().m_pinned_use_background_threads;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
}
|
||||
|
||||
static size_t pinned_reserve_segment_size_mb() {
|
||||
@ -81,24 +80,23 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return 128;
|
||||
}
|
||||
|
||||
// This is used to round-up allocation size to nearest power of 2 divisions.
|
||||
// More description below in function roundup_power2_next_division
|
||||
// As an example, if we want 4 divisions between 2's power, this can be done
|
||||
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
|
||||
static size_t roundup_power2_divisions(size_t size);
|
||||
static size_t roundup_power2_divisions(size_t size) {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions(size);
|
||||
}
|
||||
|
||||
static std::vector<size_t> roundup_power2_divisions() {
|
||||
return instance().m_roundup_power2_divisions;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions();
|
||||
}
|
||||
|
||||
static size_t max_non_split_rounding_size() {
|
||||
return instance().m_max_non_split_rounding_size;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
max_non_split_rounding_size();
|
||||
}
|
||||
|
||||
static std::string last_allocator_settings() {
|
||||
std::lock_guard<std::mutex> lock(
|
||||
instance().m_last_allocator_settings_mutex);
|
||||
return instance().m_last_allocator_settings;
|
||||
return c10::CachingAllocator::getAllocatorSettings();
|
||||
}
|
||||
|
||||
static CUDAAllocatorConfig& instance() {
|
||||
@ -111,70 +109,75 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
|
||||
}
|
||||
#endif
|
||||
inst->parseArgs(env);
|
||||
// Note: keep the parsing order and logic stable to avoid potential
|
||||
// performance regressions in internal tests.
|
||||
if (!env.has_value()) {
|
||||
env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
|
||||
}
|
||||
if (env.has_value()) {
|
||||
inst->parseArgs(env.value());
|
||||
}
|
||||
return inst;
|
||||
})();
|
||||
return *s_instance;
|
||||
}
|
||||
|
||||
void parseArgs(const std::optional<std::string>& env);
|
||||
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
|
||||
// issue.
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
static std::unordered_set<std::string> keys{
|
||||
"backend",
|
||||
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
|
||||
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_cud"
|
||||
"amalloc",
|
||||
"pinned_use_cud"
|
||||
"a_host_register",
|
||||
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_hipmalloc",
|
||||
"pinned_use_hip_host_register",
|
||||
"graph_capture_record_stream_reuse",
|
||||
"pinned_reserve_segment_size_mb",
|
||||
"pinned_num_register_threads"};
|
||||
return keys;
|
||||
}
|
||||
|
||||
void parseArgs(const std::string& env);
|
||||
|
||||
private:
|
||||
CUDAAllocatorConfig();
|
||||
CUDAAllocatorConfig() = default;
|
||||
|
||||
static void lexArgs(const std::string& env, std::vector<std::string>& config);
|
||||
static void consumeToken(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
const char c);
|
||||
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
|
||||
size_t parseMaxNonSplitRoundingSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseGarbageCollectionThreshold(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseRoundUpPower2Divisions(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseAllocatorConfig(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i,
|
||||
bool& used_cudaMallocAsync);
|
||||
size_t parsePinnedUseCudaHostRegister(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
size_t parsePinnedNumRegisterThreads(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
size_t parsePinnedReserveSegmentSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
size_t parseGraphCaptureRecordStreamReuse(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
|
||||
std::atomic<size_t> m_max_split_size;
|
||||
std::atomic<size_t> m_max_non_split_rounding_size;
|
||||
std::vector<size_t> m_roundup_power2_divisions;
|
||||
std::atomic<double> m_garbage_collection_threshold;
|
||||
std::atomic<size_t> m_pinned_num_register_threads;
|
||||
std::atomic<size_t> m_pinned_reserve_segment_size_mb;
|
||||
std::atomic<bool> m_expandable_segments;
|
||||
std::atomic<Expandable_Segments_Handle_Type>
|
||||
m_expandable_segments_handle_type;
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc;
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register;
|
||||
std::atomic<bool> m_graph_capture_record_stream_reuse;
|
||||
std::atomic<bool> m_pinned_use_background_threads;
|
||||
std::string m_last_allocator_settings;
|
||||
std::mutex m_last_allocator_settings_mutex;
|
||||
std::atomic<size_t> m_pinned_num_register_threads{1};
|
||||
std::atomic<size_t> m_pinned_reserve_segment_size_mb{0};
|
||||
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
|
||||
#if CUDA_VERSION >= 12030
|
||||
{Expandable_Segments_Handle_Type::UNSPECIFIED};
|
||||
#else
|
||||
{Expandable_Segments_Handle_Type::POSIX_FD};
|
||||
#endif
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc{false};
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register{false};
|
||||
std::atomic<bool> m_graph_capture_record_stream_reuse{false};
|
||||
};
|
||||
|
||||
// General caching allocator utilities
|
||||
C10_CUDA_API void setAllocatorSettings(const std::string& env);
|
||||
// Keep this for backwards compatibility
|
||||
using c10::CachingAllocator::setAllocatorSettings;
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
@ -64,10 +64,6 @@ namespace cuda::CUDACachingAllocator {
|
||||
using namespace c10::CachingAllocator;
|
||||
using namespace c10::CachingDeviceAllocator;
|
||||
|
||||
// Included here as this is externally used in CUDAAllocatorConfig
|
||||
const size_t kLargeBuffer =
|
||||
20971520; // "large" allocations may be packed in 20 MiB blocks
|
||||
|
||||
namespace Native {
|
||||
|
||||
//
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
@ -49,10 +50,9 @@ namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
// Preserved only for BC reasons
|
||||
// NOLINTNEXTLINE(misc-unused-using-decls)
|
||||
using c10::CachingAllocator::kLargeBuffer;
|
||||
using c10::CachingDeviceAllocator::DeviceStats;
|
||||
|
||||
extern const size_t kLargeBuffer;
|
||||
|
||||
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
|
||||
|
||||
// Struct containing info of an allocation block (i.e. a fractional part of a
|
||||
|
||||
@ -67,8 +67,8 @@ TEST(AllocatorConfigTest, allocator_config_test) {
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
|
||||
// EXPECT_EQ(
|
||||
// AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1);
|
||||
EXPECT_EQ(
|
||||
@ -101,8 +101,8 @@ TEST(AllocatorConfigTest, allocator_config_test) {
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
|
||||
// EXPECT_EQ(
|
||||
// AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2);
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ struct default_constructible
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>*)
|
||||
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>* /*unused*/)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@ -76,7 +76,7 @@ class type : public modifier<M, type<T, Tag, M...>>...
|
||||
{
|
||||
public:
|
||||
template <typename TT = T, typename = std::enable_if_t<std::is_trivially_constructible<TT>{}>>
|
||||
explicit type(uninitialized_t)
|
||||
explicit type(uninitialized_t /*unused*/)
|
||||
noexcept
|
||||
{
|
||||
}
|
||||
@ -138,7 +138,7 @@ private:
|
||||
|
||||
namespace impl {
|
||||
template <typename T, typename Tag, typename ... Ms>
|
||||
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>*) { return true;}
|
||||
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>* /*unused*/) { return true;}
|
||||
constexpr bool is_strong_type_func(...) { return false;}
|
||||
template <typename T, typename Tag, typename ... Ms>
|
||||
constexpr T underlying_type(strong::type<T, Tag, Ms...>*);
|
||||
|
||||
@ -20,8 +20,6 @@ constexpr size_t kMinBlockSize = 512;
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// "large" allocations may be packed in 20 MiB blocks
|
||||
constexpr size_t kLargeBuffer = 20971520;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
@ -435,6 +433,18 @@ class DeviceCachingAllocator {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
device_free =
|
||||
raw_device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
}
|
||||
auto allocated_bytes =
|
||||
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
@ -457,7 +467,9 @@ class DeviceCachingAllocator {
|
||||
static_cast<int>(device),
|
||||
" has a total capacity of ",
|
||||
format_size(device_total),
|
||||
". Of the allocated memory ",
|
||||
" of which ",
|
||||
format_size(device_free),
|
||||
" is free. Of the allocated memory ",
|
||||
format_size(allocated_bytes),
|
||||
" is allocated by PyTorch, and ",
|
||||
format_size(reserved_bytes - allocated_bytes),
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
|
||||
3
cmake/External/aotriton.cmake
vendored
3
cmake/External/aotriton.cmake
vendored
@ -244,7 +244,8 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
else()
|
||||
set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}")
|
||||
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX)
|
||||
if(${__AOTRITON_RUNTIME_INDEX} LESS 0)
|
||||
# Always build aotriton runtime from source on Windows due to lack of pre-built binaries
|
||||
if(${__AOTRITON_RUNTIME_INDEX} LESS 0 OR WIN32)
|
||||
message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \
|
||||
Build runtime from source")
|
||||
aotriton_build_from_source(ON aotriton_runtime)
|
||||
|
||||
@ -68,14 +68,6 @@
|
||||
.. autofunction:: get_validators
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: write_file_on_exit
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: write_file
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: read_file
|
||||
```
|
||||
@ -95,3 +87,7 @@
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_rotating_buffer_size
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: set_numerical_check_tolerances
|
||||
```
|
||||
@ -23,6 +23,7 @@ Submodules
|
||||
flex_attention
|
||||
bias
|
||||
experimental
|
||||
varlen
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
@ -30,3 +31,4 @@ Submodules
|
||||
nn.attention.flex_attention
|
||||
nn.attention.bias
|
||||
nn.attention.experimental
|
||||
nn.attention.varlen
|
||||
|
||||
17
docs/source/nn.attention.varlen.md
Normal file
17
docs/source/nn.attention.varlen.md
Normal file
@ -0,0 +1,17 @@
|
||||
```{eval-rst}
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
```
|
||||
|
||||
# torch.nn.attention.varlen
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.nn.attention.varlen
|
||||
.. currentmodule:: torch.nn.attention.varlen
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: varlen_attn
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autoclass:: AuxRequest
|
||||
```
|
||||
@ -228,3 +228,4 @@ Low-Precision functions
|
||||
ScalingType
|
||||
SwizzleType
|
||||
scaled_mm
|
||||
scaled_grouped_mm
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.compiler.config
|
||||
|
||||
```
|
||||
|
||||
# torch.compiler.config
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.compiler.config
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: torch.compiler.config.job_id
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
@ -816,6 +816,10 @@ Operator Tags
|
||||
.. py:module:: torch.types
|
||||
.. py:module:: torch.version
|
||||
|
||||
.. Compiler configuration module - documented in torch.compiler.config.md
|
||||
.. py:module:: torch.compiler.config
|
||||
:noindex:
|
||||
|
||||
.. Hidden aliases (e.g. torch.functional.broadcast_tensors()). We want `torch.broadcast_tensors()` to
|
||||
be visible only.
|
||||
.. toctree::
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# A Pyrefly configuration for PyTorch
|
||||
# Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini
|
||||
python-version = "3.12"
|
||||
|
||||
project-includes = [
|
||||
"torch",
|
||||
"caffe2",
|
||||
@ -36,6 +38,7 @@ project-excludes = [
|
||||
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
|
||||
"torch/_inductor/codecache.py",
|
||||
"torch/distributed/elastic/metrics/__init__.py",
|
||||
"torch/_inductor/fx_passes/bucketing.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
||||
@ -10,6 +10,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
|
||||
)
|
||||
|
||||
76
test/cpp/aoti_abi_check/test_scalartype.cpp
Normal file
76
test/cpp/aoti_abi_check/test_scalartype.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
TEST(TestScalarType, ScalarTypeToCPPTypeT) {
|
||||
using torch::headeronly::ScalarType;
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT;
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
EXPECT_EQ(typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE));
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
{ \
|
||||
EXPECT_EQ( \
|
||||
typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); \
|
||||
count++; \
|
||||
}
|
||||
|
||||
#define TEST_FORALL(M, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestScalarType, M) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t count = 0; \
|
||||
M(__VA_ARGS__ DEFINE_CHECK); \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
}
|
||||
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46)
|
||||
TEST_FORALL(AT_FORALL_INT_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, )
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND2, 9, Bool, Half, )
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND3, 10, Bool, Half, ComplexFloat, )
|
||||
TEST_FORALL(
|
||||
AT_FORALL_SCALAR_TYPES_AND7,
|
||||
14,
|
||||
Bool,
|
||||
Half,
|
||||
ComplexHalf,
|
||||
ComplexFloat,
|
||||
ComplexDouble,
|
||||
UInt16,
|
||||
UInt32, )
|
||||
TEST_FORALL(AT_FORALL_QINT_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_FLOAT8_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
||||
|
||||
#undef DEFINE_CHECK
|
||||
#undef TEST_FORALL
|
||||
|
||||
TEST(TestScalarType, toString) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, operator_left_shift) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
{ \
|
||||
std::stringstream ss; \
|
||||
ss << ScalarType::name; \
|
||||
EXPECT_EQ(ss.str(), #name); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
@ -559,7 +559,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
FAIL = 138
|
||||
pc = start_processes(
|
||||
name="echo",
|
||||
entrypoint=bin("echo1.py"),
|
||||
entrypoint=bin("echo4.py"),
|
||||
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -24,6 +23,5 @@ if __name__ == "__main__":
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
time.sleep(1000)
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
|
||||
29
test/distributed/elastic/multiprocessing/bin/echo4.py
Executable file
29
test/distributed/elastic/multiprocessing/bin/echo4.py
Executable file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="test binary, exits with exitcode")
|
||||
parser.add_argument("--exitcode", type=int, default=0)
|
||||
parser.add_argument("msg", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
exitcode = args.exitcode
|
||||
if exitcode != 0:
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
time.sleep(1000)
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
@ -536,6 +536,23 @@ class TestScheduleLowering(TestCase):
|
||||
"compute": ["0F0", "0F1", " ", "0B0", "0B1"],
|
||||
"comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"],
|
||||
},
|
||||
{
|
||||
"compute": ["0F0", "0F1", "1F0", "1F1", "1B0", "1B1", "0B0", "0B1"],
|
||||
"comms": [
|
||||
"0UNSHARD",
|
||||
"1UNSHARD",
|
||||
"0F0",
|
||||
"0F1",
|
||||
"1F0",
|
||||
"1F1",
|
||||
"1B0",
|
||||
"1B1",
|
||||
"1RESHARD",
|
||||
"0B0",
|
||||
"0B1",
|
||||
"0RESHARD",
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_unshard_reshard(self, test_info):
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import functools
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from typing import Callable, ClassVar, Optional, Union
|
||||
from typing import Any, Callable, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -32,6 +31,7 @@ from torch.distributed.tensor.experimental._load_balancer import (
|
||||
_HeadTailLoadBalancer,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
@ -39,6 +39,7 @@ from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
AuxOutput,
|
||||
AuxRequest,
|
||||
BlockMask,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
)
|
||||
@ -391,9 +392,7 @@ def generate_random_lengths_in_chunks(
|
||||
return [num_chunks * chunk_size for num_chunks in num_chunks_per_document]
|
||||
|
||||
|
||||
def length_to_offsets(
|
||||
lengths: list[list[int]], device: Union[str, torch.device]
|
||||
) -> Tensor:
|
||||
def length_to_offsets(lengths: list[list[int]], device: str | torch.device) -> Tensor:
|
||||
"""Converts a list of lengths to a list of offsets.
|
||||
|
||||
Args:
|
||||
@ -475,8 +474,9 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
*,
|
||||
qkv_size: int,
|
||||
B: int = 1,
|
||||
mask_func: _mask_mod_signature = causal_mask,
|
||||
lb: Optional[_LoadBalancer] = None,
|
||||
block_mask,
|
||||
lb_type: str,
|
||||
document_lengths: Optional[list[list[int]]] = None,
|
||||
) -> None:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.cuda.manual_seed(1234)
|
||||
@ -486,6 +486,14 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
dim = 32
|
||||
nheads = 8
|
||||
seq_dim = 2
|
||||
lb = self._get_load_balancer(
|
||||
lb_type,
|
||||
{
|
||||
"seq_length": qkv_size,
|
||||
"document_lengths": document_lengths,
|
||||
"block_mask": block_mask,
|
||||
},
|
||||
)
|
||||
|
||||
qkv = [
|
||||
torch.rand(
|
||||
@ -497,15 +505,6 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
block_mask = compiled_create_block_mask(
|
||||
mask_func,
|
||||
B=B,
|
||||
H=1,
|
||||
Q_LEN=qkv_size,
|
||||
KV_LEN=qkv_size,
|
||||
device=self.device_type,
|
||||
)
|
||||
|
||||
expect_out, expect_aux = compiled_flex_attention(
|
||||
*qkv, block_mask=block_mask, return_aux=AuxRequest(lse=True)
|
||||
)
|
||||
@ -547,6 +546,8 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
# backward run
|
||||
cp_out.sum().backward()
|
||||
|
||||
atol = 2e-06
|
||||
rtol = 1e-05
|
||||
# unshard the output
|
||||
cp_out, cp_lse = context_parallel_unshard(
|
||||
device_mesh,
|
||||
@ -554,8 +555,8 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
seq_dims=[seq_dim] * 2,
|
||||
load_balancer=lb,
|
||||
)
|
||||
torch.testing.assert_close(cp_out, expect_out)
|
||||
torch.testing.assert_close(cp_lse, expect_aux.lse)
|
||||
torch.testing.assert_close(cp_out, expect_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(cp_lse, expect_aux.lse, atol=atol, rtol=rtol)
|
||||
|
||||
# unshard the gradient
|
||||
cp_qkv_grad = context_parallel_unshard(
|
||||
@ -567,7 +568,38 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
|
||||
qkv_grad = [t.grad for t in qkv]
|
||||
for grad, cp_grad in zip(qkv_grad, cp_qkv_grad):
|
||||
torch.testing.assert_close(grad, cp_grad)
|
||||
torch.testing.assert_close(grad, cp_grad, atol=atol, rtol=rtol)
|
||||
|
||||
def _get_load_balancer(
|
||||
self, lb_type: str, kwargs: dict[str, Any]
|
||||
) -> Optional[_LoadBalancer]:
|
||||
seq_length = kwargs["seq_length"]
|
||||
document_lengths = kwargs["document_lengths"]
|
||||
block_mask = kwargs["block_mask"]
|
||||
|
||||
# generate load balancer
|
||||
if lb_type == "None":
|
||||
load_balancer = None # no load-balance
|
||||
elif lb_type == "_HeadTailLoadBalancer":
|
||||
assert isinstance(seq_length, int)
|
||||
load_balancer = _HeadTailLoadBalancer(
|
||||
seq_length, self.world_size, torch.device(self.device_type)
|
||||
)
|
||||
elif lb_type == "_PerDocumentHeadTailLoadBalancer":
|
||||
assert isinstance(document_lengths, list)
|
||||
load_balancer = _PerDocumentHeadTailLoadBalancer(
|
||||
document_lengths, self.world_size, torch.device(self.device_type)
|
||||
)
|
||||
elif lb_type == "_PTRRLoadBalancer":
|
||||
assert isinstance(block_mask, BlockMask)
|
||||
load_balancer = _PTRRLoadBalancer(
|
||||
block_mask,
|
||||
self.world_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"load_balancer type {lb_type} is not supported!")
|
||||
|
||||
return load_balancer
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
@ -575,33 +607,65 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
|
||||
)
|
||||
def test_cp_flex_attention_causal_mask(self) -> None:
|
||||
restore_enable_load_balance = _cp_options.enable_load_balance
|
||||
seq_length_list = [256 * self.world_size, 2048]
|
||||
load_balance_type_list = [
|
||||
"None",
|
||||
"_HeadTailLoadBalancer",
|
||||
"_PTRRLoadBalancer",
|
||||
]
|
||||
|
||||
for enable_load_balance in [
|
||||
False, # test w/o load-balancing
|
||||
True, # test w/ the default load-balancing
|
||||
]:
|
||||
_cp_options.enable_load_balance = enable_load_balance
|
||||
self.run_subtests(
|
||||
{
|
||||
"qkv_size": [
|
||||
(256 if enable_load_balance else 128) * self.world_size,
|
||||
2048,
|
||||
],
|
||||
},
|
||||
self._test_cp_flex_attention,
|
||||
# NOTE: Each (seq_len, load_balance_type) tuple introduces 2
|
||||
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
|
||||
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
|
||||
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
|
||||
# will be the total number of compilations in our test case.
|
||||
torch._dynamo.config.cache_size_limit = (len(seq_length_list) + 1) * (
|
||||
1 + len(load_balance_type_list)
|
||||
)
|
||||
|
||||
for qkv_size, lb_type in itertools.product(
|
||||
seq_length_list, load_balance_type_list
|
||||
):
|
||||
block_mask = compiled_create_block_mask(
|
||||
causal_mask,
|
||||
B=1,
|
||||
H=1,
|
||||
Q_LEN=qkv_size,
|
||||
KV_LEN=qkv_size,
|
||||
device=self.device_type,
|
||||
)
|
||||
self._test_cp_flex_attention(
|
||||
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
|
||||
)
|
||||
|
||||
_cp_options.enable_load_balance = restore_enable_load_balance
|
||||
|
||||
# NOTE: Context Parallel should not be used for small attentions (block_size < 128)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "Q_LEN 128 is not divisible by CP mesh world size"
|
||||
):
|
||||
self.run_subtests(
|
||||
{"qkv_size": [64 * self.world_size]},
|
||||
self._test_cp_flex_attention,
|
||||
)
|
||||
qkv_size = 64 * self.world_size
|
||||
block_mask = compiled_create_block_mask(
|
||||
causal_mask,
|
||||
B=1,
|
||||
H=1,
|
||||
Q_LEN=qkv_size,
|
||||
KV_LEN=qkv_size,
|
||||
device=self.device_type,
|
||||
)
|
||||
|
||||
for lb_type in ["None", "_HeadTailLoadBalancer"]:
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
f"Q_LEN {qkv_size} is not divisible",
|
||||
):
|
||||
self._test_cp_flex_attention(
|
||||
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
|
||||
)
|
||||
|
||||
for lb_type in ["_PTRRLoadBalancer"]:
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"must be divisible by group_size",
|
||||
):
|
||||
self._test_cp_flex_attention(
|
||||
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
|
||||
)
|
||||
|
||||
# TODO: merge with the above test
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@ -610,77 +674,71 @@ class CPFlexAttentionTest(DTensorTestBase):
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
|
||||
)
|
||||
def test_cp_flex_attention_document_mask(self) -> None:
|
||||
restore_enable_load_balance = _cp_options.enable_load_balance
|
||||
|
||||
random.seed(10)
|
||||
|
||||
# parameters for testing
|
||||
doc_count = 28
|
||||
enable_load_balance_list = [True, False]
|
||||
batch_size_list = [2, 4, 8]
|
||||
max_seq_len_list = [
|
||||
256 * self.world_size,
|
||||
2048,
|
||||
# 128 * self.world_size # NOTE: Mismatched elements: 8 / 131072 (0.0%),
|
||||
]
|
||||
load_balance_type = [
|
||||
"None",
|
||||
"_HeadTailLoadBalancer",
|
||||
"_PerDocumentHeadTailLoadBalancer",
|
||||
"_PTRRLoadBalancer",
|
||||
]
|
||||
|
||||
# NOTE: Each (enable_load_balance, batch_size, seq_len) tuple introduces 2
|
||||
# NOTE: Each (batch_size, seq_len, load_balance_type) tuple introduces 2
|
||||
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
|
||||
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
|
||||
# we need to increase the cache_size_limit to 12 which is the total number
|
||||
# of compilations in our test case.
|
||||
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
|
||||
# will be the total number of compilations in our test case.
|
||||
torch._dynamo.config.cache_size_limit = (
|
||||
2
|
||||
* len(enable_load_balance_list)
|
||||
* len(batch_size_list)
|
||||
* len(max_seq_len_list)
|
||||
2 * len(batch_size_list) * len(max_seq_len_list) * len(load_balance_type)
|
||||
)
|
||||
|
||||
# TODO: change this for-loop to run_subtests
|
||||
# Use a for-loop instead of run_subtests because we need to intialize the mask
|
||||
# for each subtest. This can be baked into self._test_cp_flex_attention as
|
||||
# a str argument denoting mask type.
|
||||
for enable_load_balance, batch_size, max_seq_len in itertools.product(
|
||||
enable_load_balance_list, batch_size_list, max_seq_len_list
|
||||
for batch_size, max_seq_len, lb_type in itertools.product(
|
||||
batch_size_list,
|
||||
max_seq_len_list,
|
||||
load_balance_type,
|
||||
):
|
||||
_cp_options.enable_load_balance = enable_load_balance
|
||||
|
||||
# initialize document mask
|
||||
lengths = [
|
||||
(
|
||||
generate_random_lengths_in_chunks(
|
||||
max_seq_len, doc_count, chunk_size=2 * self.world_size
|
||||
)
|
||||
if enable_load_balance
|
||||
if lb_type == "_PerDocumentHeadTailLoadBalancer"
|
||||
else generate_random_lengths(max_seq_len, doc_count)
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
offsets = length_to_offsets(lengths, self.device_type)
|
||||
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
|
||||
|
||||
# generate load balancer
|
||||
load_balancer = (
|
||||
_PerDocumentHeadTailLoadBalancer(
|
||||
lengths, self.world_size, torch.device(self.device_type)
|
||||
)
|
||||
if enable_load_balance
|
||||
else None
|
||||
block_mask = compiled_create_block_mask(
|
||||
document_causal_mask,
|
||||
B=batch_size,
|
||||
H=1,
|
||||
Q_LEN=max_seq_len,
|
||||
KV_LEN=max_seq_len,
|
||||
device=self.device_type,
|
||||
)
|
||||
|
||||
# construct testing function
|
||||
test_func = functools.partial(
|
||||
self._test_cp_flex_attention,
|
||||
self._test_cp_flex_attention(
|
||||
qkv_size=max_seq_len,
|
||||
B=batch_size,
|
||||
lb=load_balancer,
|
||||
mask_func=document_causal_mask,
|
||||
lb_type=lb_type,
|
||||
block_mask=block_mask,
|
||||
document_lengths=lengths,
|
||||
)
|
||||
|
||||
test_func()
|
||||
|
||||
_cp_options.enable_load_balance = restore_enable_load_balance
|
||||
|
||||
|
||||
class TestCPCustomOps(DTensorTestBase):
|
||||
@property
|
||||
|
||||
@ -266,6 +266,41 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
compiled_out = compiled_fn(mesh)
|
||||
self.assertEqual(opt_fn, compiled_out)
|
||||
|
||||
def test_get_local_rank_compile(self):
|
||||
mesh = init_device_mesh(
|
||||
self.device_type, (self.world_size,), mesh_dim_names=("dp",)
|
||||
)
|
||||
|
||||
def fn_with_str_arg(x):
|
||||
local_rank = x.device_mesh.get_local_rank("dp")
|
||||
return x * local_rank
|
||||
|
||||
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
|
||||
ref = fn_with_str_arg(x)
|
||||
|
||||
opt_fn = torch.compile(fn_with_str_arg, backend="aot_eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def fn_with_int_arg(x):
|
||||
local_rank = x.device_mesh.get_local_rank(0)
|
||||
return x * local_rank
|
||||
|
||||
ref2 = fn_with_int_arg(x)
|
||||
opt_fn2 = torch.compile(fn_with_int_arg, backend="aot_eager", fullgraph=True)
|
||||
res2 = opt_fn2(x)
|
||||
self.assertEqual(res2, ref2)
|
||||
|
||||
def fn_without_arg(x):
|
||||
# will fail if device_mesh.ndim > 1
|
||||
local_rank = x.device_mesh.get_local_rank()
|
||||
return x + local_rank
|
||||
|
||||
ref3 = fn_without_arg(x)
|
||||
opt_fn3 = torch.compile(fn_without_arg, backend="aot_eager", fullgraph=True)
|
||||
res3 = opt_fn3(x)
|
||||
self.assertEqual(res3, ref3)
|
||||
|
||||
def test_fakify_dtensor(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
||||
@ -239,9 +239,7 @@ class DTensorExportTest(TestCase):
|
||||
"view_9",
|
||||
"t_15",
|
||||
"detach",
|
||||
"detach_1",
|
||||
"detach_6",
|
||||
"detach_7",
|
||||
"detach_3",
|
||||
"threshold_backward_1",
|
||||
"t_16",
|
||||
"mm_6",
|
||||
@ -259,10 +257,8 @@ class DTensorExportTest(TestCase):
|
||||
"sum_1",
|
||||
"view_7",
|
||||
"t_7",
|
||||
"detach_1",
|
||||
"detach_2",
|
||||
"detach_3",
|
||||
"detach_4",
|
||||
"detach_5",
|
||||
"threshold_backward",
|
||||
"mm_2",
|
||||
"t_9",
|
||||
|
||||
@ -20,6 +20,7 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -1145,6 +1146,22 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
device_mesh = init_device_mesh(self.device_type, (4, 2))
|
||||
x = torch.randn(8, 4, device=self.device_type)
|
||||
# specify right-to-left order use _StridedShard
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
|
||||
)
|
||||
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -70,6 +70,8 @@ def get_patches():
|
||||
"force_disable_caches": True,
|
||||
# Messes up existing test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
|
||||
|
||||
@ -364,6 +366,8 @@ def get_bucket_patches(compute_multiplier=1.0):
|
||||
"force_disable_caches": True,
|
||||
# messes up test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
|
||||
|
||||
@ -579,7 +583,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap_blocking(self):
|
||||
def test_bucketing_split_for_overlap_blocking_no_deps(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import os
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -40,6 +41,13 @@ from torch.utils._typing_utils import not_none
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
|
||||
try:
|
||||
import torch._C._distributed_c10d.ProcessGroupNCCL
|
||||
|
||||
_NCCL_AVAILABLE = True
|
||||
except ImportError:
|
||||
_NCCL_AVAILABLE = False
|
||||
|
||||
|
||||
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
|
||||
os.environ["MASTER_ADDR"] = addr
|
||||
@ -962,6 +970,85 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
# check flattened mesh dependency
|
||||
self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d)
|
||||
|
||||
@with_comms
|
||||
def test_unflatten_mesh_2d(self):
|
||||
mesh_shape = (4, 2)
|
||||
mesh_dim_names = ("dp", "tp")
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
|
||||
self.assertEqual(
|
||||
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
|
||||
)
|
||||
self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
|
||||
self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group())
|
||||
|
||||
# Not supporting slicing out unflatten dim name from root mesh.
|
||||
with self.assertRaises(KeyError):
|
||||
self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh)
|
||||
|
||||
@with_comms
|
||||
def test_unflatten_mesh_3d(self):
|
||||
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
|
||||
global_mesh = init_device_mesh(
|
||||
self.device_type,
|
||||
(8,),
|
||||
mesh_dim_names=("world",),
|
||||
)
|
||||
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
|
||||
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
|
||||
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
|
||||
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
|
||||
mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))
|
||||
unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
|
||||
self.assertEqual(
|
||||
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"]
|
||||
)
|
||||
self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh)
|
||||
self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group())
|
||||
self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh)
|
||||
self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group())
|
||||
|
||||
# Test unflatten with backend override set.
|
||||
if not _NCCL_AVAILABLE:
|
||||
return
|
||||
opts = dist.ProcessGroupNCCL.Options()
|
||||
opts._timeout = timedelta(seconds=30)
|
||||
mesh_2d = global_mesh._unflatten(
|
||||
0,
|
||||
(1, 8),
|
||||
("pp", "spmd"),
|
||||
backend_override={"pp": "fake", "spmd": ("nccl", opts)},
|
||||
)
|
||||
opts = dist.ProcessGroupNCCL.Options()
|
||||
opts._timeout = timedelta(seconds=60)
|
||||
mesh_4d = mesh_2d._unflatten(
|
||||
1,
|
||||
(2, 2, 2),
|
||||
("dp", "cp", "tp"),
|
||||
backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)},
|
||||
)
|
||||
self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom")
|
||||
spmd_pg = mesh_2d["spmd"].get_group()
|
||||
self.assertEqual(spmd_pg._get_backend_name(), "nccl")
|
||||
w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
self.assertTrue(
|
||||
spmd_pg._get_backend(
|
||||
torch.device(f"cuda:{self.rank}")
|
||||
)._verify_work_timeout(w, timedelta(seconds=30))
|
||||
)
|
||||
w.wait()
|
||||
tp_pg = mesh_4d["tp"].get_group()
|
||||
self.assertEqual(tp_pg._get_backend_name(), "nccl")
|
||||
w = tp_pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
self.assertTrue(
|
||||
tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout(
|
||||
w, timedelta(seconds=60)
|
||||
)
|
||||
)
|
||||
w.wait()
|
||||
|
||||
@with_comms
|
||||
def test_reconstruct_mesh_with_flatten_dim(self):
|
||||
mesh_3d = init_device_mesh(
|
||||
@ -1577,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
|
||||
def test_remap_to_tensor(self):
|
||||
"""Test the remap_to_tensor method for various scenarios."""
|
||||
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
||||
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
|
||||
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
||||
result1 = layout1.remap_to_tensor(original_mesh)
|
||||
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# Test 2: Non-consecutive ranks - should map to actual ranks
|
||||
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
|
||||
layout2 = _Layout((2, 2), (2, 1))
|
||||
result2 = layout2.remap_to_tensor(original_mesh)
|
||||
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
||||
@ -1605,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
|
||||
self.assertEqual(result5, expected5)
|
||||
|
||||
# Test 6: Tensor Cute representation of a 2D mesh
|
||||
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
|
||||
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
||||
result6 = layout6.remap_to_tensor(original_mesh)
|
||||
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
|
||||
@ -273,12 +273,7 @@ class TestFakePG(TestCase):
|
||||
kwargs = {}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"FakeProcessGroup cannot be constructed directly\. "
|
||||
r"Use torch\.distributed\.init_process_group\(backend='fake'\) instead to ensure "
|
||||
r"proper dispatch system integration\.",
|
||||
):
|
||||
with self.assertRaisesRegex(TypeError, r"No constructor defined"):
|
||||
fake_pg = FakeProcessGroup(rank=0, world_size=3)
|
||||
|
||||
with SimpleTensorMode():
|
||||
|
||||
@ -1743,6 +1743,124 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
correct = f(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all"])
|
||||
def test_all_reduce_bucket(self, bucket_mode):
|
||||
def func(x, w, ar_0, ar_1, tag, ranks, group_size):
|
||||
y = torch.mm(x, w)
|
||||
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
ar_0_out = torch.ops._c10d_functional.all_reduce.default(
|
||||
ar_0, "sum", group_name
|
||||
)
|
||||
ar_1_out = torch.ops._c10d_functional.all_reduce.default(
|
||||
ar_1, "sum", group_name
|
||||
)
|
||||
|
||||
ar_0_w = torch.ops.c10d_functional.wait_tensor(ar_0_out)
|
||||
ar_1_w = torch.ops.c10d_functional.wait_tensor(ar_1_out)
|
||||
|
||||
return y, ar_0_w, ar_1_w
|
||||
|
||||
f = func
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ar_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ar_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ar_0, ar_1]
|
||||
f(*inputs, **self.get_world_trs())
|
||||
|
||||
def _pass(g):
|
||||
from torch._inductor.fx_passes.bucketing import bucket_all_reduce
|
||||
|
||||
bucket_all_reduce(g.owning_module, lambda _: 2000)
|
||||
|
||||
torch._inductor.config.post_grad_custom_post_pass = _pass
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(f)
|
||||
compiled(*inputs, **self.get_world_trs())
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unnecessary copy is made.
|
||||
(
|
||||
FileCheck()
|
||||
.check_count(
|
||||
"torch.ops._c10d_functional.all_reduce_.default(",
|
||||
count=1,
|
||||
exactly=True,
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
correct = f(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all_custom_ops_multidtype"])
|
||||
def test_all_gather_bucket_multidtype(self, bucket_mode):
|
||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
|
||||
ag_0_w = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0, group_size, group_name
|
||||
)
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_w)
|
||||
ag_0_out = ag_0_out * 2
|
||||
|
||||
ag_1_w = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_1, group_size, group_name
|
||||
)
|
||||
|
||||
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_w)
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.bfloat16)
|
||||
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
correct = func(*inputs, **self.get_world_trs())
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": bucket_mode,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
(
|
||||
FileCheck()
|
||||
.check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
||||
count=1,
|
||||
exactly=True,
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
_, y_ag0, y_ag1 = out
|
||||
assert y_ag0.dtype == ag_0.dtype
|
||||
assert y_ag1.dtype == ag_1.dtype
|
||||
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||
|
||||
@ -294,7 +294,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize("gather_dim", [0, 1, 2])
|
||||
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
|
||||
self._init_process()
|
||||
|
||||
@ -306,7 +306,10 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
rank = self.rank
|
||||
|
||||
torch.manual_seed(42 + rank)
|
||||
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
||||
A_shard_shape = [BATCH, M, K]
|
||||
A_shard_shape[gather_dim] //= self.world_size
|
||||
|
||||
A_shard = torch.rand(A_shard_shape, device="cuda")
|
||||
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
|
||||
|
||||
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
|
||||
@ -523,7 +526,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
BATCH = 8
|
||||
M = 64
|
||||
N = 16
|
||||
K = 32
|
||||
K = 1024
|
||||
group = dist.group.WORLD
|
||||
rank = self.rank
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py
|
||||
index e599b02c17d..750d7a84fb4 100644
|
||||
index e599b02c17d..057b6ec01b9 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_baseexception.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_baseexception.py
|
||||
@@ -1,10 +1,64 @@
|
||||
@ -78,7 +78,27 @@ index e599b02c17d..750d7a84fb4 100644
|
||||
self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set)
|
||||
|
||||
interface_tests = ("length", "args", "str", "repr")
|
||||
@@ -142,7 +193,7 @@ class ExceptionClassTests(unittest.TestCase):
|
||||
@@ -122,12 +173,13 @@ class ExceptionClassTests(unittest.TestCase):
|
||||
# in PyObject_SetAttr.
|
||||
import gc
|
||||
d = {}
|
||||
- class HashThisKeyWillClearTheDict(str):
|
||||
- def __hash__(self) -> int:
|
||||
- d.clear()
|
||||
- return super().__hash__()
|
||||
- class Value(str):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class HashThisKeyWillClearTheDict(str):
|
||||
+ def __hash__(self) -> int:
|
||||
+ d.clear()
|
||||
+ return super().__hash__()
|
||||
+ class Value(str):
|
||||
+ pass
|
||||
exc = Exception()
|
||||
|
||||
d[HashThisKeyWillClearTheDict()] = Value() # refcount of Value() is 1 now
|
||||
@@ -142,7 +194,7 @@ class ExceptionClassTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
|
||||
@ -87,7 +107,31 @@ index e599b02c17d..750d7a84fb4 100644
|
||||
|
||||
"""Test usage of exceptions"""
|
||||
|
||||
@@ -208,5 +259,5 @@ class UsageTests(unittest.TestCase):
|
||||
@@ -182,8 +234,9 @@ class UsageTests(unittest.TestCase):
|
||||
# BaseException; the ability was not possible until BaseException's
|
||||
# introduction so no need to support new-style objects that do not
|
||||
# inherit from it.
|
||||
- class NewStyleClass(object):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NewStyleClass(object):
|
||||
+ pass
|
||||
self.raise_fails(NewStyleClass)
|
||||
self.raise_fails(NewStyleClass())
|
||||
|
||||
@@ -194,8 +247,9 @@ class UsageTests(unittest.TestCase):
|
||||
def test_catch_non_BaseException(self):
|
||||
# Trying to catch an object that does not inherit from BaseException
|
||||
# is not allowed.
|
||||
- class NonBaseException(object):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NonBaseException(object):
|
||||
+ pass
|
||||
self.catch_fails(NonBaseException)
|
||||
self.catch_fails(NonBaseException())
|
||||
|
||||
@@ -208,5 +262,5 @@ class UsageTests(unittest.TestCase):
|
||||
self.catch_fails("spam")
|
||||
|
||||
|
||||
|
||||
@ -173,12 +173,13 @@ class ExceptionClassTests(__TestCase):
|
||||
# in PyObject_SetAttr.
|
||||
import gc
|
||||
d = {}
|
||||
class HashThisKeyWillClearTheDict(str):
|
||||
def __hash__(self) -> int:
|
||||
d.clear()
|
||||
return super().__hash__()
|
||||
class Value(str):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class HashThisKeyWillClearTheDict(str):
|
||||
def __hash__(self) -> int:
|
||||
d.clear()
|
||||
return super().__hash__()
|
||||
class Value(str):
|
||||
pass
|
||||
exc = Exception()
|
||||
|
||||
d[HashThisKeyWillClearTheDict()] = Value() # refcount of Value() is 1 now
|
||||
@ -233,8 +234,9 @@ class UsageTests(__TestCase):
|
||||
# BaseException; the ability was not possible until BaseException's
|
||||
# introduction so no need to support new-style objects that do not
|
||||
# inherit from it.
|
||||
class NewStyleClass(object):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NewStyleClass(object):
|
||||
pass
|
||||
self.raise_fails(NewStyleClass)
|
||||
self.raise_fails(NewStyleClass())
|
||||
|
||||
@ -245,8 +247,9 @@ class UsageTests(__TestCase):
|
||||
def test_catch_non_BaseException(self):
|
||||
# Trying to catch an object that does not inherit from BaseException
|
||||
# is not allowed.
|
||||
class NonBaseException(object):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NonBaseException(object):
|
||||
pass
|
||||
self.catch_fails(NonBaseException)
|
||||
self.catch_fails(NonBaseException())
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py
|
||||
index c91f6662948..0ded70db3c7 100644
|
||||
index c91f6662948..3a62dec411c 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_exceptions.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_exceptions.py
|
||||
@@ -1,3 +1,59 @@
|
||||
@ -71,7 +71,305 @@ index c91f6662948..0ded70db3c7 100644
|
||||
|
||||
def raise_catch(self, exc, excname):
|
||||
with self.subTest(exc=exc, excname=excname):
|
||||
@@ -1844,7 +1900,7 @@ class ExceptionTests(unittest.TestCase):
|
||||
@@ -343,12 +399,13 @@ class ExceptionTests(unittest.TestCase):
|
||||
# test that setting an exception at the C level works even if the
|
||||
# exception object can't be constructed.
|
||||
|
||||
- class BadException(Exception):
|
||||
- def __init__(self_):
|
||||
- raise RuntimeError("can't instantiate BadException")
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadException(Exception):
|
||||
+ def __init__(self_):
|
||||
+ raise RuntimeError("can't instantiate BadException")
|
||||
|
||||
- class InvalidException:
|
||||
- pass
|
||||
+ class InvalidException:
|
||||
+ pass
|
||||
|
||||
@unittest.skipIf(_testcapi is None, "requires _testcapi")
|
||||
def test_capi1():
|
||||
@@ -636,8 +693,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsInstance(e, IndexError)
|
||||
self.assertEqual(e.__traceback__, tb)
|
||||
|
||||
- class MyException(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ pass
|
||||
|
||||
e = MyException().with_traceback(tb)
|
||||
self.assertIsInstance(e, MyException)
|
||||
@@ -696,8 +754,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsNone(e.__context__)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
- class MyException(OSError):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(OSError):
|
||||
+ pass
|
||||
|
||||
e = MyException()
|
||||
self.assertIsNone(e.__context__)
|
||||
@@ -726,10 +785,11 @@ class ExceptionTests(unittest.TestCase):
|
||||
# but user-defined subclasses can if they want
|
||||
self.assertRaises(TypeError, BaseException, a=1)
|
||||
|
||||
- class DerivedException(BaseException):
|
||||
- def __init__(self, fancy_arg):
|
||||
- BaseException.__init__(self)
|
||||
- self.fancy_arg = fancy_arg
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DerivedException(BaseException):
|
||||
+ def __init__(self, fancy_arg):
|
||||
+ BaseException.__init__(self)
|
||||
+ self.fancy_arg = fancy_arg
|
||||
|
||||
x = DerivedException(fancy_arg=42)
|
||||
self.assertEqual(x.fancy_arg, 42)
|
||||
@@ -779,11 +839,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
# Make sure exception state is cleaned up as soon as the except
|
||||
# block is left. See #2507
|
||||
|
||||
- class MyException(Exception):
|
||||
- def __init__(self, obj):
|
||||
- self.obj = obj
|
||||
- class MyObj:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self, obj):
|
||||
+ self.obj = obj
|
||||
+ class MyObj:
|
||||
+ pass
|
||||
|
||||
def inner_raising_func():
|
||||
# Create some references in exception value and traceback
|
||||
@@ -881,11 +942,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsNone(obj)
|
||||
|
||||
# Inside an exception-silencing "with" block
|
||||
- class Context:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
- return True
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Context:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
+ return True
|
||||
obj = MyObj()
|
||||
wr = weakref.ref(obj)
|
||||
with Context():
|
||||
@@ -1027,11 +1089,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
def _check_generator_cleanup_exc_state(self, testfunc):
|
||||
# Issue #12791: exception state is cleaned up as soon as a generator
|
||||
# is closed (reference cycles are broken).
|
||||
- class MyException(Exception):
|
||||
- def __init__(self, obj):
|
||||
- self.obj = obj
|
||||
- class MyObj:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self, obj):
|
||||
+ self.obj = obj
|
||||
+ class MyObj:
|
||||
+ pass
|
||||
|
||||
def raising_gen():
|
||||
try:
|
||||
@@ -1090,10 +1153,11 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_3114(self):
|
||||
# Bug #3114: in its destructor, MyObject retrieves a pointer to
|
||||
# obsolete and/or deallocated objects.
|
||||
- class MyObject:
|
||||
- def __del__(self):
|
||||
- nonlocal e
|
||||
- e = sys.exception()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyObject:
|
||||
+ def __del__(self):
|
||||
+ nonlocal e
|
||||
+ e = sys.exception()
|
||||
e = ()
|
||||
try:
|
||||
raise Exception(MyObject())
|
||||
@@ -1103,12 +1167,13 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsNone(e)
|
||||
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
- class A(Exception):
|
||||
- pass
|
||||
- class B(Exception):
|
||||
- pass
|
||||
- class C(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(Exception):
|
||||
+ pass
|
||||
+ class B(Exception):
|
||||
+ pass
|
||||
+ class C(Exception):
|
||||
+ pass
|
||||
|
||||
# Create a context chain:
|
||||
# C -> B -> A
|
||||
@@ -1164,12 +1229,13 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_no_hang_on_context_chain_cycle2(self):
|
||||
# See issue 25782. Cycle at head of context chain.
|
||||
|
||||
- class A(Exception):
|
||||
- pass
|
||||
- class B(Exception):
|
||||
- pass
|
||||
- class C(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(Exception):
|
||||
+ pass
|
||||
+ class B(Exception):
|
||||
+ pass
|
||||
+ class C(Exception):
|
||||
+ pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@@ -1200,16 +1266,17 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
|
||||
- class A(Exception):
|
||||
- pass
|
||||
- class B(Exception):
|
||||
- pass
|
||||
- class C(Exception):
|
||||
- pass
|
||||
- class D(Exception):
|
||||
- pass
|
||||
- class E(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(Exception):
|
||||
+ pass
|
||||
+ class B(Exception):
|
||||
+ pass
|
||||
+ class C(Exception):
|
||||
+ pass
|
||||
+ class D(Exception):
|
||||
+ pass
|
||||
+ class E(Exception):
|
||||
+ pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@@ -1364,11 +1431,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_badisinstance(self):
|
||||
# Bug #2542: if issubclass(e, MyException) raises an exception,
|
||||
# it should be ignored
|
||||
- class Meta(type):
|
||||
- def __subclasscheck__(cls, subclass):
|
||||
- raise ValueError()
|
||||
- class MyException(Exception, metaclass=Meta):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Meta(type):
|
||||
+ def __subclasscheck__(cls, subclass):
|
||||
+ raise ValueError()
|
||||
+ class MyException(Exception, metaclass=Meta):
|
||||
+ pass
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
try:
|
||||
@@ -1602,8 +1670,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertTrue(issubclass(error3, error2))
|
||||
|
||||
# test with explicit base tuple
|
||||
- class C(object):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ pass
|
||||
error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4,
|
||||
(error3, C))
|
||||
self.assertTrue(issubclass(error4, error3))
|
||||
@@ -1623,8 +1692,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
# Issue #5437: preallocated MemoryError instances should not keep
|
||||
# traceback objects alive.
|
||||
from _testcapi import raise_memoryerror
|
||||
- class C:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@@ -1644,8 +1714,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
@no_tracing
|
||||
def test_recursion_error_cleanup(self):
|
||||
# Same test as above, but with "recursion exceeded" errors
|
||||
- class C:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@@ -1670,11 +1741,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
|
||||
def test_unraisable(self):
|
||||
# Issue #22836: PyErr_WriteUnraisable() should give sensible reports
|
||||
- class BrokenDel:
|
||||
- def __del__(self):
|
||||
- exc = ValueError("del is broken")
|
||||
- # The following line is included in the traceback report:
|
||||
- raise exc
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BrokenDel:
|
||||
+ def __del__(self):
|
||||
+ exc = ValueError("del is broken")
|
||||
+ # The following line is included in the traceback report:
|
||||
+ raise exc
|
||||
|
||||
obj = BrokenDel()
|
||||
with support.catch_unraisable_exception() as cm:
|
||||
@@ -1728,11 +1800,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
|
||||
def test_yield_in_nested_try_excepts(self):
|
||||
#Issue #25612
|
||||
- class MainError(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MainError(Exception):
|
||||
+ pass
|
||||
|
||||
- class SubError(Exception):
|
||||
- pass
|
||||
+ class SubError(Exception):
|
||||
+ pass
|
||||
|
||||
def main():
|
||||
try:
|
||||
@@ -1807,8 +1880,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
# subclass object. Finally, it checks that creating a new MemoryError
|
||||
# succeeds, proving that the freelist is not corrupted.
|
||||
|
||||
- class TestException(MemoryError):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestException(MemoryError):
|
||||
+ pass
|
||||
|
||||
try:
|
||||
raise MemoryError
|
||||
@@ -1844,7 +1918,7 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIn(b'MemoryError', err)
|
||||
|
||||
|
||||
@ -80,7 +378,18 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_name_error_has_name(self):
|
||||
try:
|
||||
bluch
|
||||
@@ -1894,7 +1950,7 @@ class NameErrorTests(unittest.TestCase):
|
||||
@@ -1886,15 +1960,16 @@ class NameErrorTests(unittest.TestCase):
|
||||
|
||||
def test_gh_111654(self):
|
||||
def f():
|
||||
- class TestClass:
|
||||
- TestClass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestClass:
|
||||
+ TestClass
|
||||
|
||||
self.assertRaises(NameError, f)
|
||||
|
||||
# Note: name suggestion tests live in `test_traceback`.
|
||||
|
||||
|
||||
@ -89,7 +398,33 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_attributes(self):
|
||||
# Setting 'attr' should not be a problem.
|
||||
exc = AttributeError('Ouch!')
|
||||
@@ -1937,7 +1993,7 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
@@ -1907,8 +1982,9 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
self.assertIs(exc.obj, sentinel)
|
||||
|
||||
def test_getattr_has_name_and_obj(self):
|
||||
- class A:
|
||||
- blech = None
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ blech = None
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@@ -1923,9 +1999,10 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
self.assertEqual(obj, exc.obj)
|
||||
|
||||
def test_getattr_has_name_and_obj_for_method(self):
|
||||
- class A:
|
||||
- def blech(self):
|
||||
- return
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ def blech(self):
|
||||
+ return
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@@ -1937,7 +2014,7 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
# Note: name suggestion tests live in `test_traceback`.
|
||||
|
||||
|
||||
@ -98,7 +433,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
|
||||
def test_attributes(self):
|
||||
# Setting 'name' and 'path' should not be a problem.
|
||||
@@ -2024,7 +2080,7 @@ def run_script(source):
|
||||
@@ -2024,7 +2101,7 @@ def run_script(source):
|
||||
_rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN)
|
||||
return err.decode('utf-8').splitlines()
|
||||
|
||||
@ -107,7 +442,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def tearDown(self):
|
||||
unlink(TESTFN)
|
||||
|
||||
@@ -2159,7 +2215,7 @@ class AssertionErrorTests(unittest.TestCase):
|
||||
@@ -2159,7 +2236,7 @@ class AssertionErrorTests(unittest.TestCase):
|
||||
|
||||
|
||||
@support.force_not_colorized_test_class
|
||||
@ -116,7 +451,19 @@ index c91f6662948..0ded70db3c7 100644
|
||||
maxDiff = None
|
||||
|
||||
@force_not_colorized
|
||||
@@ -2290,6 +2346,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
@@ -2254,8 +2331,9 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
the_exception = exc
|
||||
|
||||
def test_subclass(self):
|
||||
- class MySyntaxError(SyntaxError):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySyntaxError(SyntaxError):
|
||||
+ pass
|
||||
|
||||
try:
|
||||
raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7))
|
||||
@@ -2290,6 +2368,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
err = run_script(b"\x89")
|
||||
self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1])
|
||||
|
||||
@ -124,7 +471,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_string_source(self):
|
||||
def try_compile(source):
|
||||
with self.assertRaises(SyntaxError) as cm:
|
||||
@@ -2405,7 +2462,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
@@ -2405,7 +2484,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
self.assertRaises(TypeError, SyntaxError, "bad bad", args)
|
||||
|
||||
|
||||
@ -133,7 +480,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_except_star_invalid_exception_type(self):
|
||||
with self.assertRaises(TypeError):
|
||||
try:
|
||||
@@ -2420,7 +2477,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase):
|
||||
@@ -2420,7 +2499,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -142,7 +489,42 @@ index c91f6662948..0ded70db3c7 100644
|
||||
|
||||
def lineno_after_raise(self, f, *expected):
|
||||
try:
|
||||
@@ -2529,5 +2586,5 @@ class PEP626Tests(unittest.TestCase):
|
||||
@@ -2499,11 +2578,12 @@ class PEP626Tests(unittest.TestCase):
|
||||
self.lineno_after_raise(in_finally_except, 4)
|
||||
|
||||
def test_lineno_after_with(self):
|
||||
- class Noop:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__(self, *args):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Noop:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__(self, *args):
|
||||
+ pass
|
||||
def after_with():
|
||||
with Noop():
|
||||
1/0
|
||||
@@ -2518,16 +2598,17 @@ class PEP626Tests(unittest.TestCase):
|
||||
self.lineno_after_raise(f, None)
|
||||
|
||||
def test_lineno_after_raise_in_with_exit(self):
|
||||
- class ExitFails:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__(self, *args):
|
||||
- raise ValueError
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ExitFails:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__(self, *args):
|
||||
+ raise ValueError
|
||||
|
||||
def after_with():
|
||||
with ExitFails():
|
||||
1/0
|
||||
self.lineno_after_raise(after_with, 1, 1)
|
||||
|
||||
|
||||
@ -399,12 +399,13 @@ class ExceptionTests(__TestCase):
|
||||
# test that setting an exception at the C level works even if the
|
||||
# exception object can't be constructed.
|
||||
|
||||
class BadException(Exception):
|
||||
def __init__(self_):
|
||||
raise RuntimeError("can't instantiate BadException")
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadException(Exception):
|
||||
def __init__(self_):
|
||||
raise RuntimeError("can't instantiate BadException")
|
||||
|
||||
class InvalidException:
|
||||
pass
|
||||
class InvalidException:
|
||||
pass
|
||||
|
||||
@unittest.skipIf(_testcapi is None, "requires _testcapi")
|
||||
def test_capi1():
|
||||
@ -692,8 +693,9 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsInstance(e, IndexError)
|
||||
self.assertEqual(e.__traceback__, tb)
|
||||
|
||||
class MyException(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
e = MyException().with_traceback(tb)
|
||||
self.assertIsInstance(e, MyException)
|
||||
@ -752,8 +754,9 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsNone(e.__context__)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
class MyException(OSError):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(OSError):
|
||||
pass
|
||||
|
||||
e = MyException()
|
||||
self.assertIsNone(e.__context__)
|
||||
@ -782,10 +785,11 @@ class ExceptionTests(__TestCase):
|
||||
# but user-defined subclasses can if they want
|
||||
self.assertRaises(TypeError, BaseException, a=1)
|
||||
|
||||
class DerivedException(BaseException):
|
||||
def __init__(self, fancy_arg):
|
||||
BaseException.__init__(self)
|
||||
self.fancy_arg = fancy_arg
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DerivedException(BaseException):
|
||||
def __init__(self, fancy_arg):
|
||||
BaseException.__init__(self)
|
||||
self.fancy_arg = fancy_arg
|
||||
|
||||
x = DerivedException(fancy_arg=42)
|
||||
self.assertEqual(x.fancy_arg, 42)
|
||||
@ -835,11 +839,12 @@ class ExceptionTests(__TestCase):
|
||||
# Make sure exception state is cleaned up as soon as the except
|
||||
# block is left. See #2507
|
||||
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
|
||||
def inner_raising_func():
|
||||
# Create some references in exception value and traceback
|
||||
@ -937,11 +942,12 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsNone(obj)
|
||||
|
||||
# Inside an exception-silencing "with" block
|
||||
class Context:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
return True
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Context:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
return True
|
||||
obj = MyObj()
|
||||
wr = weakref.ref(obj)
|
||||
with Context():
|
||||
@ -1083,11 +1089,12 @@ class ExceptionTests(__TestCase):
|
||||
def _check_generator_cleanup_exc_state(self, testfunc):
|
||||
# Issue #12791: exception state is cleaned up as soon as a generator
|
||||
# is closed (reference cycles are broken).
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
|
||||
def raising_gen():
|
||||
try:
|
||||
@ -1146,10 +1153,11 @@ class ExceptionTests(__TestCase):
|
||||
def test_3114(self):
|
||||
# Bug #3114: in its destructor, MyObject retrieves a pointer to
|
||||
# obsolete and/or deallocated objects.
|
||||
class MyObject:
|
||||
def __del__(self):
|
||||
nonlocal e
|
||||
e = sys.exception()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyObject:
|
||||
def __del__(self):
|
||||
nonlocal e
|
||||
e = sys.exception()
|
||||
e = ()
|
||||
try:
|
||||
raise Exception(MyObject())
|
||||
@ -1159,12 +1167,13 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsNone(e)
|
||||
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
|
||||
# Create a context chain:
|
||||
# C -> B -> A
|
||||
@ -1220,12 +1229,13 @@ class ExceptionTests(__TestCase):
|
||||
def test_no_hang_on_context_chain_cycle2(self):
|
||||
# See issue 25782. Cycle at head of context chain.
|
||||
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@ -1256,16 +1266,17 @@ class ExceptionTests(__TestCase):
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
class D(Exception):
|
||||
pass
|
||||
class E(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
class D(Exception):
|
||||
pass
|
||||
class E(Exception):
|
||||
pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@ -1420,11 +1431,12 @@ class ExceptionTests(__TestCase):
|
||||
def test_badisinstance(self):
|
||||
# Bug #2542: if issubclass(e, MyException) raises an exception,
|
||||
# it should be ignored
|
||||
class Meta(type):
|
||||
def __subclasscheck__(cls, subclass):
|
||||
raise ValueError()
|
||||
class MyException(Exception, metaclass=Meta):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Meta(type):
|
||||
def __subclasscheck__(cls, subclass):
|
||||
raise ValueError()
|
||||
class MyException(Exception, metaclass=Meta):
|
||||
pass
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
try:
|
||||
@ -1658,8 +1670,9 @@ class ExceptionTests(__TestCase):
|
||||
self.assertTrue(issubclass(error3, error2))
|
||||
|
||||
# test with explicit base tuple
|
||||
class C(object):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
pass
|
||||
error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4,
|
||||
(error3, C))
|
||||
self.assertTrue(issubclass(error4, error3))
|
||||
@ -1679,8 +1692,9 @@ class ExceptionTests(__TestCase):
|
||||
# Issue #5437: preallocated MemoryError instances should not keep
|
||||
# traceback objects alive.
|
||||
from _testcapi import raise_memoryerror
|
||||
class C:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@ -1700,8 +1714,9 @@ class ExceptionTests(__TestCase):
|
||||
@no_tracing
|
||||
def test_recursion_error_cleanup(self):
|
||||
# Same test as above, but with "recursion exceeded" errors
|
||||
class C:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@ -1726,11 +1741,12 @@ class ExceptionTests(__TestCase):
|
||||
|
||||
def test_unraisable(self):
|
||||
# Issue #22836: PyErr_WriteUnraisable() should give sensible reports
|
||||
class BrokenDel:
|
||||
def __del__(self):
|
||||
exc = ValueError("del is broken")
|
||||
# The following line is included in the traceback report:
|
||||
raise exc
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BrokenDel:
|
||||
def __del__(self):
|
||||
exc = ValueError("del is broken")
|
||||
# The following line is included in the traceback report:
|
||||
raise exc
|
||||
|
||||
obj = BrokenDel()
|
||||
with support.catch_unraisable_exception() as cm:
|
||||
@ -1784,11 +1800,12 @@ class ExceptionTests(__TestCase):
|
||||
|
||||
def test_yield_in_nested_try_excepts(self):
|
||||
#Issue #25612
|
||||
class MainError(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MainError(Exception):
|
||||
pass
|
||||
|
||||
class SubError(Exception):
|
||||
pass
|
||||
class SubError(Exception):
|
||||
pass
|
||||
|
||||
def main():
|
||||
try:
|
||||
@ -1863,8 +1880,9 @@ class ExceptionTests(__TestCase):
|
||||
# subclass object. Finally, it checks that creating a new MemoryError
|
||||
# succeeds, proving that the freelist is not corrupted.
|
||||
|
||||
class TestException(MemoryError):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestException(MemoryError):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise MemoryError
|
||||
@ -1942,8 +1960,9 @@ class NameErrorTests(__TestCase):
|
||||
|
||||
def test_gh_111654(self):
|
||||
def f():
|
||||
class TestClass:
|
||||
TestClass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestClass:
|
||||
TestClass
|
||||
|
||||
self.assertRaises(NameError, f)
|
||||
|
||||
@ -1963,8 +1982,9 @@ class AttributeErrorTests(__TestCase):
|
||||
self.assertIs(exc.obj, sentinel)
|
||||
|
||||
def test_getattr_has_name_and_obj(self):
|
||||
class A:
|
||||
blech = None
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
blech = None
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@ -1979,9 +1999,10 @@ class AttributeErrorTests(__TestCase):
|
||||
self.assertEqual(obj, exc.obj)
|
||||
|
||||
def test_getattr_has_name_and_obj_for_method(self):
|
||||
class A:
|
||||
def blech(self):
|
||||
return
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
def blech(self):
|
||||
return
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@ -2310,8 +2331,9 @@ class SyntaxErrorTests(__TestCase):
|
||||
the_exception = exc
|
||||
|
||||
def test_subclass(self):
|
||||
class MySyntaxError(SyntaxError):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySyntaxError(SyntaxError):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7))
|
||||
@ -2556,11 +2578,12 @@ class PEP626Tests(__TestCase):
|
||||
self.lineno_after_raise(in_finally_except, 4)
|
||||
|
||||
def test_lineno_after_with(self):
|
||||
class Noop:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Noop:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
def after_with():
|
||||
with Noop():
|
||||
1/0
|
||||
@ -2575,11 +2598,12 @@ class PEP626Tests(__TestCase):
|
||||
self.lineno_after_raise(f, None)
|
||||
|
||||
def test_lineno_after_raise_in_with_exit(self):
|
||||
class ExitFails:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
raise ValueError
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ExitFails:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
raise ValueError
|
||||
|
||||
def after_with():
|
||||
with ExitFails():
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_raise.py b/test/dynamo/cpython/3_13/test_raise.py
|
||||
index 6d26a61bee4..042d1ae3d7c 100644
|
||||
index 6d26a61bee4..ce748433d28 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_raise.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_raise.py
|
||||
@@ -1,3 +1,58 @@
|
||||
@ -70,7 +70,35 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
def test_invalid_reraise(self):
|
||||
try:
|
||||
raise
|
||||
@@ -148,7 +203,7 @@ class TestRaise(unittest.TestCase):
|
||||
@@ -120,9 +175,10 @@ class TestRaise(unittest.TestCase):
|
||||
self.assertRaises(StopIteration, lambda: next(g))
|
||||
|
||||
def test_erroneous_exception(self):
|
||||
- class MyException(Exception):
|
||||
- def __init__(self):
|
||||
- raise RuntimeError()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self):
|
||||
+ raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise MyException
|
||||
@@ -133,9 +189,10 @@ class TestRaise(unittest.TestCase):
|
||||
|
||||
def test_new_returns_invalid_instance(self):
|
||||
# See issue #11627.
|
||||
- class MyException(Exception):
|
||||
- def __new__(cls, *args):
|
||||
- return object()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __new__(cls, *args):
|
||||
+ return object()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise MyException
|
||||
@@ -148,7 +205,7 @@ class TestRaise(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
@ -79,7 +107,37 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
|
||||
def testCauseSyntax(self):
|
||||
try:
|
||||
@@ -221,7 +276,7 @@ class TestCause(unittest.TestCase):
|
||||
@@ -186,10 +243,11 @@ class TestCause(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_class_cause_nonexception_result(self):
|
||||
- class ConstructsNone(BaseException):
|
||||
- @classmethod
|
||||
- def __new__(*args, **kwargs):
|
||||
- return None
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ConstructsNone(BaseException):
|
||||
+ @classmethod
|
||||
+ def __new__(*args, **kwargs):
|
||||
+ return None
|
||||
try:
|
||||
raise IndexError from ConstructsNone
|
||||
except TypeError as e:
|
||||
@@ -209,9 +267,10 @@ class TestCause(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_erroneous_cause(self):
|
||||
- class MyException(Exception):
|
||||
- def __init__(self):
|
||||
- raise RuntimeError()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self):
|
||||
+ raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise IndexError from MyException
|
||||
@@ -221,7 +280,7 @@ class TestCause(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
@ -88,7 +146,7 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
|
||||
def test_sets_traceback(self):
|
||||
try:
|
||||
@@ -242,7 +297,7 @@ class TestTraceback(unittest.TestCase):
|
||||
@@ -242,7 +301,7 @@ class TestTraceback(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
@ -97,7 +155,7 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
|
||||
def raiser(self):
|
||||
raise ValueError
|
||||
@@ -308,7 +363,7 @@ class TestTracebackType(unittest.TestCase):
|
||||
@@ -308,7 +367,7 @@ class TestTracebackType(unittest.TestCase):
|
||||
types.TracebackType(other_tb, frame, 1, "nuh-uh")
|
||||
|
||||
|
||||
@ -106,7 +164,45 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
def test_instance_context_instance_raise(self):
|
||||
context = IndexError()
|
||||
try:
|
||||
@@ -498,7 +553,7 @@ class TestContext(unittest.TestCase):
|
||||
@@ -392,11 +451,12 @@ class TestContext(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_context_manager(self):
|
||||
- class ContextManager:
|
||||
- def __enter__(self):
|
||||
- pass
|
||||
- def __exit__(self, t, v, tb):
|
||||
- xyzzy
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ContextManager:
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
+ def __exit__(self, t, v, tb):
|
||||
+ xyzzy
|
||||
try:
|
||||
with ContextManager():
|
||||
1/0
|
||||
@@ -471,12 +531,13 @@ class TestContext(unittest.TestCase):
|
||||
import gc
|
||||
# A re-raised exception in a __del__ caused the __context__
|
||||
# to be cleared
|
||||
- class C:
|
||||
- def __del__(self):
|
||||
- try:
|
||||
- 1/0
|
||||
- except:
|
||||
- raise
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def __del__(self):
|
||||
+ try:
|
||||
+ 1/0
|
||||
+ except:
|
||||
+ raise
|
||||
|
||||
def f():
|
||||
x = C()
|
||||
@@ -498,7 +559,7 @@ class TestContext(unittest.TestCase):
|
||||
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
|
||||
|
||||
|
||||
@ -115,7 +211,7 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
def test_tuples(self):
|
||||
try:
|
||||
raise (IndexError, KeyError) # This should be a tuple!
|
||||
@@ -517,4 +572,4 @@ class TestRemovedFunctionality(unittest.TestCase):
|
||||
@@ -517,4 +578,4 @@ class TestRemovedFunctionality(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -175,9 +175,10 @@ class TestRaise(__TestCase):
|
||||
self.assertRaises(StopIteration, lambda: next(g))
|
||||
|
||||
def test_erroneous_exception(self):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise MyException
|
||||
@ -188,9 +189,10 @@ class TestRaise(__TestCase):
|
||||
|
||||
def test_new_returns_invalid_instance(self):
|
||||
# See issue #11627.
|
||||
class MyException(Exception):
|
||||
def __new__(cls, *args):
|
||||
return object()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __new__(cls, *args):
|
||||
return object()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise MyException
|
||||
@ -241,10 +243,11 @@ class TestCause(__TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_class_cause_nonexception_result(self):
|
||||
class ConstructsNone(BaseException):
|
||||
@classmethod
|
||||
def __new__(*args, **kwargs):
|
||||
return None
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ConstructsNone(BaseException):
|
||||
@classmethod
|
||||
def __new__(*args, **kwargs):
|
||||
return None
|
||||
try:
|
||||
raise IndexError from ConstructsNone
|
||||
except TypeError as e:
|
||||
@ -264,9 +267,10 @@ class TestCause(__TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_erroneous_cause(self):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise IndexError from MyException
|
||||
@ -447,11 +451,12 @@ class TestContext(__TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_context_manager(self):
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, t, v, tb):
|
||||
xyzzy
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, t, v, tb):
|
||||
xyzzy
|
||||
try:
|
||||
with ContextManager():
|
||||
1/0
|
||||
@ -526,12 +531,13 @@ class TestContext(__TestCase):
|
||||
import gc
|
||||
# A re-raised exception in a __del__ caused the __context__
|
||||
# to be cleared
|
||||
class C:
|
||||
def __del__(self):
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
raise
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __del__(self):
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
raise
|
||||
|
||||
def f():
|
||||
x = C()
|
||||
|
||||
@ -838,6 +838,55 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
y,
|
||||
use_reentrant=False,
|
||||
context_fn=selective_checkpointing_context_fn,
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True, device=device)
|
||||
y = torch.randn(4, 4, requires_grad=True, device=device)
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=1,
|
||||
op=torch.ops.aten.sigmoid.default,
|
||||
)
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
# Main check here is just that sigmoid is properly recomputed
|
||||
# (we will see a sigmoid() and sigmoid_backward() in the bw graph)
|
||||
freq=1,
|
||||
op=torch.ops.aten.sigmoid.default,
|
||||
)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
|
||||
@ -921,7 +921,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
1|aten._native_batch_norm_legit_functional.default|batch_norm|
|
||||
2|aten.relu.default|relu|
|
||||
2|aten.detach.default|relu|
|
||||
2|aten.detach.default|relu|
|
||||
3|aten.add.Tensor|add|
|
||||
4|aten.view.default|flatten|
|
||||
5|aten.view.default|linear|
|
||||
@ -948,7 +947,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
5|aten.view.default||linear
|
||||
4|aten.view.default||flatten
|
||||
2|aten.detach.default||relu
|
||||
2|aten.detach.default||relu
|
||||
2|aten.threshold_backward.default||relu
|
||||
1|aten.native_batch_norm_backward.default||batch_norm
|
||||
0|aten.convolution_backward.default||conv2d
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
@ -203,6 +204,39 @@ class TestAOTCompile(torch._inductor.test_case.TestCase):
|
||||
actual = compiled_fn(*example_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_decorated_function_with_functools_wrap_aot(self):
|
||||
def check_inputs(fn):
|
||||
@functools.wraps(fn)
|
||||
def _fn(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert arg.shape[0] > 1
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _fn
|
||||
|
||||
@check_inputs
|
||||
def foo(x, y):
|
||||
a = x + x
|
||||
b = y + y
|
||||
c = a + b
|
||||
return c
|
||||
|
||||
example_inputs = (torch.ones(3), torch.ones(3))
|
||||
expected = foo(*example_inputs)
|
||||
|
||||
def backend(gm, example_inputs):
|
||||
return CustomCompiledFunction(gm, example_inputs)
|
||||
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
compiled_fn = torch.compile(
|
||||
foo,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
).aot_compile((example_inputs, {}))
|
||||
actual = compiled_fn(*example_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_disable_guard_check(self):
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
# ruff: noqa: TRY002
|
||||
|
||||
import enum
|
||||
import itertools
|
||||
import operator
|
||||
import types
|
||||
@ -56,6 +57,30 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_contains_enum(self):
|
||||
class TensorDim(str, enum.Enum):
|
||||
DDP = "ddp"
|
||||
FSDP = "fsdp"
|
||||
CP = "cp"
|
||||
TP = "tp"
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
val = x.sin()
|
||||
if TensorDim.DDP in {"ddp"}:
|
||||
val += x.cos()
|
||||
if "ddp" in {TensorDim.DDP}:
|
||||
val += x.cos()
|
||||
return val
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
mod = Foo()
|
||||
opt_f = torch.compile(mod)
|
||||
self.assertEqual(mod(inp), opt_f(inp))
|
||||
|
||||
def test_dict_subclass_local_with_non_dict_method(self):
|
||||
# Checks that add_1 method is inlined
|
||||
class MethodDict(dict):
|
||||
|
||||
@ -113,7 +113,7 @@ sort with non-constant keys
|
||||
Explanation: Cannot perform sort with non-constant key. First non-constant key type: <class 'torch.Tensor'>. Most notably, we cannot sort with Tensor or SymInt keys, but we can sort ints.
|
||||
Hint: Use something else as the key.
|
||||
|
||||
Developer debug context: TensorVariable()
|
||||
Developer debug context: LazyVariableTracker(realized: TensorVariable())
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0207.html
|
||||
|
||||
@ -216,7 +216,7 @@ Unsupported context manager
|
||||
Hint: If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then it may be the case that it was created outside the compiled region, which Dynamo does not support. Supported context managers can cross graph break boundaries only if they are local non-closure variables, or are intermediate values.
|
||||
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
|
||||
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on ConstantVariable(int: 3)
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on LazyVariableTracker(realized: ConstantVariable(int: 3))
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html
|
||||
|
||||
@ -543,7 +543,7 @@ Dynamic slicing with Tensor arguments
|
||||
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
|
||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
||||
|
||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None)
|
||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
|
||||
|
||||
@ -869,6 +869,51 @@ from user code:
|
||||
if x.sum() > 0:""",
|
||||
)
|
||||
|
||||
# Test that the bytecode source attribution is correct with VariableTracker
|
||||
@make_logging_test(trace_bytecode=True)
|
||||
def test_variable_tracker_source_attribution(self, records):
|
||||
def inner(x):
|
||||
return x + 1
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
x = inner(x)
|
||||
return inner(x)
|
||||
|
||||
fn(torch.ones(3))
|
||||
|
||||
def find_trace_bytecode_lines(long_string):
|
||||
# Split the string into lines
|
||||
lines = long_string.split("\n")
|
||||
# More comprehensive pattern to capture LazyVariableTracker info
|
||||
pattern = r"LazyVariableTracker\([^)]*\)"
|
||||
# Find all lines containing the pattern
|
||||
result = [line for line in lines if re.search(pattern, line)]
|
||||
return result
|
||||
|
||||
# Get all log messages, not just the last one
|
||||
all_messages = []
|
||||
for record in records:
|
||||
msg = munge_exc(record.getMessage(), skip=0)
|
||||
|
||||
all_messages.append(msg)
|
||||
|
||||
# Combine all messages to search through
|
||||
combined_msg = "\n".join(all_messages)
|
||||
all_lines = find_trace_bytecode_lines(combined_msg)
|
||||
|
||||
# For now, just check that we found some lines with LazyVariableTracker
|
||||
self.assertGreater(
|
||||
len(all_lines), 0, "Should find at least one LazyVariableTracker line"
|
||||
)
|
||||
|
||||
self.assertIn(
|
||||
"LazyVariableTracker(unrealized: <class 'function'>)", all_lines[0]
|
||||
)
|
||||
self.assertIn(
|
||||
"LazyVariableTracker(realized: UserFunctionVariable())", all_lines[3]
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_data_dependent_branching_gb(self, records):
|
||||
def fn(x):
|
||||
@ -1141,17 +1186,17 @@ NOTE: the most recent `torch.compile` tracing attempt might not be where you app
|
||||
Most recent bytecode instructions traced (max 20):
|
||||
TRACE RESUME 0 []
|
||||
TRACE LOAD_FAST 'x' []
|
||||
TRACE LOAD_CONST 1 [LazyVariableTracker()]
|
||||
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
|
||||
TRACE LOAD_CONST 1 [LazyVariableTracker(unrealized: <class 'torch.Tensor'>)]
|
||||
TRACE BINARY_OP 0 [LazyVariableTracker(unrealized: <class 'torch.Tensor'>), ConstantVariable(int: 1)]
|
||||
TRACE STORE_FAST 'y' [TensorVariable()]
|
||||
TRACE LOAD_FAST 'x' []
|
||||
TRACE LOAD_FAST 'y' [TensorVariable()]
|
||||
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
|
||||
TRACE STORE_FAST 'z' [TensorVariable()]
|
||||
TRACE LOAD_GLOBAL 'torch' []
|
||||
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker()]
|
||||
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker()]
|
||||
TRACE CALL 0 [NullVariable, LazyVariableTracker()]""",
|
||||
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker(unrealized: <class 'module'>)]
|
||||
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker(unrealized: <class 'module'>)]
|
||||
TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user