Compare commits

..

1 Commits

Author SHA1 Message Date
3e3fd97b6a [dynamo] Use strings instead of modules for fqn info tracking 2025-09-30 11:52:49 -07:00
57 changed files with 1299 additions and 2088 deletions

View File

@ -1 +1 @@
78a47f87ce259a48f0391fa9ae15add05ea7432b
0307428d65acf5cf1a73a70a7722e076bbb83f22

View File

@ -127,6 +127,53 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
),
]
ROCM_SMOKE_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_variant="rocm",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["6.4"],
python_versions=["3.10"],
),
ciflow_config=CIFlowConfig(
labels={
LABEL_CIFLOW_BINARIES,
LABEL_CIFLOW_BINARIES_WHEEL,
LABEL_CIFLOW_ROCM,
},
isolated_workflow=True,
),
branches="main",
),
]
LINUX_BINARY_SMOKE_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["13.0"],
python_versions=["3.12"],
),
branches="main",
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="libtorch",
build_variant=generate_binary_build_matrix.RELEASE,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.LINUX,
generate_binary_build_matrix.RELEASE,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
branches="main",
),
]
WINDOWS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
@ -212,6 +259,39 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
),
]
WINDOWS_BINARY_SMOKE_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
package_type="libtorch",
build_variant=generate_binary_build_matrix.RELEASE,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.WINDOWS,
generate_binary_build_matrix.RELEASE,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
branches="main",
ciflow_config=CIFlowConfig(
isolated_workflow=True,
),
),
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
package_type="libtorch",
build_variant=generate_binary_build_matrix.DEBUG,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.WINDOWS,
generate_binary_build_matrix.DEBUG,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
branches="main",
ciflow_config=CIFlowConfig(
isolated_workflow=True,
),
),
]
MACOS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.MACOS_ARM64,
@ -292,10 +372,23 @@ def main() -> None:
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
S390X_BINARY_BUILD_WORKFLOWS,
),
(
# Give rocm it's own workflow file
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
ROCM_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
LINUX_BINARY_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
WINDOWS_BINARY_BUILD_WORKFLOWS,
),
(
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
WINDOWS_BINARY_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("macos_binary_build_workflow.yml.j2"),
MACOS_BINARY_BUILD_WORKFLOWS,

View File

@ -0,0 +1,87 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-libtorch-release
on:
push:
branches:
- main
tags:
- 'ciflow/trunk/*'
workflow_dispatch:
permissions:
id-token: write
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-libtorch-release
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: cpu
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: libtorch-cpu-shared-with-deps-release
build_environment: linux-binary-libtorch-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-cpu-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cpu-shared-with-deps-release-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: cpu
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-cpu-shared-with-deps-release
build_environment: linux-binary-libtorch-release
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -0,0 +1,88 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-manywheel
on:
push:
branches:
- main
tags:
- 'ciflow/trunk/*'
workflow_dispatch:
permissions:
id-token: write
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-manywheel
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
manywheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu130
GPU_ARCH_VERSION: "13.0"
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
DESIRED_PYTHON: "3.12"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_12-cuda13_0
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-cuda13_0-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_12-cuda13_0-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu130
GPU_ARCH_VERSION: "13.0"
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cuda13_0
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -0,0 +1,136 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: linux-binary-manywheel-rocm
on:
push:
branches:
- main
tags:
- 'ciflow/binaries/*'
- 'ciflow/binaries_wheel/*'
- 'ciflow/rocm/*'
workflow_dispatch:
permissions:
id-token: write
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BINARY_ENV_FILE: /tmp/env
BUILD_ENVIRONMENT: linux-binary-manywheel-rocm
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
concurrency:
group: linux-binary-manywheel-rocm-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
manywheel-py3_10-rocm6_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.10"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: manywheel-py3_10-rocm6_4
build_environment: linux-binary-manywheel-rocm
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_10-rocm6_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_10-rocm6_4-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.10"
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: manywheel-py3_10-rocm6_4
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: manylinux2_28-builder
custom-tag-prefix: rocm6.4
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm

View File

@ -0,0 +1,261 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: windows-binary-libtorch-debug
on:
push:
branches:
- main
workflow_dispatch:
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BUILD_ENVIRONMENT: windows-binary-libtorch-debug
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 1
OS: windows
concurrency:
group: windows-binary-libtorch-debug-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cpu-shared-with-deps-debug
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cpu-shared-with-deps-debug-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cpu-shared-with-deps-debug-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cpu-shared-with-deps-debug
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1

View File

@ -0,0 +1,261 @@
# @generated DO NOT EDIT MANUALLY
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
# Generation script: .github/scripts/generate_ci_workflows.py
name: windows-binary-libtorch-release
on:
push:
branches:
- main
workflow_dispatch:
env:
# Needed for conda builds
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
AWS_DEFAULT_REGION: us-east-1
BUILD_ENVIRONMENT: windows-binary-libtorch-release
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 1
OS: windows
concurrency:
group: windows-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cpu-shared-with-deps-release
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cpu-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cpu-shared-with-deps-release-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cpu
GPU_ARCH_TYPE: cpu
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cpu-shared-with-deps-release
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1

View File

@ -1831,37 +1831,6 @@ std::optional<c10::ScalarType> out_dtype) {
return out;
}
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
const auto batch1_sizes = batch1.sizes();
const auto batch2_sizes = batch2.sizes();
int64_t bs = batch1_sizes[0];
int64_t contraction_size = batch1_sizes[2];
int64_t res_rows = batch1_sizes[1];
int64_t res_cols = batch2_sizes[2];
std::vector<int64_t> output_size {bs, res_rows, res_cols};
TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
"Expected size for first two dimensions of batch2 tensor to be: [",
bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
TORCH_CHECK(batch1.scalar_type() == batch2.scalar_type(), "batch1 and batch2 must have the same dtype");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
if (!is_bmm && self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
}
}
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
IntArrayRef batch1_sizes = batch1.sizes();
IntArrayRef batch2_sizes = batch2.sizes();
@ -1871,7 +1840,12 @@ Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::Sca
}
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
Scalar beta(0.0);
Scalar alpha(1.0);
{
@ -1890,7 +1864,12 @@ Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tenso
}
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
@ -1905,12 +1884,6 @@ Tensor _mm_dtype_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarTy
}
Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarType out_dtype, Tensor &out) {
TORCH_CHECK(self.dim() == 2, "self must be a matrix, got ", self.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "input dtypes must be the same");
TORCH_CHECK(out_dtype == self.scalar_type() ||
@ -1930,14 +1903,6 @@ Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& m
}
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == self.scalar_type() ||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),

View File

@ -101,9 +101,6 @@ __device__ inline bool isinf_device(float v) {
__device__ inline bool isinf_device(c10::BFloat16 v) {
return ::isinf(static_cast<float>(v));
}
__device__ inline bool isinf_device(at::Half v) {
return ::isinf(static_cast<float>(v));
}
// CUDA kernel to compute Moving Average Min/Max of the tensor.
// It uses the running_min and running_max along with averaging const, c.
@ -163,8 +160,8 @@ void _calculate_moving_average(
std::tie(x_min, x_max) = at::aminmax(x, 1);
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
@ -184,8 +181,8 @@ void _calculate_moving_average(
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
std::tie(x_min, x_max) = at::aminmax(x);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
@ -224,8 +221,8 @@ void _calc_moving_avg_qparams_helper(
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
if (per_row_fq) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
int num_threads = std::min(size, (int64_t)512);
@ -247,8 +244,8 @@ void _calc_moving_avg_qparams_helper(
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(

View File

@ -1,16 +1,100 @@
#pragma once
// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
// If you modify me, keep me synchronized with that file.
#include <c10/macros/Export.h>
// If you modified DeviceType in caffe2/proto/caffe2.proto, please also sync
// your changes into torch/headeronly/core/DeviceType.h.
#include <torch/headeronly/core/DeviceType.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <ostream>
#include <string>
namespace c10 {
// These contains all device types that also have a BackendComponent
// and therefore participate in per-backend functionality dispatch keys.
// This is most backends except PrivateUse2 and PrivateUse3
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
_(CPU, extra) \
_(CUDA, extra) \
_(HIP, extra) \
_(XLA, extra) \
_(MPS, extra) \
_(IPU, extra) \
_(XPU, extra) \
_(HPU, extra) \
_(VE, extra) \
_(Lazy, extra) \
_(Meta, extra) \
_(MTIA, extra) \
_(PrivateUse1, extra)
enum class DeviceType : int8_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
MAIA = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
MPS = 13, // MPS
Meta = 14, // Meta (tensors with no data)
HPU = 15, // HPU / HABANA
VE = 16, // SX-Aurora / NEC
Lazy = 17, // Lazy Tensors
IPU = 18, // Graphcore IPU
MTIA = 19, // Meta training and inference devices
PrivateUse1 = 20, // PrivateUse1 device
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
};
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMAIA = DeviceType::MAIA;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMPS = DeviceType::MPS;
constexpr DeviceType kMeta = DeviceType::Meta;
constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kHPU = DeviceType::HPU;
constexpr DeviceType kVE = DeviceType::VE;
constexpr DeviceType kLazy = DeviceType::Lazy;
constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMTIA = DeviceType::MTIA;
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
static_assert(
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
"for this constant to reflect the actual number of DeviceTypes we support "
"in PyTorch; it's important that this number is not too large as we "
"use this to allocate stack arrays in some places in our code. If you "
"are indeed just adding the 20th device type, feel free to change "
"the check to 32; but if you are adding some sort of extensible device "
"types registration, please be aware that you are affecting code that "
"this number is small. Try auditing uses of this constant.");
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
C10_API bool isValidDeviceType(DeviceType d);
@ -24,6 +108,15 @@ C10_API bool is_privateuse1_backend_registered();
} // namespace c10
namespace std {
template <>
struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
namespace torch {
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::DeviceType;

View File

@ -1183,16 +1183,6 @@ class DeviceCachingAllocator {
// ends.
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
// Incremental reverse-traversal state cached per graph.
// We never re-traverse nodes we've already seen
struct GraphReuseContext {
ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
visited;
};
ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
mempool_to_capture_id;
ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;
// outstanding cuda events
ska::flat_hash_map<
cuda::CUDAStream,
@ -1648,70 +1638,44 @@ class DeviceCachingAllocator {
return block;
}
struct CaptureInfo {
cudaGraph_t graph{};
CaptureId_t capture_id{0};
const cudaGraphNode_t* terminals{nullptr};
size_t num_terminals{0};
cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
};
inline CaptureInfo stream_get_capture_info(cudaStream_t stream) {
CaptureInfo info{};
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream,
&info.status,
&info.capture_id,
&info.graph,
&info.terminals,
nullptr,
&info.num_terminals));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream,
&info.status,
&info.capture_id,
&info.graph,
&info.terminals,
&info.num_terminals));
#endif
TORCH_INTERNAL_ASSERT(
info.status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
return info;
}
// Record "free marker" of the CUDA graph for all streams that
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
// have used the block, including the allocation stream. These nodes mark the
// last use of the block in the capture graph. Returns a vector of the
// inserted nodes, or an empty vector if any stream is not capturing.
std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
// Is is possible to have the same marker recorded multiple times, so we use
// a set to avoid duplicates
ska::flat_hash_set<cudaGraphNode_t> markers;
cudaGraph_t owning_graph = nullptr;
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
std::vector<cudaGraphNode_t> empty_nodes;
auto try_record = [&](cudaStream_t s) -> bool {
auto info = stream_get_capture_info(s);
if (info.status == cudaStreamCaptureStatusNone) {
return false; // not capturing on this stream -> must defer
}
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* deps = nullptr;
size_t num_deps = 0;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &deps, &num_deps));
#endif
if (owning_graph == nullptr) {
owning_graph = info.graph;
}
TORCH_INTERNAL_ASSERT(
info.graph == owning_graph,
"All streams in the same capture should agree on the graph");
status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
// Use current terminals as the free markers for the stream
for (size_t i = 0; i < info.num_terminals; ++i) {
auto terminal = info.terminals[i];
markers.insert(terminal);
if (status == cudaStreamCaptureStatusNone) {
return false;
}
owning_graph = info.graph; // all streams in the same capture should agree
cudaGraphNode_t node{};
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
#else
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, 1, cudaStreamSetCaptureDependencies));
#endif
empty_nodes.push_back(node);
return true;
};
@ -1719,34 +1683,81 @@ class DeviceCachingAllocator {
// An empty vector indicates that the block should be deferred for freeing
// until after capture.
// Allocation stream
if (!try_record(block->stream)) {
// Attempt to add an empty node for the allocation stream.
if (!try_add_empty_node(block->stream)) {
return {};
}
// Any extra streams that used this block
// Attempt to add empty nodes for all streams that have used the block.
for (const auto& s : block->stream_uses) {
if (!try_record(s.stream())) {
if (!try_add_empty_node(s.stream())) {
return {};
}
}
return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
return empty_nodes;
}
// Returns the set of "reusable" free markers in the current
// Returns the current set of "terminal" nodes in the CUDA graph for a given
// stream. These represent the current endpoints of the stream, and may
// include additional nodes if the graph branches. Any new work captured will
// be attached after one or more of these terminals.
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
std::vector<cudaGraphNode_t> result;
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* dependencies = nullptr;
size_t num_dependencies = 0;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream,
&status,
nullptr,
&graph,
&dependencies,
nullptr,
&num_dependencies));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
#endif
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatusActive,
"Invalid stream capture status");
for (size_t i = 0; i < num_dependencies; i++) {
auto node = dependencies[i];
if (node != nullptr) {
result.push_back(node);
}
}
return result;
}
// Returns the set of "reusable" free markers (empty nodes) in the current
// CUDA graph capture. A free marker is considered reusable if it is a
// predecessor of every terminal node.
// This ensures that all future captured work will occur after the free
// marker, making it safe to reuse.
void update_visited(
const CaptureInfo& info,
ska::flat_hash_set<cudaGraphNode_t>& visited) {
// This is the versioned cudaGraphNodeGetDependencies helper function.
auto node_get_dependencies =
[](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
cudaStream_t stream) {
auto terminals = get_terminals(stream);
if (terminals.empty()) {
// No terminal nodes found; nothing to free.
return {};
}
auto get_dependencies = [](cudaGraphNode_t node,
cudaGraphNode_t* pDependencies,
size_t* pNumDependencies) -> void {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
node, pDependencies, nullptr, pNumDependencies));
#else
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
C10_CUDA_CHECK(
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
#endif
};
@ -1754,43 +1765,62 @@ class DeviceCachingAllocator {
auto get_parents =
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
size_t count = 0;
node_get_dependencies(node, nullptr, &count);
get_dependencies(node, nullptr, &count);
std::vector<cudaGraphNode_t> out(count);
if (count) {
node_get_dependencies(node, out.data(), &count);
get_dependencies(node, out.data(), &count);
out.resize(count);
}
return out;
};
// For each terminal node, perform a reverse DFS to count, for each free
// marker, how many terminals it can reach (i.e., for how many terminals it
// is a predecessor). A free marker is reusable if it is a predecessor of
// all terminal nodes.
std::deque<cudaGraphNode_t> dfs;
for (size_t i = 0; i < info.num_terminals; ++i) {
dfs.push_back(info.terminals[i]);
// Helper to determine if a node is an empty node (used as a free marker).
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
cudaGraphNodeType type{};
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
return type == cudaGraphNodeTypeEmpty;
};
// For each terminal node, perform a reverse DFS to count, for each empty
// node, how many terminals it can reach (i.e., for how many terminals it is
// a predecessor). An empty node is reusable if it is a predecessor of all
// terminal nodes.
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
for (auto terminal : terminals) {
ska::flat_hash_set<cudaGraphNode_t> visited;
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
std::function<void(cudaGraphNode_t)> reverse_dfs =
[&](cudaGraphNode_t node) {
if (!visited.insert(node).second)
return;
if (is_empty_node(node)) {
num_terminals_reachable[node]++;
empty_nodes.insert(node);
}
auto parents = get_parents(node);
for (auto p : parents) {
reverse_dfs(p);
}
};
reverse_dfs(terminal);
}
while (!dfs.empty()) {
auto v = dfs.back();
dfs.pop_back();
if (visited.count(v)) {
continue;
}
visited.insert(v);
auto parents = get_parents(v);
for (auto p : parents) {
dfs.push_back(p);
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
for (auto [node, count] : num_terminals_reachable) {
if (count == terminals.size()) {
reusable_empty_nodes.insert(node);
}
}
return reusable_empty_nodes;
}
// A block is considered reusable during CUDA graph capture if every free
// marker associated with the block is a predecessor of every
// marker (empty node) associated with the block is a predecessor of every
// terminal node.
//
// This ensures that any new operation added to the graph will be attached
@ -1799,52 +1829,36 @@ class DeviceCachingAllocator {
// on every stream, so the block's previous lifetime ends before any new
// lifetime begins. This check relies solely on the DAG topology and does not
// require event queries, making it safe to use during capture.
//
// This function iterates over all deferred blocks, determines if their empty
// nodes are reusable according to the above criteria, and frees the block if
// so.
void free_safe_blocks_in_capture(
const std::shared_ptr<GatheredContext>& context,
cudaStream_t stream) {
auto info = stream_get_capture_info(stream);
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
// If there are no reusable empty nodes (e.g., not currently capturing),
// there is nothing to do.
if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
if (reusable_empty_nodes.empty()) {
return;
}
if (graph_reuse_context.find(info.capture_id) ==
graph_reuse_context.end()) {
bool found = false;
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto graph_pool = graph_pools.find(entry.first);
TORCH_INTERNAL_ASSERT(
graph_pool != graph_pools.end(),
"Could not find graph pool for capture.");
auto mempool_id = graph_pool->first;
graph_reuse_context[info.capture_id] = GraphReuseContext{};
mempool_to_capture_id[mempool_id] = info.capture_id;
found = true;
break;
}
}
TORCH_INTERNAL_ASSERT(
found, "Could not find memory pool id for capture.");
}
auto& graph_context = graph_reuse_context[info.capture_id];
auto& visited = graph_context.visited[stream];
update_visited(info, visited);
std::vector<Block*> blocks_to_erase;
for (auto& [block, markers] : deferred_blocks) {
// Skip this block if it has no markers, as we defer its freeing until
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
// Skip this block if it has no empty nodes, as we defer its freeing until
// after graph capture. Also skip if the block was not allocated on the
// current stream; such blocks will be freed when
// free_safe_blocks_in_capture is attempted on that stream.
if (markers.empty() || block->stream != stream) {
if (inserted_empty_nodes.empty() || block->stream != stream) {
continue;
}
bool is_reusable = true;
for (auto m : markers) {
if (!visited.count(m)) {
for (const auto& node : inserted_empty_nodes) {
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
is_reusable = false;
break;
}
@ -1905,11 +1919,11 @@ class DeviceCachingAllocator {
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(!captures_underway.empty())) {
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
// record_free_markers returns a vector of free markers,
// insert_free_marker returns a vector of free markers,
// or an empty vector if any associated stream is not currently
// capturing. The empty vector means that we will defer the free until
// capture is finished.
deferred_blocks.emplace(block, record_free_markers(block));
deferred_blocks.emplace(block, insert_free_marker(block));
} else {
// If graph_capture_record_stream_reuse is not enabled, always defer
// the free until capture is finished.
@ -2497,21 +2511,6 @@ class DeviceCachingAllocator {
// Called by CUDAGraph::capture_end
void endAllocateToPool(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
!graph_reuse_context.empty()) {
auto capture_id = mempool_to_capture_id[mempool_id];
auto graph_context = graph_reuse_context[capture_id];
for (auto& [stream, _] : graph_context.visited) {
TORCH_INTERNAL_ASSERT(
stream_get_capture_info(stream).status ==
cudaStreamCaptureStatusNone,
"This stream should not be capturing when the capture is ended");
}
graph_reuse_context.erase(capture_id);
mempool_to_capture_id.erase(mempool_id);
}
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
if (it->first == mempool_id) {

View File

@ -339,16 +339,13 @@ XLA
~~~
- Jack Cao (`JackCaoG <https://github.com/JackCaoG>`__)
- Han Qi (`qihqi <https://github.com/qihqi>`__)
- Yifei Teng (`tengyifei <https://github.com/tengyifei>`__)
- Siyuan Liu (`lsy323 <https://github.com/lsy323>`__)
- Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
- Zach Cain (`zcain117 <https://github.com/zcain117>`__)
- Brian Hirsh (`bdhirsh <https://github.com/bdhirsh>`__)
- (emeritus) Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
- Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
- (emeritus) Ailing Zhang (`ailzhang <https://github.com/ailzhang>`__)
- (emeritus) Davide Libenzi (`dlibenzi <https://github.com/dlibenzi>`__)
- (emeritus) Alex Suhan (`asuhan <https://github.com/asuhan>`__)
- (emeritus) Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
- (emeritus) Zach Cain (`zcain117 <https://github.com/zcain117>`__)
TorchServe
~~~~~~~~~~

View File

@ -613,7 +613,8 @@ Available options:
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
and reallocate buffers across multiple streams, especially when the capture DAG frequently
reaches joined frontiers.
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
capturing the graph.
.. note::

View File

@ -4,7 +4,6 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
@ -28,7 +27,7 @@ add_executable(test_aoti_abi_check
target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
# WARNING: DO NOT LINK torch!!!
# The purpose is to check if the used aten/c10 headers are written in a header-only way
# The purpose is to check if the used aten/c10 headers are writtern in a header-only way
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})

View File

@ -1,35 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/DeviceType.h>
TEST(TestDeviceType, TestDeviceType) {
using torch::headeronly::DeviceType;
constexpr DeviceType expected_device_types[] = {
torch::headeronly::kCPU,
torch::headeronly::kCUDA,
DeviceType::MKLDNN,
DeviceType::OPENGL,
DeviceType::OPENCL,
DeviceType::IDEEP,
torch::headeronly::kHIP,
torch::headeronly::kFPGA,
torch::headeronly::kMAIA,
torch::headeronly::kXLA,
torch::headeronly::kVulkan,
torch::headeronly::kMetal,
torch::headeronly::kXPU,
torch::headeronly::kMPS,
torch::headeronly::kMeta,
torch::headeronly::kHPU,
torch::headeronly::kVE,
torch::headeronly::kLazy,
torch::headeronly::kIPU,
torch::headeronly::kMTIA,
torch::headeronly::kPrivateUse1,
};
for (int8_t i = 0; i <
static_cast<int8_t>(torch::headeronly::COMPILE_TIME_MAX_DEVICE_TYPES);
i++) {
EXPECT_EQ(static_cast<DeviceType>(i), expected_device_types[i]);
}
}

View File

@ -15,9 +15,6 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._composable import checkpoint, replicate
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import (
FSDPModule,
@ -61,7 +58,6 @@ from torch.testing._internal.common_fsdp import (
)
from torch.testing._internal.common_utils import run_tests, TEST_XPU, xfailIf
from torch.testing._internal.distributed._tensor.common_dtensor import (
FeedForward,
ModelArgs,
Transformer,
TransformerBlock,
@ -1014,222 +1010,6 @@ class TestFullyShardPrefetch(FSDPTest):
self.assertEqual(events, expected_backward_events)
events.clear()
@skip_if_lt_x_gpu(2)
def test_set_modules_to_backward_prefetch_inside_ac(self):
n_layers = 3
reshard_after_forward = True
# use checkpoint wrapper instead of torch.utils
model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=False)
model = Transformer(model_args)
apply_activation_checkpointing(
model, check_fn=lambda m: isinstance(m, TransformerBlock)
)
apply_activation_checkpointing(
model, check_fn=lambda m: isinstance(m, FeedForward)
)
fully_shard([model.tok_embeddings, model.pos_embeddings])
for layer in model.layers:
# mimic fully_shard(layer.moe.experts)
fully_shard(
layer.feed_forward.w1, reshard_after_forward=reshard_after_forward
)
fully_shard(layer, reshard_after_forward=reshard_after_forward)
fully_shard(
[model.norm, model.output], reshard_after_forward=reshard_after_forward
)
fully_shard(model, reshard_after_forward=reshard_after_forward)
inp = torch.randint(
0,
model_args.vocab_size,
(2, model_args.max_seq_len),
device=device_type.type,
)
def set_backward_prefetch(model: Transformer) -> None:
# tell pyre model.set_modules_to_backward_prefetch is available
assert isinstance(model, FSDPModule)
assert isinstance(model.output, FSDPModule)
# mimic deepseek MOE
# prefetch layer - 1 and its feedforward before cpu sync during a2a
reversed_transformer_blocks = list(reversed(model.layers))
prev_transformer_blocks = reversed_transformer_blocks[1:] + [None]
if (
model.norm is not None
and model.output is not None
and len(model.layers) > 0
):
assert isinstance(reversed_transformer_blocks[0], FSDPModule)
model.output.set_modules_to_backward_prefetch(
[reversed_transformer_blocks[0]]
)
for transformer_block, prev_transformer_block in zip(
reversed_transformer_blocks, prev_transformer_blocks
):
assert isinstance(transformer_block, FSDPModule)
if prev_transformer_block is not None:
assert isinstance(prev_transformer_block, FSDPModule)
assert hasattr(prev_transformer_block.feed_forward, "w1")
assert isinstance(
prev_transformer_block.feed_forward.w1, FSDPModule
)
transformer_block.set_modules_to_backward_prefetch(
[
prev_transformer_block,
prev_transformer_block.feed_forward.w1,
]
)
elif model.tok_embeddings is not None:
assert isinstance(model.tok_embeddings, FSDPModule)
transformer_block.set_modules_to_backward_prefetch(
[model.tok_embeddings]
)
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
with (
patch_unshard(unshard_with_record),
patch_reshard(reshard_with_record),
):
loss = model(inp)
events.clear()
loss.sum().backward()
expected_backward_events = [
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
# layers.2 prefetch w1
(
"unshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
# layers.2.w1 prefetch layers.1
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
(
"reshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
(
"unshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
(
"reshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
(
"unshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
(
"unshard",
"tok_embeddings, pos_embeddings",
TrainingState.PRE_BACKWARD,
),
(
"reshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()
set_backward_prefetch(model)
loss = model(inp)
events.clear()
loss.sum().backward()
expected_backward_events = expected_backward_events = [
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
# root explicit prefetch layers.2
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
# layers.2 prefetch layers.1 and feed_forward
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
(
"unshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
# AC recompute_fn
(
"unshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.FORWARD,
),
(
"reshard",
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
# layers.1 prefetch layers.0
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
(
"unshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.PRE_BACKWARD,
),
(
"reshard",
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
# layers.0 prefetch embeddings
(
"unshard",
"tok_embeddings, pos_embeddings",
TrainingState.PRE_BACKWARD,
),
(
"reshard",
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
TrainingState.POST_BACKWARD,
),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
(
"reshard",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
("reshard", "norm, output", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()
@skip_if_lt_x_gpu(2)
def test_fully_shard_multi_module_backward_prefetch(self):
n_layers = 5

View File

@ -1,626 +0,0 @@
# Owner(s): ["oncall: distributed"]
import copy
import dataclasses
import functools
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed._composable.replicate_with_fsdp import replicate
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
_get_gradient_divide_factors,
)
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
SaveForwardInputsModel,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
patch_reduce_scatter,
reduce_scatter_with_assert,
)
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocmVersionLessThan,
TEST_HPU,
)
device_type = torch.device(get_devtype())
class TestReplicateMixedPrecisionTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.get_device_module(device_type).device_count())
def _init_models_and_optims(
self,
reshard_after_forward: Union[bool, int],
param_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
use_shard_placement_fn,
):
torch.manual_seed(42)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = -1
largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
assert largest_dim >= 0, f"{param.shape}"
return Shard(largest_dim)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
replicate_fn = functools.partial(
replicate,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=shard_placement_fn,
)
for mlp in model:
replicate_fn(mlp)
replicate_fn(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
return ref_model, ref_optim, model, optim
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
use_shard_placement_fn_vals = [False]
if self.world_size == 2:
# For world size >2, gradient elements get reduced in different
# orders for the baseline vs. dim-1 sharding, leading to numeric
# differences for bf16 reduction, so only test world size 2.
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"param_dtype": [torch.bfloat16, torch.float16],
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_compute_dtype,
)
def _test_compute_dtype(
self,
param_dtype: torch.dtype,
reshard_after_forward: Union[bool, int],
use_shard_placement_fn: bool,
):
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=None,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, param_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
predivide_factor, postdivide_factor, _, _ = _get_gradient_divide_factors(
self.process_group, all_reduce_group=None, reduce_dtype=param_dtype
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
with patch_reduce_scatter(reduce_scatter):
fsdp_loss.backward()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
ref_loss.backward()
for param in ref_model_bf16.parameters():
# Use reduce-scatter -> all-gather as all-reduce because for
# world size >=4, NCCL all-reduce shows numeric differences
# compared with NCCL reduce-scatter
if predivide_factor is not None and predivide_factor > 1:
param.grad.div_(predivide_factor)
elif predivide_factor is None:
param.grad.div_(self.world_size)
output = torch.zeros_like(torch.chunk(param.grad, self.world_size)[0])
dist.reduce_scatter_tensor(output, param.grad)
dist.all_gather_into_tensor(param.grad, output)
if postdivide_factor is not None and postdivide_factor > 1:
param.grad.div_(postdivide_factor)
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_fp32.grad = param_bf16.grad.to(param_fp32.dtype)
param_bf16.grad = None
ref_optim.step() # fp32 optimizer step
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_bf16.detach().copy_(param_fp32)
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": [False, True],
},
self._test_reduce_dtype_fp32_reduce,
)
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_reduce_dtype_bf16_reduce,
)
def _test_reduce_dtype_fp32_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
if (
self.world_size > 2
and isinstance(reshard_after_forward, int)
and use_shard_placement_fn
):
return
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, reduce_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
with patch_reduce_scatter(reduce_scatter):
fsdp_loss.backward()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
ref_loss.backward()
for param in ref_model_bf16.parameters():
param.grad.data = param.grad.to(torch.float32)
dist.all_reduce(param.grad) # fp32 reduction
param.grad.div_(self.world_size)
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_fp32.grad = param_bf16.grad
param_bf16.grad = None
ref_optim.step() # fp32 optimizer step
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_bf16.detach().copy_(param_fp32)
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
def _test_reduce_dtype_bf16_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
group = dist.distributed_c10d._get_default_group()
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, reduce_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
with patch_reduce_scatter(reduce_scatter):
fsdp_loss.backward()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
ref_loss = ref_model(inp).sum()
ref_loss.backward()
for param in ref_model.parameters():
param_grad = param.grad.to(reduce_dtype)
# Use reduce-scatter -> all-gather to implement all-reduce
# since for world size >2, bf16 all-reduce and reduce-scatter
# have numeric differences
sharded_grad = funcol.reduce_scatter_tensor(
param_grad, scatter_dim=0, reduceOp="avg", group=group
) # bf16 reduction
param.grad = funcol.all_gather_tensor(
sharded_grad, gather_dim=0, group=group
).to(param.dtype) # upcast to fp32
ref_optim.step() # fp32 optimizer step
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
@skip_if_lt_x_gpu(2)
def test_grad_acc_with_reduce_dtype(self):
"""
Tests that gradient accumulation without reduce-scatter when using
bf16 compute and fp32 reduction accumulates the unsharded gradients in
fp32.
"""
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_grad_acc_with_reduce_dtype,
)
def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool):
torch.manual_seed(42)
param_dtype, reduce_dtype = (torch.bfloat16, torch.float32)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
# To emulate the mixed precision implementation where forward/backward
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
# and a bf16 copy of the reference model
ref_model = copy.deepcopy(model).to(device_type)
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
replicate(
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, reduce_dtype)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
device = device_type
# Train on the same input to avoid loss explosion
num_microbatches = 4
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
for iter_idx in range(10):
microbatch_inps = torch.chunk(inp, 4)
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch)
model.set_reshard_after_backward(
is_last_microbatch or reshard_after_forward
)
losses: list[torch.Tensor] = []
for _model in (ref_model_compute, model):
losses.append(
_model(microbatch_inps[microbatch_idx].detach()).sum()
)
self.assertEqual(losses[-1].dtype, param_dtype)
with patch_reduce_scatter(reduce_scatter):
losses[-1].backward()
self.assertEqual(losses[0], losses[1])
# Manually accumulate gradients into the base reference model
# from the compute reference model in fp32
for ref_param, ref_param_compute in zip(
ref_model.parameters(), ref_model_compute.parameters()
):
self.assertTrue(ref_param_compute.grad is not None)
self.assertEqual(ref_param.dtype, torch.float32)
if ref_param.grad is not None:
ref_param.grad += ref_param_compute.grad
else:
ref_param.grad = ref_param_compute.grad.to(ref_param.dtype)
ref_param_compute.grad = None
# Manually reduce gradients for the reference model on the last
# microbatch to implement data parallelism
if is_last_microbatch:
for ref_param in ref_model.parameters():
self.assertTrue(ref_param.grad is not None)
dist.all_reduce(ref_param.grad)
ref_param.grad /= self.world_size
check_sharded_parity(self, ref_model, model)
ref_optim.step()
optim.step()
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
# Manually copy parameters from the base reference model to the
# compute reference model to run the optimizer step for the latter
for ref_param, ref_param_compute in zip(
ref_model.parameters(), ref_model_compute.parameters()
):
ref_param_compute.detach().copy_(ref_param)
class TestReplicateMixedPrecisionCasts(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(1)
def test_float16_on_one_submodule(self):
x = torch.zeros(2, 100, device=device_type)
# Subtest 1: use fp16 on the second child submodule -- does not require
# any additional casting logic
forward_inputs: dict[str, nn.Module] = {}
model = SaveForwardInputsModel(
forward_inputs,
cast_forward_inputs=False,
).to(device_type)
replicate(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
# Subtest 2: use fp16 on the second child module, where the user module
# owns the cast
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=True
).to(device_type)
replicate(
model.c2,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, cast_forward_inputs=False
),
)
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
# Subtest 3: use fp16 on the first child module and specify its output
# dtype so that the second child module does not need to cast
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False
).to(device_type)
replicate(
model.c1,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, output_dtype=torch.float32
),
)
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
@skip_if_lt_x_gpu(1)
def test_submodules_with_external_inputs(self):
self.run_subtests(
{"enable_submodule_cast": [False, True]},
self._test_submodules_with_external_inputs,
)
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
class ToyModule(nn.Module):
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
self.forward_inputs["l2_input_x"] = x
self.forward_inputs["l2_input_y"] = y
return self.l(x)
class ToyModel(nn.Module):
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs)
self.forward_inputs = forward_inputs
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.forward_inputs["model_input_x"] = x
y = torch.ones(
2, 100, device=device_type.type, dtype=torch.float32
) # external input
return self.l2(self.l1(x), y)
forward_inputs: dict[str, torch.Tensor] = {}
model = ToyModel(forward_inputs).to(device_type)
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
replicate(
model.l2,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
),
)
replicate(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
model(x).sum().backward()
# If we enable `model.l2` to cast (as default), then `l2_input_y` gets
# cast to fp16, and if we disable, then it says as fp32.
self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
self.assertEqual(
forward_inputs["l2_input_y"].dtype,
torch.float16 if enable_submodule_cast else torch.float32,
)
@skip_if_lt_x_gpu(1)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_norm_modules_bf16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
self._test_norm_modules(mp_policy)
@skip_if_lt_x_gpu(1)
def test_norm_modules_fp16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
self._test_norm_modules(mp_policy)
def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
def inner(model: nn.Module, x: torch.Tensor):
# Run forward and backward to check for no type mismatch errors
z = model(x)
self.assertEqual(z.dtype, mp_policy.param_dtype)
z.sum().backward()
# Layer norm
model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((4, 32)))
# Batch norm 1D
model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((4, 32)))
# Batch norm 2D: error in backward from buffer dtype mismatch
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
if TEST_HPU:
inner(model, torch.randn((3, 1, 9, 9)))
else:
with self.assertRaisesRegex(
RuntimeError,
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
):
# Errors in batch norm 2D backward
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: cast buffers down to lower precision
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
# Casting batch norm buffers to the lower precision allows backward
model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: use special mixed precision policy
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
replicate(model[1], mp_policy=bn_mp_policy)
for module in (model[0], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((3, 1, 9, 9)))
@skip_if_lt_x_gpu(1)
def test_clamp_reduce_dtype(self):
# Initialize the model directly in bf16
init_dtype = torch.bfloat16
model = nn.Sequential(
nn.Linear(32, 32, dtype=init_dtype),
nn.Linear(32, 32, dtype=init_dtype),
).to(device_type.type)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
)
# Check that we did not clamp the reduce dtype
self.assertEqual(mp_policy.reduce_dtype, torch.bfloat16)
for module in model:
replicate((module), mp_policy=mp_policy)
replicate(model, mp_policy=mp_policy)
# Check that the reduce-scatter runs in bf16 even after we change the
# model from bf16 to fp32
model.to(torch.float32)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, torch.bfloat16)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
with patch_reduce_scatter(reduce_scatter):
inp = torch.randn((4, 32), device=device_type.type)
loss = model(inp).sum()
loss.backward()
@skip_if_lt_x_gpu(1)
def test_dataclass_input(self):
@dataclasses.dataclass
class Input:
x: torch.Tensor
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._layer = nn.Linear(10, 10)
def forward(self, input: Input):
return self._layer(input.x)
mp_policy = MixedPrecisionPolicy(
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
)
model = Model()
inp = Input(torch.randn(2, 10).cuda())
replicate(model, mp_policy=mp_policy)
loss = model(inp).sum()
loss.backward()
if __name__ == "__main__":
run_tests()

View File

@ -5,7 +5,6 @@ import copy
import functools
import itertools
import unittest
from collections import defaultdict
from collections.abc import Iterable
from typing import Union
@ -18,20 +17,8 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
apply_activation_checkpointing,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
OffloadPolicy,
register_fsdp_forward_method,
)
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
from torch.distributed.tensor import DTensor, init_device_mesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
@ -39,7 +26,6 @@ from torch.testing._internal.common_fsdp import (
FSDPTest,
FSDPTestMultiThread,
MLP,
MLPStack,
patch_all_gather,
patch_reduce_scatter,
)
@ -856,385 +842,5 @@ class TestReplicateSharedParams(FSDPTest):
self.assertEqual(losses[0], losses[1])
class TestReplicateGradientAccumulation(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_gradient_accumulation(self):
"""
Tests gradient accumulation with/without gradient reduction and
with/without resharding after backward.
"""
shard_size, replicate_size = 1, self.world_size
meshes = init_device_mesh(
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("replicate", "shard"),
)
self.run_subtests(
{
"mesh": [meshes],
"reshard_after_forward": [True, False],
# "all": disable reduce-scatter for all modules
# "root_only": disable reduce-scatter for root's linear only
# "some_mlps": disable reduce-scatter for some MLPs
"mode": ["all", "root_only", "some_mlps"],
"reshard_after_backward": [False, True],
"offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
# For HSDP only:
# `True`: reduce-scatter only (no all-reduce) each microbatch
# until the last microbatch
# `False`: neither reduce-scatter nor all-reduce each
# microbatch until the last microbatch
"reduce_scatter_only": [False, True],
},
self._test_gradient_accumulation,
)
def _test_gradient_accumulation(
self,
mesh: DeviceMesh,
reshard_after_forward: Union[bool, int],
mode: str,
reshard_after_backward: bool,
offload_policy: OffloadPolicy,
reduce_scatter_only: bool, # for HSDP
):
if (
(
not reshard_after_backward
and (reshard_after_forward is not False or mode == "some_mlps")
)
or (
isinstance(offload_policy, CPUOffloadPolicy)
and reshard_after_forward is not True
)
or (
mesh.ndim != 2
) # may eventually need to change once decision on device mesh is made
):
return # skip since not common or applicable
torch.manual_seed(42)
batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
if mode == "some_mlps":
num_mlps_to_disable_reduce_scatter = 2
modules = [nn.Linear(lin_dim, lin_dim)]
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
model = nn.Sequential(*modules)
ref_model = copy.deepcopy(model).to(device_type)
replicate_fn = functools.partial(
replicate,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
for mlp in model[1:]:
replicate_fn(mlp)
replicate_fn(model) # root gets the 1st linear
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
def set_grad_sync_flag(
module: nn.Module, is_last_microbatch: bool, recurse: bool = True
):
if reduce_scatter_only:
module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
else:
module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)
def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
if mode == "all":
set_grad_sync_flag(_model, is_last_microbatch)
if not reshard_after_backward:
_model.set_reshard_after_backward(is_last_microbatch)
elif mode == "some_mlps":
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
set_grad_sync_flag(mlp, is_last_microbatch)
if not reshard_after_backward:
mlp.set_reshard_after_backward(is_last_microbatch)
elif mode == "root_only":
set_grad_sync_flag(model, is_last_microbatch, recurse=False)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch, recurse=False)
torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
comm_count_list = []
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
set_backward_flags(model, is_last_microbatch)
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
with CommDebugMode() as comm_mode:
losses.append(_model(inp).sum())
losses[-1].backward()
comm_count_list.append(comm_mode.get_comm_counts())
self.assertEqual(losses[0], losses[1])
comm_counts = defaultdict(int)
for comm_count_dict in comm_count_list:
for collective, count in comm_count_dict.items():
comm_counts[collective] += count
all_gather_count = comm_counts[c10d_ops._allgather_base_]
# reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
all_reduce_count = comm_counts[c10d_ops.allreduce_]
# Expect one reduce-scatter per MLP plus one for the root's linear
# on the last microbatch
# expected_reduce_scatter_count = 0
expected_all_reduce_count = num_mlps + 1
if mode == "some_mlps":
# Expect additional reduce-scatters for non-disabled MLPs and
# the root's linear
expected_all_reduce_count += (
num_mlps - num_mlps_to_disable_reduce_scatter + 1
) * (num_microbatches - 1)
elif mode == "root_only":
# Expect additional reduce-scatters for all MLPs
expected_all_reduce_count += (num_mlps) * (num_microbatches - 1)
# self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
self.assertEqual(all_reduce_count, expected_all_reduce_count)
# Expect one all-gather per MLP plus one for the root's linear in
# the first microbatch's forward
expected_all_gather_count = 0
self.assertEqual(all_gather_count, expected_all_gather_count)
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)
for _optim in (optim, ref_optim):
_optim.step()
# When `set_to_none=False`, we are exercising mixing
# gradient accumulation with and without communication
_optim.zero_grad(set_to_none=(iter_idx % 2))
@skip_if_lt_x_gpu(2)
def test_1f1b_microbatching(self):
self.run_subtests(
{
"use_explicit_unshard": [False, True],
"reshard_after_backward": [False, True],
},
self._test_1f1b_microbatching,
)
def _test_1f1b_microbatching(
self, use_explicit_unshard: bool, reshard_after_backward: bool
):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
replicate(module, reshard_after_forward=False)
replicate(model, reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
num_microbatches = 3
local_batch_size = 2
torch.manual_seed(42 + self.rank + 1)
inps = [
torch.randint(
0,
model_args.vocab_size,
(local_batch_size, 16),
device=device_type.type,
)
for _ in range(num_microbatches)
]
# Before pipelining, we may prefer to issue all all-gathers ahead of
# time to increase overlap opportunity at no difference in parameter
# memory usage since we do not reshard after forward
if use_explicit_unshard:
for module in model.modules():
if isinstance(module, FSDPModule):
module.unshard(async_op=True)
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
# last microbatch
losses: list[torch.Tensor] = []
ref_losses: list[torch.Tensor] = []
for inp_idx, inp in enumerate(inps):
is_last_microbatch = inp_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch)
model.set_is_last_backward(is_last_microbatch)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch)
losses.append(model(inp).sum())
losses[-1].backward()
ref_losses.append(ref_model(inp).sum())
ref_losses[-1].backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
for loss, ref_loss in zip(losses, ref_losses):
self.assertEqual(loss, ref_loss)
optim.step()
ref_optim.step()
check_sharded_parity(self, ref_model, model)
class TestReplicateCustomForwardMethod(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_register_fsdp_forward_method(self):
class VisionTransformer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs).sum(dim=1)
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
# Run `vit.forward_features`, which is not `forward`!
patch_embeddings = self.vit.forward_features(imgs)
return self.projector(patch_embeddings)
torch.manual_seed(42)
model = Model()
ref_model = copy.deepcopy(model).to(device_type)
replicate(model.vit)
replicate(model.projector)
replicate(model)
register_fsdp_forward_method(model.vit, "forward_features")
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)
class TestReplicateTPTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.get_device_module(device_type).device_count())
def init_global_mesh(self) -> DeviceMesh:
return init_device_mesh(
device_type.type,
(2, 1, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
@skip_if_lt_x_gpu(8)
def test_replicate_tp(self):
global_mesh = self.init_global_mesh()
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
},
functools.partial(self._test_replicate_tp, global_mesh),
)
def _test_replicate_tp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
foreach: bool,
):
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
parallelize_plan = {
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"0.in_proj": ColwiseParallel(use_local_output=False),
"0.out_proj": RowwiseParallel(use_local_output=False),
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": (RowwiseParallel()),
}
model = parallelize_module(model, tp_mesh, parallelize_plan)
for module in model:
if isinstance(module, nn.LayerNorm):
continue
if use_activation_checkpointing:
checkpoint(module)
replicate(module, device_mesh=dp_mesh)
replicate(model, device_mesh=dp_mesh)
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
self.assertIsInstance(p, DTensor)
self.assertEqual(ref_p, p.full_tensor())
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
torch.manual_seed(42 + dp_pg.rank() + 1)
device = device_type
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
losses.append(_model(inp).sum())
losses[-1].backward()
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
for _optim in (ref_optim, optim):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
_optim.step()
self.assertEqual(losses[0], losses[1])
check_sharded_parity(self, ref_model, model)
for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 3)
self.assertEqual(len(p.placements), 3)
self.assertEqual(
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
)
if __name__ == "__main__":
run_tests()

View File

@ -5,7 +5,6 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
@ -38,18 +37,6 @@ class SimpleModel(torch.nn.Module):
return self.mlp_1(self.mlp_0(input))
class SimpleModelAnnotated(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
with fx_traceback.annotate({"pp_stage": 0}):
x = self.mlp_0(input)
return self.mlp_1(x)
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
torch.utils._pytree.register_constant(DTensorSpec)
@ -103,7 +90,7 @@ class DTensorExportTest(TestCase):
)
self.device_type = "cuda"
def _run_test(self, export_fn, test_annotation=False):
def _run_test(self, export_fn):
dp_degree = 2
tp_degree = self.world_size // dp_degree
@ -114,11 +101,7 @@ class DTensorExportTest(TestCase):
mesh_dim_names=["dp", "tp"],
)
model = None
if test_annotation:
model = SimpleModelAnnotated(self.device_type)
else:
model = SimpleModel(self.device_type)
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
@ -148,116 +131,6 @@ class DTensorExportTest(TestCase):
1,
)
if test_annotation:
def has_tag(node):
return "custom" in node.meta and node.meta["custom"] == {"pp_stage": 0}
def marked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if has_tag(node) and node.op == "call_function"
]
def unmarked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if not has_tag(node) and node.op == "call_function"
]
marked_nodes_fw = [
"t",
"addmm",
"view",
"relu",
"view_1",
"t_1",
"div",
"addmm_1",
"all_reduce",
"wait_tensor",
"view_2",
"t_12",
]
unmarked_nodes_fw = [
"view_3",
"t_2",
"addmm_2",
"view_4",
"relu_1",
"view_5",
"t_3",
"div_1",
"addmm_3",
"all_reduce_1",
"wait_tensor_1",
"view_6",
"t_4",
"t_8",
]
marked_nodes_bw = [
"mm_4",
"t_13",
"view_1",
"mm_5",
"t_14",
"sum_3",
"view_9",
"t_15",
"detach",
"detach_1",
"detach_6",
"detach_7",
"threshold_backward_1",
"t_16",
"mm_6",
"t_17",
"sum_4",
"view_10",
"t_18",
]
unmarked_nodes_bw = [
"mm",
"t_5",
"view_5",
"mm_1",
"t_6",
"sum_1",
"view_7",
"t_7",
"detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward",
"mm_2",
"t_9",
"mm_3",
"t_10",
"sum_2",
"view_8",
"t_11",
"all_reduce_2",
"wait_tensor_2",
]
self.assertEqual(marked_nodes(fw_gm), marked_nodes_fw)
self.assertEqual(unmarked_nodes(fw_gm), unmarked_nodes_fw)
self.assertEqual(marked_nodes(bw_gm), marked_nodes_bw)
self.assertEqual(unmarked_nodes(bw_gm), unmarked_nodes_bw)
self.assertEqual(
set(marked_nodes(joint_gm)), set(marked_nodes_fw + marked_nodes_bw)
)
self.assertEqual(
set(unmarked_nodes(joint_gm)),
set(unmarked_nodes_fw + unmarked_nodes_bw),
)
@parametrize(
"export_fn",
[
@ -277,9 +150,6 @@ class DTensorExportTest(TestCase):
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
def test_annotate_aot_export_joint_with_descriptors_alone(self):
self._run_test(aot_export_joint_with_descriptors_alone, True)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import contextlib
import itertools
import torch
@ -356,7 +355,7 @@ class RedistributeTest(DTensorTestBase):
replica_spec = Replicate()
# 1) test replicate -> partial forward
replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
with self.assertRaisesRegex(RuntimeError, "Can not redistribute"):
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
from torch.distributed.tensor._redistribute import Redistribute
@ -620,38 +619,6 @@ class RedistributeTest(DTensorTestBase):
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(out.placements, [Shard(0), dst])
@with_comms
def test_redistribute_to_partial(self):
mesh = init_device_mesh(self.device_type, (2, 2))
tensor = torch.randn(12, 8, device=self.device_type)
test_cases = [
# Partial to Partial is allowed
([Partial(), Shard(0)], [Partial(), Shard(0)], True),
([Partial(), Shard(0)], [Partial(), Shard(1)], True),
([Shard(0), Partial()], [Replicate(), Partial()], True),
([Shard(0), Partial("prod")], [Replicate(), Partial("prod")], True),
# Non-Partial to Partial is NOT allowed
([Shard(0), Replicate()], [Shard(0), Partial()], False),
([Shard(0), Replicate()], [Replicate(), Partial()], False),
([Shard(0), Shard(1)], [Replicate(), Partial()], False),
# Partial to partial is allowed, if only the reduction ops is the same
([Shard(0), Partial("prod")], [Replicate(), Partial("sum")], False),
]
for src, dst, allow in test_cases:
dt = DTensor.from_local(tensor, mesh, src)
raise_context = (
self.assertRaisesRegex(RuntimeError, "Can not redistribute")
if not allow
else contextlib.nullcontext()
)
with raise_context:
out = dt.redistribute(mesh, dst)
self.assertEqual(out.placements, dst)
instantiate_parametrized_tests(RedistributeTest)

View File

@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||flatten
4|aten.view.default||
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1

View File

@ -15228,11 +15228,6 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureRetraceability
@testing.expectedFailureStrictV2
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
@testing.expectedFailureSerDer
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
@ -15251,22 +15246,17 @@ def forward(self, x):
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.op in ("placeholder", "output"):
continue
if node.target == torch.ops.aten.add.Tensor:
if node.target == torch.ops.aten.add.default:
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
elif node.target == torch.ops.aten.sub.Tensor:
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
elif node.target == torch.ops.aten.mul.Tensor:
if node.target == torch.ops.aten.mul.default:
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
elif node.target == torch.ops.aten.div.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta["custom"], {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
if node.target == torch.ops.aten.div.default:
self.assertTrue(node.meta["custom"], {})
def test_dynamic_shapes_serdes_generic(self):
from torch._export.serde.dynamic_shapes import (

View File

@ -13,7 +13,6 @@ import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.testing import normalize_gm
from torch._functorch._aot_autograd.descriptors import (
BufferAOTInput,
@ -778,34 +777,36 @@ class inner_f(torch.nn.Module):
inputs = (torch.randn(4, 3),)
for with_export in [False]: # TODO: make dynamo work for annotation
for with_export in [True, False]:
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
ep = torch.export.export(SimpleLinear(), inputs)
model = ep.module()
else:
model = SimpleLinear()
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions={}
stack, model, inputs, decompositions=decomposition_table
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if node.op in ("placeholder", "output"):
continue
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
elif node.target == torch.ops.aten.sub.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta.get("custom", {}), {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})
if __name__ == "__main__":

View File

@ -908,13 +908,11 @@ class TestFakeQuantize(TestCase):
self.assertEqual(fq_module.activation_post_process.quant_min, 0)
self.assertEqual(fq_module.activation_post_process.quant_max, 127)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
sampled_dtype=st.sampled_from(['bf16', 'fp16', 'fp32']))
def test_fused_moving_avg_obs_fake_quant(self, device, sampled_dtype):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']))
def test_fused_moving_avg_obs_fake_quant(self, device):
try:
if device == 'cpu':
sampled_dtype = 'fp32'
dtype = {'bf16' : torch.bfloat16, 'fp16' : torch.half, 'fp32' : torch.float32}[sampled_dtype]
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
torch.set_default_dtype(dtype)
with torch.device(device):

View File

@ -1065,17 +1065,15 @@ class TestFakeQuantizeOps(TestCase):
class TestFusedObsFakeQuant(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
sampled_dtype=st.sampled_from(['bf16', 'fp16', 'fp32']),
symmetric_quant=st.booleans(), use_bool=st.booleans())
@settings(deadline=None)
def test_fused_obs_fake_quant_moving_avg(self, device, sampled_dtype, symmetric_quant, use_bool) -> None:
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None:
"""
Tests the case where we call the fused_obs_fake_quant op multiple times
and update the running_min and max of the activation tensors.
"""
if device == "cpu":
sampled_dtype = "fp32"
dtype = {'bf16' : torch.bfloat16, 'fp16' : torch.half, 'fp32' : torch.float32}[sampled_dtype]
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
in_running_min_ref = out_running_min_ref = torch.tensor(float("inf"), dtype=dtype)
in_running_min_op = torch.tensor(float("inf"), dtype=dtype, device=device)

View File

@ -806,60 +806,6 @@ class TestMatmulCuda(InductorTestCase):
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
@onlyCUDA
@parametrize("ops", [("mm", torch.mm), ("bmm", torch.bmm), ("addmm", torch.addmm), ("baddbmm", torch.baddbmm)])
def test_input_dimension_checking_out_dtype(self, ops):
op_name, op = ops
B = 2
M, N, K = 32, 32, 32
def is_addmm():
return "add" in op_name
def is_batched():
return "bmm" in op_name
if is_batched():
a = torch.randn(B, M, K, device="cuda", dtype=torch.bfloat16)
mismatch_k_b = torch.randn(B, K + 1, N, device="cuda", dtype=torch.bfloat16)
c = torch.randn(B, M, N, device="cuda", dtype=torch.bfloat16)
extra_dim_b = a.clone().unsqueeze(0)
mismatch_k_err = "Expected size for first two dimensions of batch2 tensor to be"
extra_dim_err = "batch2 must be a 3D tensor"
else:
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
mismatch_k_b = torch.randn(K + 1, N, device="cuda", dtype=torch.bfloat16)
c = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
extra_dim_b = a.clone().unsqueeze(0)
mismatch_k_err = "mat1 and mat2 shapes cannot be multiplied"
extra_dim_err = "mat2 must be a matrix, got 3-D tensor"
# Test mismatch K
with self.assertRaisesRegex(RuntimeError, mismatch_k_err):
if is_addmm():
op(c, a, mismatch_k_b, out_dtype=torch.float32)
else:
op(a, mismatch_k_b, out_dtype=torch.float32)
# Test extra dimension
with self.assertRaisesRegex(RuntimeError, extra_dim_err):
if is_addmm():
op(c, a, extra_dim_b, out_dtype=torch.float32)
else:
op(c, extra_dim_b, out_dtype=torch.float32)
if is_batched():
with self.assertRaisesRegex(RuntimeError, "Expected size for first two dimensions of batch2 tensor to be"):
# Test mismatch B for bmm/baddbmm
mismatch_batch_dim_b = torch.randn(B + 1, K, N, device="cuda", dtype=torch.bfloat16)
if is_addmm():
op(c, a, mismatch_batch_dim_b, out_dtype=torch.float32)
else:
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"

View File

@ -726,10 +726,9 @@ class OutputGraph(OutputGraphCommon):
self.export_metadata = ExportMetaData()
# Dict of inlined unspecialized modules to generate the
# dynamo_flat_name_to_original_fqn mapping. The code here follows what
# export was doing earlier.
self.inlined_unspecialized_modules: dict[str, torch.nn.Module] = {}
# Set of inlined unspecialized modules names to generate the
# dynamo_flat_name_to_original_fqn mapping.
self.used_inlined_inbuilt_modules_names: OrderedSet[str] = OrderedSet()
def mark_bytecode_tracing_start(self) -> None:
self.compiler_trace_stack.enter_context(
@ -2585,7 +2584,7 @@ class OutputGraph(OutputGraphCommon):
# some of the tensor objects to be held alive for longer than necessary.
self.root_tx = None # type: ignore[assignment]
self.nn_modules.clear()
self.inlined_unspecialized_modules.clear()
self.used_inlined_inbuilt_modules_names.clear()
self.param_name_to_source = None
for node in self.graph.nodes:
@ -2619,9 +2618,9 @@ class OutputGraph(OutputGraphCommon):
) -> None:
name = OutputGraph.module_key_name(source.name())
name = get_unique_name_wrt(
name, self.inlined_unspecialized_modules, self.global_scope
name, self.used_inlined_inbuilt_modules_names, self.global_scope
)
self.inlined_unspecialized_modules[name] = inlined_module
self.used_inlined_inbuilt_modules_names.add(name)
def register_leaf_name(leaf_name: str) -> None:
assert self.param_name_to_source is not None

View File

@ -420,18 +420,14 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
fwd_seq_nr_to_node = {}
for node in fx_g.graph.nodes:
@ -451,10 +447,8 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
# TODO: better to change to a specific field of custom?
node.meta["custom"] = fwd_node.meta.get("custom")
def register_buffer_assignment_hook(mod, assigned_buffers):

View File

@ -476,7 +476,7 @@ def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
# u0==u1 assume the same, no broadcasting!
torch._check(
x == y,
lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
"sizes assumed to be the same due to unbacked broadcasting semantics",
)
return False

View File

@ -83,7 +83,7 @@ def _quantize_weight(float_wt, observer):
torch.qint8,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme == torch.per_channel_affine_float_qparams:
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float),

View File

@ -64,7 +64,9 @@ def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
return qscheme == torch.per_channel_affine_float_qparams
return qscheme in [
torch.per_channel_affine_float_qparams,
]
class FakeQuantizeBase(ABC, Module):

View File

@ -227,7 +227,7 @@ def is_getattr_tensor_metadata_node(node):
return (
node.op == "call_function"
and node.target == getattr
and node.args[1] == "shape"
and node.args[1] in ["shape"]
)

View File

@ -388,7 +388,7 @@ class UniformQuantizationObserverBase(ObserverBase):
)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.dtype == torch.uint16:
elif self.dtype in [torch.uint16]:
zero_point = zero_point.new_full(zero_point.size(), 2**15)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)

View File

@ -237,7 +237,7 @@ def _add_observer_(
for name, child in module.named_children():
# TODO remove Dropout special after codebase stable
if type_before_parametrizations(child) is nn.Dropout:
if type_before_parametrizations(child) in [nn.Dropout]:
continue
elif issubclass(
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)

View File

@ -598,7 +598,7 @@ class X86InductorQuantizer(Quantizer):
_annotate_nodes_not_quantize(linear_node)
return
input_qspec_map = {}
assert linear_node.target == torch.ops.aten.linear.default
assert linear_node.target in (torch.ops.aten.linear.default,)
has_bias = len(linear_node.args) == 3
input_index = 0
weight_index = 1
@ -1436,9 +1436,8 @@ class X86InductorQuantizer(Quantizer):
"Linear partition cannot have more than one output node"
)
linear_node = partition.output_nodes[0]
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default
if linear_node.op != "call_function" or linear_node.target not in (
torch.ops.aten.linear.default,
):
raise ValueError(f"{linear_node} is not an aten linear operator")
# skip annotation if it is already annotated
@ -1468,9 +1467,8 @@ class X86InductorQuantizer(Quantizer):
linear_node, unary_node = self._get_output_nodes_of_partitions(
[linear_partition, unary_partition]
)
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default
if linear_node.op != "call_function" or linear_node.target not in (
torch.ops.aten.linear.default,
):
continue
if _skip_annotate([unary_node, linear_node], filter_fn):

View File

@ -501,9 +501,9 @@ def calculate_qmin_qmax(
quant_min, quant_max = 0, 255
elif dtype in [torch.qint32, torch.int32]:
quant_min, quant_max = -1 * (2**31), (2**31) - 1
elif dtype == torch.uint16:
elif dtype in [torch.uint16]:
quant_min, quant_max = 0, 2**16 - 1
elif dtype == torch.int16:
elif dtype in [torch.int16]:
quant_min, quant_max = -(2**15), 2**15 - 1
else:
quant_min, quant_max = 0, 15

View File

@ -624,7 +624,7 @@ void broadcast(
")");
ncclComm_t comm = comms[i];
NCCL_CHECK(ncclBcast(
tensors[i].mutable_data_ptr(),
tensors[i].data_ptr(),
numel,
data_type,
0,
@ -669,9 +669,9 @@ void reduce(
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduce(
inputs[i].const_data_ptr(),
inputs[i].data_ptr(),
static_cast<std::remove_cv_t<decltype(i)>>(root) == i
? output.mutable_data_ptr()
? output.data_ptr()
: nullptr,
count,
data_type,
@ -723,8 +723,8 @@ void all_reduce(
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclAllReduce(
inputs[i].const_data_ptr(),
outputs[i].mutable_data_ptr(),
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
@ -765,8 +765,8 @@ void reduce_scatter(
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduceScatter(
inputs[i].const_data_ptr(),
outputs[i].mutable_data_ptr(),
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
@ -807,18 +807,18 @@ void all_gather(
ncclComm_t comm = comms_ref[i];
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(ncclAllGather(
inputs[i].const_data_ptr(),
outputs[i].mutable_data_ptr(),
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_comm(comm),
stream));
#else
NCCL_CHECK(ncclAllGather(
inputs[i].const_data_ptr(),
inputs[i].data_ptr(),
count,
data_type,
outputs[i].mutable_data_ptr(),
outputs[i].data_ptr(),
to_nccl_comm(comm),
stream));
#endif
@ -843,7 +843,7 @@ void all2all_single_equal_split(
size_t count = input.numel() / size;
[[maybe_unused]] size_t rankdiff = input.nbytes() / size;
const auto* sendbuff = reinterpret_cast<const char*>(input.const_data_ptr());
auto* recvbuff = reinterpret_cast<char*>(output.mutable_data_ptr());
auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
auto comm = to_nccl_comm(_comm);
#if defined(USE_ROCM) || defined(NCCL_ALLTOALL_SUPPORTED)
// NCCL_ALLTOALL_SUPPORTED is used so NCCL can differentiate send/recv
@ -964,7 +964,7 @@ void all2all(
if (_nccl_should_send_recv(input.numel())) {
NCCL_CHECK(ncclSend(
input.const_data_ptr(),
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
r,
@ -973,7 +973,7 @@ void all2all(
}
if (_nccl_should_send_recv(output.numel())) {
NCCL_CHECK(ncclRecv(
output.mutable_data_ptr(),
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
r,
@ -1005,7 +1005,7 @@ void send(
using namespace torch::cuda::nccl::detail;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclSend(
input.const_data_ptr(),
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
@ -1014,7 +1014,7 @@ void send(
#else
NCCL_CHECK_TIMEOUT(
ncclSend(
input.const_data_ptr(),
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
@ -1041,7 +1041,7 @@ void recv(
using namespace torch::cuda::nccl::detail;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclRecv(
output.mutable_data_ptr(),
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
@ -1050,7 +1050,7 @@ void recv(
#else
NCCL_CHECK_TIMEOUT(
ncclRecv(
output.mutable_data_ptr(),
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
@ -1091,7 +1091,7 @@ void gather(
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
if (r != root) {
auto* recvbuff = reinterpret_cast<char*>(outputs[r].mutable_data_ptr());
auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
} else {
// on its own rank, simply copy from the input
@ -1152,7 +1152,7 @@ void scatter(
} else {
size_t recv_count = outputs.numel();
auto recv_type = to_nccl_data_type(outputs);
auto* recvbuff = reinterpret_cast<char*>(outputs.mutable_data_ptr());
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
}
#ifndef NCCL_HAS_COMM_NONBLOCKING

View File

@ -420,7 +420,7 @@ class _CudaKernel:
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
# CDNA4 allows a max of 160KB shared memory
max_shared_mem = (
65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
65536 if device_props.gcnArchName not in ["gfx950"] else 160 * 1024
)
else:
max_shared_mem = getattr(

View File

@ -171,7 +171,3 @@ def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
):
return x
return x.to(dtype)
def is_bw() -> bool:
return torch._C._current_graph_task_id() != -1

View File

@ -31,7 +31,6 @@ from ._fsdp_common import (
compiled_autograd_enabled,
FSDPMeshInfo,
HSDPMeshInfo,
is_bw,
TrainingState,
)
from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState
@ -273,7 +272,7 @@ class FSDPParamGroup:
the staging buffers for collective comms.
"""
assert isinstance(
self._all_gather_comm, (DefaultAllGather | ProcessGroupAllocAllGather)
self._all_gather_comm, (DefaultAllGather, ProcessGroupAllocAllGather)
), (
"cannot call set_allocate_memory_from_process_group() "
f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}"
@ -286,7 +285,7 @@ class FSDPParamGroup:
assert isinstance(
self._reduce_scatter_comm,
(DefaultReduceScatter | ProcessGroupAllocReduceScatter),
(DefaultReduceScatter, ProcessGroupAllocReduceScatter),
), (
"cannot call set_allocate_memory_from_process_group() "
f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}"
@ -446,15 +445,8 @@ class FSDPParamGroup:
if not compiled_autograd_enabled():
logger.debug("%s", self._with_fqn("FSDP::post_forward"))
with record_function(self._with_fqn("FSDP::post_forward")):
if not compiled_autograd_enabled():
# for AC(fully_shard(model)), AC runs fsdp's _pre_forward
# it shouldn't change post_forward_order
if not is_bw():
self.reshard()
self._record_post_forward()
else:
self.reshard()
self._record_post_forward()
self.reshard()
self._record_post_forward()
self._training_state = TrainingState.IDLE
return output

View File

@ -528,10 +528,9 @@ class DTensor(torch.Tensor):
placements = list(placements)
for i, placement in enumerate(placements):
if placement.is_partial() and self.placements[i] != placement:
if placement.is_partial():
raise RuntimeError(
f"Can not redistribute from {self.placements[i]} to {placement}, "
"redistributing to Partial is for internal use only!"
"Can not redistribute to Partial, redistributing to Partial is for internal use only!"
)
elif isinstance(placement, Shard) and placement.dim < 0:
# normalize shard dim to be positive

View File

@ -34,7 +34,7 @@ class LayerNorm(nn.Module):
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format != torch.contiguous_format:
if self.data_format not in [torch.contiguous_format]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)

View File

@ -241,11 +241,8 @@ def _check_input_constraints_pre_hook(self, args, kwargs):
_check_inputs_match(args, kwargs, self._in_spec)
return
# NOTE: for some reason, Dynamo is tracing into this, we should see why and
# put compile at the right place. Until then, we can skip the input
# constraint checks.
if not torch.compiler.is_dynamo_compiling():
_check_input_constraints_for_module(self, args, kwargs)
# NOTE: this call is Dynamo disabled, as it used to be
_check_input_constraints_for_module(self, args, kwargs)
def _unlift_inputs_as_getattr(

View File

@ -65,76 +65,3 @@ def get_node_context(node, num_nodes=2) -> str:
break
cur = cur.prev
return "\n".join(node_contexts[::-1])
def map_recorded_events_to_aten_ops_with_stack_trace(graph_module, traced_data):
"""
Maps recorded profiler events to their corresponding aten operations and adds stack traces.
Args:
graph_module: The FX GraphModule
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
trace_events = traced_data.get("traceEvents", [])
# Create a mapping from node name to node for easy lookup
node_map = {node.name: node for node in graph_module.graph.nodes}
# Find aten operation events
aten_events = [e for e in trace_events if e.get("cat") == "cpu_op"]
# Map recorded events to aten ops and add stack traces
event_mapping = {}
for recorded_event in trace_events:
if (recorded_event.get("cat") in ["cpu_op"] and
recorded_event.get("name", "").startswith("## ") and
recorded_event.get("name", "").endswith(" ##")):
# Extract node name from "## node_name ##"
node_name = recorded_event["name"][3:-3] # Remove "## " and " ##"
if node_name in node_map:
node = node_map[node_name]
# Find corresponding aten operations within this recorded event's time window
recorded_start = recorded_event["ts"]
recorded_end = recorded_start + recorded_event["dur"]
# Find aten ops that fall within this time window
corresponding_aten_ops = []
for aten_event in aten_events:
aten_start = aten_event["ts"]
aten_end = aten_start + aten_event["dur"]
# Check if aten event overlaps with recorded event
if (aten_start >= recorded_start and aten_start <= recorded_end) or \
(aten_end >= recorded_start and aten_end <= recorded_end) or \
(aten_start <= recorded_start and aten_end >= recorded_end):
corresponding_aten_ops.append(aten_event)
# Add stack trace to recorded event and aten ops
stack_trace = node.meta.get("stack_trace", "No stack trace available")
# Add stack trace to the recorded event
if "args" not in recorded_event:
recorded_event["args"] = {}
recorded_event["args"]["stack_trace"] = stack_trace
# Add stack trace to corresponding aten ops
for aten_op in corresponding_aten_ops:
if "args" not in aten_op:
aten_op["args"] = {}
aten_op["args"]["stack_trace"] = stack_trace
event_mapping[node_name] = {
"recorded_event": recorded_event,
"aten_operations": corresponding_aten_ops,
"node": node,
"stack_trace": stack_trace
}
return event_mapping

View File

@ -440,7 +440,6 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -778,13 +777,8 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in ("call_function", "call_method", "call_module")
if do_record:
body.append(f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {node.name} ##'); _rf_{node.name}.__enter__()\n")
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1215,9 +1209,6 @@ class Graph:
name = self._graph_namespace.create_name(candidate, None)
n = Node(self, name, op, target, args, kwargs, type_expr)
# print(name)
# breakpoint()
if (
self.owning_module is not None
and getattr(self.owning_module, "_create_node_hooks", None) is not None
@ -1642,7 +1633,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1710,7 +1700,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1723,7 +1712,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1734,7 +1722,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -161,7 +161,6 @@ class Interpreter:
delay=0,
)
print("running inside interpreter")
for node in self.graph.nodes:
pbar.update(1)
if node in self.env:

View File

@ -95,27 +95,6 @@ bits4x2
bits8
bits16
# torch/headeronly/core/DeviceType.h
DeviceType
kCPU
kCUDA
kHIP
kFPGA
kMAIA
kXLA
kMPS
kMeta
kVulkan
kMetal
kXPU
kHPU
kVE
kLazy
kIPU
kMTIA
kPrivateUse1
COMPILE_TIME_MAX_DEVICE_TYPES
# torch/headeronly/core/ScalarType.h
NumScalarTypes
ScalarType

View File

@ -1,125 +0,0 @@
// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
// If you modify me, keep me synchronized with that file.
#include <torch/headeronly/macros/Export.h>
#include <cstddef>
#include <cstdint>
#include <functional>
namespace c10 {
// These contains all device types that also have a BackendComponent
// and therefore participate in per-backend functionality dispatch keys.
// This is most backends except PrivateUse2 and PrivateUse3
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
_(CPU, extra) \
_(CUDA, extra) \
_(HIP, extra) \
_(XLA, extra) \
_(MPS, extra) \
_(IPU, extra) \
_(XPU, extra) \
_(HPU, extra) \
_(VE, extra) \
_(Lazy, extra) \
_(Meta, extra) \
_(MTIA, extra) \
_(PrivateUse1, extra)
enum class DeviceType : int8_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
MAIA = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
MPS = 13, // MPS
Meta = 14, // Meta (tensors with no data)
HPU = 15, // HPU / HABANA
VE = 16, // SX-Aurora / NEC
Lazy = 17, // Lazy Tensors
IPU = 18, // Graphcore IPU
MTIA = 19, // Meta training and inference devices
PrivateUse1 = 20, // PrivateUse1 device
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in c10/core/DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
};
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMAIA = DeviceType::MAIA;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMPS = DeviceType::MPS;
constexpr DeviceType kMeta = DeviceType::Meta;
constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kHPU = DeviceType::HPU;
constexpr DeviceType kVE = DeviceType::VE;
constexpr DeviceType kLazy = DeviceType::Lazy;
constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMTIA = DeviceType::MTIA;
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
static_assert(
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
"for this constant to reflect the actual number of DeviceTypes we support "
"in PyTorch; it's important that this number is not too large as we "
"use this to allocate stack arrays in some places in our code. If you "
"are indeed just adding the 20th device type, feel free to change "
"the check to 32; but if you are adding some sort of extensible device "
"types registration, please be aware that you are affecting code that "
"this number is small. Try auditing uses of this constant.");
} // namespace c10
namespace std {
template <>
struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
namespace torch::headeronly {
using c10::COMPILE_TIME_MAX_DEVICE_TYPES;
using c10::DeviceType;
using c10::kCPU;
using c10::kCUDA;
using c10::kFPGA;
using c10::kHIP;
using c10::kHPU;
using c10::kIPU;
using c10::kLazy;
using c10::kMAIA;
using c10::kMeta;
using c10::kMetal;
using c10::kMPS;
using c10::kMTIA;
using c10::kPrivateUse1;
using c10::kVE;
using c10::kVulkan;
using c10::kXLA;
using c10::kXPU;
} // namespace torch::headeronly

View File

@ -427,7 +427,7 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
elif op_name == "logsumexp":
elif op_name in {"logsumexp"}:
if torch.is_floating_point(input):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_complex(input):

View File

@ -3272,13 +3272,11 @@ def gaussian_nll_loss(
if input.size()[:-1] == var.size():
var = torch.unsqueeze(var, -1)
# This checks if the var is broadcastable to the input and there is only one mismatched dimension.
# This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
# This is also a homoscedastic case.
# e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
# or input.size = (4, 3, 32, 32), var.size = (4, 1, 32, 32)
elif (
input.ndim == var.ndim
and sum(y for x, y in zip(input.size(), var.size()) if x != y) == 1
input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1
): # Heteroscedastic case
pass

View File

@ -432,18 +432,15 @@ def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
('reduction_none', {'reduction': 'none'}),
('homoscedastic', {'homoscedastic': True}),
]
module_inputs = []
for desc, constructor_kwargs in cases:
homoscedastic = constructor_kwargs.pop('homoscedastic', False)
var_input = make_input(1, 3).abs() if homoscedastic else make_input(4, 1).abs()
module_inputs.append(
ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
forward_input=FunctionInput(make_input(4, 3),
make_target(4, 3),
var_input),
forward_input=FunctionInput(make_input(3),
make_target(3),
make_input(1).abs()),
desc=desc,
reference_fn=no_batch_dim_reference_fn)
)

View File

@ -765,7 +765,7 @@ class QuantizationTestCase(TestCase):
and not isinstance(module, _FusedModule)
):
for child in module.children():
if type(child) is nn.Dropout:
if type(child) in [nn.Dropout]:
continue
self.checkObservers(
child, propagate_qconfig_list, prepare_custom_config_dict

View File

@ -192,7 +192,7 @@ class Trainer:
self.hybrid_module = HybridModel(
self.remote_em_rref,
self.remote_net_rref,
self.trainer_group if ddp_mode == DdpMode.INSIDE else None,
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
)
self.ddp_params, self.non_ddp_params = (
self.hybrid_module.ddp_params,

View File

@ -707,7 +707,7 @@ class DistributedTest:
self.assertNotEqual(args.get("dtype", ""), "")
per_coll_meta[collname].append(args)
if collname == "wait":
if collname in {"wait"}:
continue
self.assertEqual(args["Process Group Description"], "default_pg")
@ -7029,7 +7029,7 @@ class DistributedTest:
self.assertNotEqual(attrs.get("dtype", ""), "")
per_coll_meta[collname].append(attrs)
if collname == "wait":
if collname in {"wait"}:
continue
self.assertEqual(attrs["pg_name"], "0") # yes this is a string

View File

@ -125,7 +125,7 @@ class DebugMode(TorchDispatchMode):
_get_current_dispatch_mode(), FakeTensorMode
):
if self.record_faketensor:
if func != torch.ops.prim.device.default:
if func not in {torch.ops.prim.device.default}:
self.operators.append((func, args, kwargs, self.call_depth + 1))
elif len(types) == 0:
if self.record_realtensor:

View File

@ -103,7 +103,7 @@ class Capture:
def __getattr__(self, attrname):
if attrname == "kwarg" or attrname == "kwargs":
raise RuntimeError("no kwargs!")
if attrname == "__deepcopy__":
if attrname in ["__deepcopy__"]:
raise AttributeError
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
return result

View File

@ -783,7 +783,7 @@ class _FlopCounterMode(TorchDispatchMode):
return result, flop_counts
def _handle_higher_order_ops(self, func, types, args, kwargs):
if func is not torch.ops.higher_order.cond:
if func not in {torch.ops.higher_order.cond, }:
return NotImplemented
# The flop counter for cond counts the upper bound of flops.