mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
Compare commits
1 Commits
codegen_tr
...
export-rel
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e3fd97b6a |
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
78a47f87ce259a48f0391fa9ae15add05ea7432b
|
||||
0307428d65acf5cf1a73a70a7722e076bbb83f22
|
||||
|
||||
93
.github/scripts/generate_ci_workflows.py
vendored
93
.github/scripts/generate_ci_workflows.py
vendored
@ -127,6 +127,53 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
|
||||
),
|
||||
]
|
||||
|
||||
ROCM_SMOKE_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_variant="rocm",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
arches=["6.4"],
|
||||
python_versions=["3.10"],
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={
|
||||
LABEL_CIFLOW_BINARIES,
|
||||
LABEL_CIFLOW_BINARIES_WHEEL,
|
||||
LABEL_CIFLOW_ROCM,
|
||||
},
|
||||
isolated_workflow=True,
|
||||
),
|
||||
branches="main",
|
||||
),
|
||||
]
|
||||
|
||||
LINUX_BINARY_SMOKE_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
arches=["13.0"],
|
||||
python_versions=["3.12"],
|
||||
),
|
||||
branches="main",
|
||||
),
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="libtorch",
|
||||
build_variant=generate_binary_build_matrix.RELEASE,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
generate_binary_build_matrix.RELEASE,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
branches="main",
|
||||
),
|
||||
]
|
||||
|
||||
WINDOWS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
@ -212,6 +259,39 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
|
||||
),
|
||||
]
|
||||
|
||||
WINDOWS_BINARY_SMOKE_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
package_type="libtorch",
|
||||
build_variant=generate_binary_build_matrix.RELEASE,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.WINDOWS,
|
||||
generate_binary_build_matrix.RELEASE,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
branches="main",
|
||||
ciflow_config=CIFlowConfig(
|
||||
isolated_workflow=True,
|
||||
),
|
||||
),
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
package_type="libtorch",
|
||||
build_variant=generate_binary_build_matrix.DEBUG,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.WINDOWS,
|
||||
generate_binary_build_matrix.DEBUG,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
branches="main",
|
||||
ciflow_config=CIFlowConfig(
|
||||
isolated_workflow=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.MACOS_ARM64,
|
||||
@ -292,10 +372,23 @@ def main() -> None:
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
S390X_BINARY_BUILD_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
# Give rocm it's own workflow file
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
ROCM_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
LINUX_BINARY_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
|
||||
WINDOWS_BINARY_BUILD_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
|
||||
WINDOWS_BINARY_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("macos_binary_build_workflow.yml.j2"),
|
||||
MACOS_BINARY_BUILD_WORKFLOWS,
|
||||
|
||||
87
.github/workflows/generated-linux-binary-libtorch-release-main.yml
generated
vendored
Normal file
87
.github/workflows/generated-linux-binary-libtorch-release-main.yml
generated
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-binary-libtorch-release
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'ciflow/trunk/*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BINARY_ENV_FILE: /tmp/env
|
||||
BUILD_ENVIRONMENT: linux-binary-libtorch-release
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
|
||||
PYTORCH_ROOT: /pytorch
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 0
|
||||
concurrency:
|
||||
group: linux-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
libtorch-cpu-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cpu
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: libtorch-cpu-shared-with-deps-release
|
||||
build_environment: linux-binary-libtorch-release
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
libtorch-cpu-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cpu-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
uses: ./.github/workflows/_binary-test-linux.yml
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cpu
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
build_name: libtorch-cpu-shared-with-deps-release
|
||||
build_environment: linux-binary-libtorch-release
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runs_on: linux.4xlarge
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
88
.github/workflows/generated-linux-binary-manywheel-main.yml
generated
vendored
Normal file
88
.github/workflows/generated-linux-binary-manywheel-main.yml
generated
vendored
Normal file
@ -0,0 +1,88 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-binary-manywheel
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'ciflow/trunk/*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BINARY_ENV_FILE: /tmp/env
|
||||
BUILD_ENVIRONMENT: linux-binary-manywheel
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
|
||||
PYTORCH_ROOT: /pytorch
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 0
|
||||
concurrency:
|
||||
group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
manywheel-py3_12-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu130
|
||||
GPU_ARCH_VERSION: "13.0"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
|
||||
DESIRED_PYTHON: "3.12"
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_12-cuda13_0
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_12-cuda13_0-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- manywheel-py3_12-cuda13_0-build
|
||||
- get-label-type
|
||||
uses: ./.github/workflows/_binary-test-linux.yml
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu130
|
||||
GPU_ARCH_VERSION: "13.0"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
|
||||
DESIRED_PYTHON: "3.12"
|
||||
build_name: manywheel-py3_12-cuda13_0
|
||||
build_environment: linux-binary-manywheel
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
136
.github/workflows/generated-linux-binary-manywheel-rocm-main.yml
generated
vendored
Normal file
136
.github/workflows/generated-linux-binary-manywheel-rocm-main.yml
generated
vendored
Normal file
@ -0,0 +1,136 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-binary-manywheel-rocm
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'ciflow/binaries/*'
|
||||
- 'ciflow/binaries_wheel/*'
|
||||
- 'ciflow/rocm/*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BINARY_ENV_FILE: /tmp/env
|
||||
BUILD_ENVIRONMENT: linux-binary-manywheel-rocm
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
|
||||
PYTORCH_ROOT: /pytorch
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 0
|
||||
concurrency:
|
||||
group: linux-binary-manywheel-rocm-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
manywheel-py3_10-rocm6_4-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.10"
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
timeout-minutes: 300
|
||||
build_name: manywheel-py3_10-rocm6_4
|
||||
build_environment: linux-binary-manywheel-rocm
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_10-rocm6_4-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- manywheel-py3_10-rocm6_4-build
|
||||
- get-label-type
|
||||
runs-on: linux.rocm.gpu.mi250
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
SKIP_ALL_TESTS: 1
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: manywheel-py3_10-rocm6_4
|
||||
path: "${{ runner.temp }}/artifacts/"
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: ROCm set GPU_FLAG
|
||||
run: |
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
|
||||
docker-image-name: manylinux2_28-builder
|
||||
custom-tag-prefix: rocm6.4
|
||||
docker-build-dir: .ci/docker
|
||||
working-directory: pytorch
|
||||
- name: Pull Docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Test Pytorch binary
|
||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||
env:
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Teardown ROCm
|
||||
uses: ./.github/actions/teardown-rocm
|
||||
261
.github/workflows/generated-windows-binary-libtorch-debug-main.yml
generated
vendored
Normal file
261
.github/workflows/generated-windows-binary-libtorch-debug-main.yml
generated
vendored
Normal file
@ -0,0 +1,261 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: windows-binary-libtorch-debug
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BUILD_ENVIRONMENT: windows-binary-libtorch-debug
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 1
|
||||
OS: windows
|
||||
concurrency:
|
||||
group: windows-binary-libtorch-debug-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
libtorch-cpu-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-debug
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cpu-shared-with-deps-debug-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cpu-shared-with-deps-debug-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-debug
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
261
.github/workflows/generated-windows-binary-libtorch-release-main.yml
generated
vendored
Normal file
261
.github/workflows/generated-windows-binary-libtorch-release-main.yml
generated
vendored
Normal file
@ -0,0 +1,261 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: windows-binary-libtorch-release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BUILD_ENVIRONMENT: windows-binary-libtorch-release
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 1
|
||||
OS: windows
|
||||
concurrency:
|
||||
group: windows-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
libtorch-cpu-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-release
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cpu-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cpu-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-release
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
@ -1831,37 +1831,6 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
return out;
|
||||
}
|
||||
|
||||
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
|
||||
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
|
||||
|
||||
const auto batch1_sizes = batch1.sizes();
|
||||
const auto batch2_sizes = batch2.sizes();
|
||||
|
||||
int64_t bs = batch1_sizes[0];
|
||||
int64_t contraction_size = batch1_sizes[2];
|
||||
int64_t res_rows = batch1_sizes[1];
|
||||
int64_t res_cols = batch2_sizes[2];
|
||||
std::vector<int64_t> output_size {bs, res_rows, res_cols};
|
||||
|
||||
TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
|
||||
"Expected size for first two dimensions of batch2 tensor to be: [",
|
||||
bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
|
||||
|
||||
TORCH_CHECK(batch1.scalar_type() == batch2.scalar_type(), "batch1 and batch2 must have the same dtype");
|
||||
|
||||
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
if (!is_bmm && self_baddbmm.has_value()) {
|
||||
const auto& self = self_baddbmm.value();
|
||||
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
|
||||
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
|
||||
IntArrayRef batch1_sizes = batch1.sizes();
|
||||
IntArrayRef batch2_sizes = batch2.sizes();
|
||||
@ -1871,7 +1840,12 @@ Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::Sca
|
||||
}
|
||||
|
||||
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
|
||||
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
Scalar beta(0.0);
|
||||
Scalar alpha(1.0);
|
||||
{
|
||||
@ -1890,7 +1864,12 @@ Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tenso
|
||||
}
|
||||
|
||||
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
|
||||
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
{
|
||||
NoNamesGuard guard;
|
||||
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
|
||||
@ -1905,12 +1884,6 @@ Tensor _mm_dtype_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarTy
|
||||
}
|
||||
|
||||
Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarType out_dtype, Tensor &out) {
|
||||
TORCH_CHECK(self.dim() == 2, "self must be a matrix, got ", self.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "input dtypes must be the same");
|
||||
TORCH_CHECK(out_dtype == self.scalar_type() ||
|
||||
@ -1930,14 +1903,6 @@ Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& m
|
||||
}
|
||||
|
||||
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
TORCH_CHECK(out_dtype == self.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
|
||||
|
||||
@ -101,9 +101,6 @@ __device__ inline bool isinf_device(float v) {
|
||||
__device__ inline bool isinf_device(c10::BFloat16 v) {
|
||||
return ::isinf(static_cast<float>(v));
|
||||
}
|
||||
__device__ inline bool isinf_device(at::Half v) {
|
||||
return ::isinf(static_cast<float>(v));
|
||||
}
|
||||
|
||||
// CUDA kernel to compute Moving Average Min/Max of the tensor.
|
||||
// It uses the running_min and running_max along with averaging const, c.
|
||||
@ -163,8 +160,8 @@ void _calculate_moving_average(
|
||||
std::tie(x_min, x_max) = at::aminmax(x, 1);
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
@ -184,8 +181,8 @@ void _calculate_moving_average(
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
std::tie(x_min, x_max) = at::aminmax(x);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
@ -224,8 +221,8 @@ void _calc_moving_avg_qparams_helper(
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
|
||||
if (per_row_fq) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
@ -247,8 +244,8 @@ void _calc_moving_avg_qparams_helper(
|
||||
});
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
|
||||
|
||||
@ -1,16 +1,100 @@
|
||||
#pragma once
|
||||
|
||||
// This is directly synchronized with caffe2/proto/caffe2.proto, but
|
||||
// doesn't require me to figure out how to get Protobuf headers into
|
||||
// ATen/core (which would require a lot more build system hacking.)
|
||||
// If you modify me, keep me synchronized with that file.
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
// If you modified DeviceType in caffe2/proto/caffe2.proto, please also sync
|
||||
// your changes into torch/headeronly/core/DeviceType.h.
|
||||
#include <torch/headeronly/core/DeviceType.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// These contains all device types that also have a BackendComponent
|
||||
// and therefore participate in per-backend functionality dispatch keys.
|
||||
// This is most backends except PrivateUse2 and PrivateUse3
|
||||
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
|
||||
_(CPU, extra) \
|
||||
_(CUDA, extra) \
|
||||
_(HIP, extra) \
|
||||
_(XLA, extra) \
|
||||
_(MPS, extra) \
|
||||
_(IPU, extra) \
|
||||
_(XPU, extra) \
|
||||
_(HPU, extra) \
|
||||
_(VE, extra) \
|
||||
_(Lazy, extra) \
|
||||
_(Meta, extra) \
|
||||
_(MTIA, extra) \
|
||||
_(PrivateUse1, extra)
|
||||
|
||||
enum class DeviceType : int8_t {
|
||||
CPU = 0,
|
||||
CUDA = 1, // CUDA.
|
||||
MKLDNN = 2, // Reserved for explicit MKLDNN
|
||||
OPENGL = 3, // OpenGL
|
||||
OPENCL = 4, // OpenCL
|
||||
IDEEP = 5, // IDEEP.
|
||||
HIP = 6, // AMD HIP
|
||||
FPGA = 7, // FPGA
|
||||
MAIA = 8, // ONNX Runtime / Microsoft
|
||||
XLA = 9, // XLA / TPU
|
||||
Vulkan = 10, // Vulkan
|
||||
Metal = 11, // Metal
|
||||
XPU = 12, // XPU
|
||||
MPS = 13, // MPS
|
||||
Meta = 14, // Meta (tensors with no data)
|
||||
HPU = 15, // HPU / HABANA
|
||||
VE = 16, // SX-Aurora / NEC
|
||||
Lazy = 17, // Lazy Tensors
|
||||
IPU = 18, // Graphcore IPU
|
||||
MTIA = 19, // Meta training and inference devices
|
||||
PrivateUse1 = 20, // PrivateUse1 device
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in DeviceType.cpp
|
||||
// - Change the number below
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
|
||||
};
|
||||
|
||||
constexpr DeviceType kCPU = DeviceType::CPU;
|
||||
constexpr DeviceType kCUDA = DeviceType::CUDA;
|
||||
constexpr DeviceType kHIP = DeviceType::HIP;
|
||||
constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||
constexpr DeviceType kMAIA = DeviceType::MAIA;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
constexpr DeviceType kMPS = DeviceType::MPS;
|
||||
constexpr DeviceType kMeta = DeviceType::Meta;
|
||||
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
||||
constexpr DeviceType kMetal = DeviceType::Metal;
|
||||
constexpr DeviceType kXPU = DeviceType::XPU;
|
||||
constexpr DeviceType kHPU = DeviceType::HPU;
|
||||
constexpr DeviceType kVE = DeviceType::VE;
|
||||
constexpr DeviceType kLazy = DeviceType::Lazy;
|
||||
constexpr DeviceType kIPU = DeviceType::IPU;
|
||||
constexpr DeviceType kMTIA = DeviceType::MTIA;
|
||||
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
|
||||
|
||||
// define explicit int constant
|
||||
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
|
||||
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
|
||||
static_assert(
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
|
||||
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
|
||||
"for this constant to reflect the actual number of DeviceTypes we support "
|
||||
"in PyTorch; it's important that this number is not too large as we "
|
||||
"use this to allocate stack arrays in some places in our code. If you "
|
||||
"are indeed just adding the 20th device type, feel free to change "
|
||||
"the check to 32; but if you are adding some sort of extensible device "
|
||||
"types registration, please be aware that you are affecting code that "
|
||||
"this number is small. Try auditing uses of this constant.");
|
||||
|
||||
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
|
||||
|
||||
C10_API bool isValidDeviceType(DeviceType d);
|
||||
@ -24,6 +108,15 @@ C10_API bool is_privateuse1_backend_registered();
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::DeviceType> {
|
||||
std::size_t operator()(c10::DeviceType k) const {
|
||||
return std::hash<int>()(static_cast<int>(k));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
namespace torch {
|
||||
// NOLINTNEXTLINE(misc-unused-using-decls)
|
||||
using c10::DeviceType;
|
||||
|
||||
@ -1183,16 +1183,6 @@ class DeviceCachingAllocator {
|
||||
// ends.
|
||||
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
|
||||
|
||||
// Incremental reverse-traversal state cached per graph.
|
||||
// We never re-traverse nodes we've already seen
|
||||
struct GraphReuseContext {
|
||||
ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
|
||||
visited;
|
||||
};
|
||||
ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
|
||||
mempool_to_capture_id;
|
||||
ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;
|
||||
|
||||
// outstanding cuda events
|
||||
ska::flat_hash_map<
|
||||
cuda::CUDAStream,
|
||||
@ -1648,70 +1638,44 @@ class DeviceCachingAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
struct CaptureInfo {
|
||||
cudaGraph_t graph{};
|
||||
CaptureId_t capture_id{0};
|
||||
const cudaGraphNode_t* terminals{nullptr};
|
||||
size_t num_terminals{0};
|
||||
cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
|
||||
};
|
||||
|
||||
inline CaptureInfo stream_get_capture_info(cudaStream_t stream) {
|
||||
CaptureInfo info{};
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&info.status,
|
||||
&info.capture_id,
|
||||
&info.graph,
|
||||
&info.terminals,
|
||||
nullptr,
|
||||
&info.num_terminals));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream,
|
||||
&info.status,
|
||||
&info.capture_id,
|
||||
&info.graph,
|
||||
&info.terminals,
|
||||
&info.num_terminals));
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
info.status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
// Record "free marker" of the CUDA graph for all streams that
|
||||
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
|
||||
// have used the block, including the allocation stream. These nodes mark the
|
||||
// last use of the block in the capture graph. Returns a vector of the
|
||||
// inserted nodes, or an empty vector if any stream is not capturing.
|
||||
std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
|
||||
// Is is possible to have the same marker recorded multiple times, so we use
|
||||
// a set to avoid duplicates
|
||||
ska::flat_hash_set<cudaGraphNode_t> markers;
|
||||
cudaGraph_t owning_graph = nullptr;
|
||||
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
|
||||
std::vector<cudaGraphNode_t> empty_nodes;
|
||||
|
||||
auto try_record = [&](cudaStream_t s) -> bool {
|
||||
auto info = stream_get_capture_info(s);
|
||||
if (info.status == cudaStreamCaptureStatusNone) {
|
||||
return false; // not capturing on this stream -> must defer
|
||||
}
|
||||
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* deps = nullptr;
|
||||
size_t num_deps = 0;
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &deps, &num_deps));
|
||||
#endif
|
||||
|
||||
if (owning_graph == nullptr) {
|
||||
owning_graph = info.graph;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
info.graph == owning_graph,
|
||||
"All streams in the same capture should agree on the graph");
|
||||
status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
// Use current terminals as the free markers for the stream
|
||||
for (size_t i = 0; i < info.num_terminals; ++i) {
|
||||
auto terminal = info.terminals[i];
|
||||
markers.insert(terminal);
|
||||
if (status == cudaStreamCaptureStatusNone) {
|
||||
return false;
|
||||
}
|
||||
owning_graph = info.graph; // all streams in the same capture should agree
|
||||
|
||||
cudaGraphNode_t node{};
|
||||
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, 1, cudaStreamSetCaptureDependencies));
|
||||
#endif
|
||||
empty_nodes.push_back(node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -1719,34 +1683,81 @@ class DeviceCachingAllocator {
|
||||
// An empty vector indicates that the block should be deferred for freeing
|
||||
// until after capture.
|
||||
|
||||
// Allocation stream
|
||||
if (!try_record(block->stream)) {
|
||||
// Attempt to add an empty node for the allocation stream.
|
||||
if (!try_add_empty_node(block->stream)) {
|
||||
return {};
|
||||
}
|
||||
// Any extra streams that used this block
|
||||
// Attempt to add empty nodes for all streams that have used the block.
|
||||
for (const auto& s : block->stream_uses) {
|
||||
if (!try_record(s.stream())) {
|
||||
if (!try_add_empty_node(s.stream())) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
|
||||
return empty_nodes;
|
||||
}
|
||||
|
||||
// Returns the set of "reusable" free markers in the current
|
||||
// Returns the current set of "terminal" nodes in the CUDA graph for a given
|
||||
// stream. These represent the current endpoints of the stream, and may
|
||||
// include additional nodes if the graph branches. Any new work captured will
|
||||
// be attached after one or more of these terminals.
|
||||
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
|
||||
std::vector<cudaGraphNode_t> result;
|
||||
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* dependencies = nullptr;
|
||||
size_t num_dependencies = 0;
|
||||
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&status,
|
||||
nullptr,
|
||||
&graph,
|
||||
&dependencies,
|
||||
nullptr,
|
||||
&num_dependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"Invalid stream capture status");
|
||||
|
||||
for (size_t i = 0; i < num_dependencies; i++) {
|
||||
auto node = dependencies[i];
|
||||
if (node != nullptr) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the set of "reusable" free markers (empty nodes) in the current
|
||||
// CUDA graph capture. A free marker is considered reusable if it is a
|
||||
// predecessor of every terminal node.
|
||||
// This ensures that all future captured work will occur after the free
|
||||
// marker, making it safe to reuse.
|
||||
void update_visited(
|
||||
const CaptureInfo& info,
|
||||
ska::flat_hash_set<cudaGraphNode_t>& visited) {
|
||||
// This is the versioned cudaGraphNodeGetDependencies helper function.
|
||||
auto node_get_dependencies =
|
||||
[](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
|
||||
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
|
||||
cudaStream_t stream) {
|
||||
auto terminals = get_terminals(stream);
|
||||
if (terminals.empty()) {
|
||||
// No terminal nodes found; nothing to free.
|
||||
return {};
|
||||
}
|
||||
|
||||
auto get_dependencies = [](cudaGraphNode_t node,
|
||||
cudaGraphNode_t* pDependencies,
|
||||
size_t* pNumDependencies) -> void {
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
|
||||
node, pDependencies, nullptr, pNumDependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
|
||||
C10_CUDA_CHECK(
|
||||
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -1754,43 +1765,62 @@ class DeviceCachingAllocator {
|
||||
auto get_parents =
|
||||
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
|
||||
size_t count = 0;
|
||||
|
||||
node_get_dependencies(node, nullptr, &count);
|
||||
get_dependencies(node, nullptr, &count);
|
||||
std::vector<cudaGraphNode_t> out(count);
|
||||
if (count) {
|
||||
node_get_dependencies(node, out.data(), &count);
|
||||
get_dependencies(node, out.data(), &count);
|
||||
out.resize(count);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
// For each terminal node, perform a reverse DFS to count, for each free
|
||||
// marker, how many terminals it can reach (i.e., for how many terminals it
|
||||
// is a predecessor). A free marker is reusable if it is a predecessor of
|
||||
// all terminal nodes.
|
||||
std::deque<cudaGraphNode_t> dfs;
|
||||
for (size_t i = 0; i < info.num_terminals; ++i) {
|
||||
dfs.push_back(info.terminals[i]);
|
||||
// Helper to determine if a node is an empty node (used as a free marker).
|
||||
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
|
||||
cudaGraphNodeType type{};
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
|
||||
return type == cudaGraphNodeTypeEmpty;
|
||||
};
|
||||
|
||||
// For each terminal node, perform a reverse DFS to count, for each empty
|
||||
// node, how many terminals it can reach (i.e., for how many terminals it is
|
||||
// a predecessor). An empty node is reusable if it is a predecessor of all
|
||||
// terminal nodes.
|
||||
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
|
||||
|
||||
for (auto terminal : terminals) {
|
||||
ska::flat_hash_set<cudaGraphNode_t> visited;
|
||||
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
|
||||
|
||||
std::function<void(cudaGraphNode_t)> reverse_dfs =
|
||||
[&](cudaGraphNode_t node) {
|
||||
if (!visited.insert(node).second)
|
||||
return;
|
||||
|
||||
if (is_empty_node(node)) {
|
||||
num_terminals_reachable[node]++;
|
||||
empty_nodes.insert(node);
|
||||
}
|
||||
auto parents = get_parents(node);
|
||||
for (auto p : parents) {
|
||||
reverse_dfs(p);
|
||||
}
|
||||
};
|
||||
|
||||
reverse_dfs(terminal);
|
||||
}
|
||||
|
||||
while (!dfs.empty()) {
|
||||
auto v = dfs.back();
|
||||
dfs.pop_back();
|
||||
|
||||
if (visited.count(v)) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(v);
|
||||
|
||||
auto parents = get_parents(v);
|
||||
for (auto p : parents) {
|
||||
dfs.push_back(p);
|
||||
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
|
||||
for (auto [node, count] : num_terminals_reachable) {
|
||||
if (count == terminals.size()) {
|
||||
reusable_empty_nodes.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
return reusable_empty_nodes;
|
||||
}
|
||||
|
||||
// A block is considered reusable during CUDA graph capture if every free
|
||||
// marker associated with the block is a predecessor of every
|
||||
// marker (empty node) associated with the block is a predecessor of every
|
||||
// terminal node.
|
||||
//
|
||||
// This ensures that any new operation added to the graph will be attached
|
||||
@ -1799,52 +1829,36 @@ class DeviceCachingAllocator {
|
||||
// on every stream, so the block's previous lifetime ends before any new
|
||||
// lifetime begins. This check relies solely on the DAG topology and does not
|
||||
// require event queries, making it safe to use during capture.
|
||||
//
|
||||
// This function iterates over all deferred blocks, determines if their empty
|
||||
// nodes are reusable according to the above criteria, and frees the block if
|
||||
// so.
|
||||
void free_safe_blocks_in_capture(
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
cudaStream_t stream) {
|
||||
auto info = stream_get_capture_info(stream);
|
||||
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
|
||||
|
||||
// If there are no reusable empty nodes (e.g., not currently capturing),
|
||||
// there is nothing to do.
|
||||
if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
|
||||
if (reusable_empty_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
if (graph_reuse_context.find(info.capture_id) ==
|
||||
graph_reuse_context.end()) {
|
||||
bool found = false;
|
||||
for (auto& entry : captures_underway) {
|
||||
if (entry.second(stream)) {
|
||||
auto graph_pool = graph_pools.find(entry.first);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
graph_pool != graph_pools.end(),
|
||||
"Could not find graph pool for capture.");
|
||||
auto mempool_id = graph_pool->first;
|
||||
graph_reuse_context[info.capture_id] = GraphReuseContext{};
|
||||
mempool_to_capture_id[mempool_id] = info.capture_id;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
found, "Could not find memory pool id for capture.");
|
||||
}
|
||||
auto& graph_context = graph_reuse_context[info.capture_id];
|
||||
auto& visited = graph_context.visited[stream];
|
||||
update_visited(info, visited);
|
||||
|
||||
std::vector<Block*> blocks_to_erase;
|
||||
for (auto& [block, markers] : deferred_blocks) {
|
||||
// Skip this block if it has no markers, as we defer its freeing until
|
||||
|
||||
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
|
||||
// Skip this block if it has no empty nodes, as we defer its freeing until
|
||||
// after graph capture. Also skip if the block was not allocated on the
|
||||
// current stream; such blocks will be freed when
|
||||
// free_safe_blocks_in_capture is attempted on that stream.
|
||||
if (markers.empty() || block->stream != stream) {
|
||||
if (inserted_empty_nodes.empty() || block->stream != stream) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_reusable = true;
|
||||
for (auto m : markers) {
|
||||
if (!visited.count(m)) {
|
||||
|
||||
for (const auto& node : inserted_empty_nodes) {
|
||||
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
|
||||
is_reusable = false;
|
||||
break;
|
||||
}
|
||||
@ -1905,11 +1919,11 @@ class DeviceCachingAllocator {
|
||||
if (!block->stream_uses.empty()) {
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
|
||||
// record_free_markers returns a vector of free markers,
|
||||
// insert_free_marker returns a vector of free markers,
|
||||
// or an empty vector if any associated stream is not currently
|
||||
// capturing. The empty vector means that we will defer the free until
|
||||
// capture is finished.
|
||||
deferred_blocks.emplace(block, record_free_markers(block));
|
||||
deferred_blocks.emplace(block, insert_free_marker(block));
|
||||
} else {
|
||||
// If graph_capture_record_stream_reuse is not enabled, always defer
|
||||
// the free until capture is finished.
|
||||
@ -2497,21 +2511,6 @@ class DeviceCachingAllocator {
|
||||
// Called by CUDAGraph::capture_end
|
||||
void endAllocateToPool(MempoolId_t mempool_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
|
||||
!graph_reuse_context.empty()) {
|
||||
auto capture_id = mempool_to_capture_id[mempool_id];
|
||||
auto graph_context = graph_reuse_context[capture_id];
|
||||
for (auto& [stream, _] : graph_context.visited) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
stream_get_capture_info(stream).status ==
|
||||
cudaStreamCaptureStatusNone,
|
||||
"This stream should not be capturing when the capture is ended");
|
||||
}
|
||||
graph_reuse_context.erase(capture_id);
|
||||
mempool_to_capture_id.erase(mempool_id);
|
||||
}
|
||||
|
||||
for (auto it = captures_underway.begin(); it != captures_underway.end();
|
||||
++it) {
|
||||
if (it->first == mempool_id) {
|
||||
|
||||
@ -339,16 +339,13 @@ XLA
|
||||
~~~
|
||||
|
||||
- Jack Cao (`JackCaoG <https://github.com/JackCaoG>`__)
|
||||
- Han Qi (`qihqi <https://github.com/qihqi>`__)
|
||||
- Yifei Teng (`tengyifei <https://github.com/tengyifei>`__)
|
||||
- Siyuan Liu (`lsy323 <https://github.com/lsy323>`__)
|
||||
- Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
|
||||
- Zach Cain (`zcain117 <https://github.com/zcain117>`__)
|
||||
- Brian Hirsh (`bdhirsh <https://github.com/bdhirsh>`__)
|
||||
- (emeritus) Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
|
||||
- Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
|
||||
- (emeritus) Ailing Zhang (`ailzhang <https://github.com/ailzhang>`__)
|
||||
- (emeritus) Davide Libenzi (`dlibenzi <https://github.com/dlibenzi>`__)
|
||||
- (emeritus) Alex Suhan (`asuhan <https://github.com/asuhan>`__)
|
||||
- (emeritus) Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
|
||||
- (emeritus) Zach Cain (`zcain117 <https://github.com/zcain117>`__)
|
||||
|
||||
TorchServe
|
||||
~~~~~~~~~~
|
||||
|
||||
@ -613,7 +613,8 @@ Available options:
|
||||
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
|
||||
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
|
||||
and reallocate buffers across multiple streams, especially when the capture DAG frequently
|
||||
reaches joined frontiers.
|
||||
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
|
||||
capturing the graph.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
|
||||
set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
@ -28,7 +27,7 @@ add_executable(test_aoti_abi_check
|
||||
target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
|
||||
# WARNING: DO NOT LINK torch!!!
|
||||
# The purpose is to check if the used aten/c10 headers are written in a header-only way
|
||||
# The purpose is to check if the used aten/c10 headers are writtern in a header-only way
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/DeviceType.h>
|
||||
|
||||
TEST(TestDeviceType, TestDeviceType) {
|
||||
using torch::headeronly::DeviceType;
|
||||
constexpr DeviceType expected_device_types[] = {
|
||||
torch::headeronly::kCPU,
|
||||
torch::headeronly::kCUDA,
|
||||
DeviceType::MKLDNN,
|
||||
DeviceType::OPENGL,
|
||||
DeviceType::OPENCL,
|
||||
DeviceType::IDEEP,
|
||||
torch::headeronly::kHIP,
|
||||
torch::headeronly::kFPGA,
|
||||
torch::headeronly::kMAIA,
|
||||
torch::headeronly::kXLA,
|
||||
torch::headeronly::kVulkan,
|
||||
torch::headeronly::kMetal,
|
||||
torch::headeronly::kXPU,
|
||||
torch::headeronly::kMPS,
|
||||
torch::headeronly::kMeta,
|
||||
torch::headeronly::kHPU,
|
||||
torch::headeronly::kVE,
|
||||
torch::headeronly::kLazy,
|
||||
torch::headeronly::kIPU,
|
||||
torch::headeronly::kMTIA,
|
||||
torch::headeronly::kPrivateUse1,
|
||||
};
|
||||
for (int8_t i = 0; i <
|
||||
static_cast<int8_t>(torch::headeronly::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
i++) {
|
||||
EXPECT_EQ(static_cast<DeviceType>(i), expected_device_types[i]);
|
||||
}
|
||||
}
|
||||
@ -15,9 +15,6 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import checkpoint, replicate
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
FSDPModule,
|
||||
@ -61,7 +58,6 @@ from torch.testing._internal.common_fsdp import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, xfailIf
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
FeedForward,
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
TransformerBlock,
|
||||
@ -1014,222 +1010,6 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_set_modules_to_backward_prefetch_inside_ac(self):
|
||||
n_layers = 3
|
||||
reshard_after_forward = True
|
||||
# use checkpoint wrapper instead of torch.utils
|
||||
model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=False)
|
||||
model = Transformer(model_args)
|
||||
apply_activation_checkpointing(
|
||||
model, check_fn=lambda m: isinstance(m, TransformerBlock)
|
||||
)
|
||||
apply_activation_checkpointing(
|
||||
model, check_fn=lambda m: isinstance(m, FeedForward)
|
||||
)
|
||||
fully_shard([model.tok_embeddings, model.pos_embeddings])
|
||||
for layer in model.layers:
|
||||
# mimic fully_shard(layer.moe.experts)
|
||||
fully_shard(
|
||||
layer.feed_forward.w1, reshard_after_forward=reshard_after_forward
|
||||
)
|
||||
fully_shard(layer, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(
|
||||
[model.norm, model.output], reshard_after_forward=reshard_after_forward
|
||||
)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
inp = torch.randint(
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(2, model_args.max_seq_len),
|
||||
device=device_type.type,
|
||||
)
|
||||
|
||||
def set_backward_prefetch(model: Transformer) -> None:
|
||||
# tell pyre model.set_modules_to_backward_prefetch is available
|
||||
assert isinstance(model, FSDPModule)
|
||||
assert isinstance(model.output, FSDPModule)
|
||||
|
||||
# mimic deepseek MOE
|
||||
# prefetch layer - 1 and its feedforward before cpu sync during a2a
|
||||
reversed_transformer_blocks = list(reversed(model.layers))
|
||||
prev_transformer_blocks = reversed_transformer_blocks[1:] + [None]
|
||||
|
||||
if (
|
||||
model.norm is not None
|
||||
and model.output is not None
|
||||
and len(model.layers) > 0
|
||||
):
|
||||
assert isinstance(reversed_transformer_blocks[0], FSDPModule)
|
||||
model.output.set_modules_to_backward_prefetch(
|
||||
[reversed_transformer_blocks[0]]
|
||||
)
|
||||
|
||||
for transformer_block, prev_transformer_block in zip(
|
||||
reversed_transformer_blocks, prev_transformer_blocks
|
||||
):
|
||||
assert isinstance(transformer_block, FSDPModule)
|
||||
if prev_transformer_block is not None:
|
||||
assert isinstance(prev_transformer_block, FSDPModule)
|
||||
assert hasattr(prev_transformer_block.feed_forward, "w1")
|
||||
assert isinstance(
|
||||
prev_transformer_block.feed_forward.w1, FSDPModule
|
||||
)
|
||||
transformer_block.set_modules_to_backward_prefetch(
|
||||
[
|
||||
prev_transformer_block,
|
||||
prev_transformer_block.feed_forward.w1,
|
||||
]
|
||||
)
|
||||
elif model.tok_embeddings is not None:
|
||||
assert isinstance(model.tok_embeddings, FSDPModule)
|
||||
transformer_block.set_modules_to_backward_prefetch(
|
||||
[model.tok_embeddings]
|
||||
)
|
||||
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
reshard_with_record = self._get_reshard_with_record(
|
||||
FSDPParamGroup.reshard, events
|
||||
)
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_reshard(reshard_with_record),
|
||||
):
|
||||
loss = model(inp)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
expected_backward_events = [
|
||||
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
# layers.2 prefetch w1
|
||||
(
|
||||
"unshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
# layers.2.w1 prefetch layers.1
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"unshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
set_backward_prefetch(model)
|
||||
loss = model(inp)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
expected_backward_events = expected_backward_events = [
|
||||
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
|
||||
# root explicit prefetch layers.2
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
# layers.2 prefetch layers.1 and feed_forward
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
# AC recompute_fn
|
||||
(
|
||||
"unshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.FORWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
# layers.1 prefetch layers.0
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
# layers.0 prefetch embeddings
|
||||
(
|
||||
"unshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_multi_module_backward_prefetch(self):
|
||||
n_layers = 5
|
||||
|
||||
@ -1,626 +0,0 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.replicate_with_fsdp import replicate
|
||||
from torch.distributed.fsdp import MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
patch_reduce_scatter,
|
||||
reduce_scatter_with_assert,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfRocmVersionLessThan,
|
||||
TEST_HPU,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(2, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
param_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
use_shard_placement_fn,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
largest_dim = -1
|
||||
largest_dim_size = -1
|
||||
for dim, dim_size in enumerate(param.shape):
|
||||
if dim_size > largest_dim_size:
|
||||
largest_dim = dim
|
||||
largest_dim_size = dim_size
|
||||
assert largest_dim >= 0, f"{param.shape}"
|
||||
return Shard(largest_dim)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
shard_placement_fn=shard_placement_fn,
|
||||
)
|
||||
for mlp in model:
|
||||
replicate_fn(mlp)
|
||||
replicate_fn(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
return ref_model, ref_optim, model, optim
|
||||
|
||||
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
|
||||
use_shard_placement_fn_vals = [False]
|
||||
if self.world_size == 2:
|
||||
# For world size >2, gradient elements get reduced in different
|
||||
# orders for the baseline vs. dim-1 sharding, leading to numeric
|
||||
# differences for bf16 reduction, so only test world size 2.
|
||||
use_shard_placement_fn_vals.append(True)
|
||||
return use_shard_placement_fn_vals
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_compute_dtype(self):
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"param_dtype": [torch.bfloat16, torch.float16],
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_compute_dtype,
|
||||
)
|
||||
|
||||
def _test_compute_dtype(
|
||||
self,
|
||||
param_dtype: torch.dtype,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
use_shard_placement_fn: bool,
|
||||
):
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=None,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, param_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
predivide_factor, postdivide_factor, _, _ = _get_gradient_divide_factors(
|
||||
self.process_group, all_reduce_group=None, reduce_dtype=param_dtype
|
||||
)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
fsdp_loss.backward()
|
||||
optim.step()
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model_bf16.parameters():
|
||||
# Use reduce-scatter -> all-gather as all-reduce because for
|
||||
# world size >=4, NCCL all-reduce shows numeric differences
|
||||
# compared with NCCL reduce-scatter
|
||||
if predivide_factor is not None and predivide_factor > 1:
|
||||
param.grad.div_(predivide_factor)
|
||||
elif predivide_factor is None:
|
||||
param.grad.div_(self.world_size)
|
||||
output = torch.zeros_like(torch.chunk(param.grad, self.world_size)[0])
|
||||
dist.reduce_scatter_tensor(output, param.grad)
|
||||
dist.all_gather_into_tensor(param.grad, output)
|
||||
if postdivide_factor is not None and postdivide_factor > 1:
|
||||
param.grad.div_(postdivide_factor)
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_fp32.grad = param_bf16.grad.to(param_fp32.dtype)
|
||||
param_bf16.grad = None
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_bf16.detach().copy_(param_fp32)
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_reduce_dtype(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": [False, True],
|
||||
},
|
||||
self._test_reduce_dtype_fp32_reduce,
|
||||
)
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_reduce_dtype_bf16_reduce,
|
||||
)
|
||||
|
||||
def _test_reduce_dtype_fp32_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
if (
|
||||
self.world_size > 2
|
||||
and isinstance(reshard_after_forward, int)
|
||||
and use_shard_placement_fn
|
||||
):
|
||||
return
|
||||
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, reduce_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
fsdp_loss.backward()
|
||||
optim.step()
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model_bf16.parameters():
|
||||
param.grad.data = param.grad.to(torch.float32)
|
||||
dist.all_reduce(param.grad) # fp32 reduction
|
||||
param.grad.div_(self.world_size)
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_fp32.grad = param_bf16.grad
|
||||
param_bf16.grad = None
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_bf16.detach().copy_(param_fp32)
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_reduce_dtype_bf16_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, reduce_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
fsdp_loss.backward()
|
||||
optim.step()
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
ref_loss = ref_model(inp).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
param_grad = param.grad.to(reduce_dtype)
|
||||
# Use reduce-scatter -> all-gather to implement all-reduce
|
||||
# since for world size >2, bf16 all-reduce and reduce-scatter
|
||||
# have numeric differences
|
||||
sharded_grad = funcol.reduce_scatter_tensor(
|
||||
param_grad, scatter_dim=0, reduceOp="avg", group=group
|
||||
) # bf16 reduction
|
||||
param.grad = funcol.all_gather_tensor(
|
||||
sharded_grad, gather_dim=0, group=group
|
||||
).to(param.dtype) # upcast to fp32
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_grad_acc_with_reduce_dtype(self):
|
||||
"""
|
||||
Tests that gradient accumulation without reduce-scatter when using
|
||||
bf16 compute and fp32 reduction accumulates the unsharded gradients in
|
||||
fp32.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"reshard_after_forward": [True, False]},
|
||||
self._test_grad_acc_with_reduce_dtype,
|
||||
)
|
||||
|
||||
def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool):
|
||||
torch.manual_seed(42)
|
||||
param_dtype, reduce_dtype = (torch.bfloat16, torch.float32)
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
# To emulate the mixed precision implementation where forward/backward
|
||||
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
|
||||
# and a bf16 copy of the reference model
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
replicate(
|
||||
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(
|
||||
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, reduce_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
device = device_type
|
||||
# Train on the same input to avoid loss explosion
|
||||
num_microbatches = 4
|
||||
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
microbatch_inps = torch.chunk(inp, 4)
|
||||
for microbatch_idx in range(num_microbatches):
|
||||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||
model.set_requires_gradient_sync(is_last_microbatch)
|
||||
model.set_reshard_after_backward(
|
||||
is_last_microbatch or reshard_after_forward
|
||||
)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model_compute, model):
|
||||
losses.append(
|
||||
_model(microbatch_inps[microbatch_idx].detach()).sum()
|
||||
)
|
||||
self.assertEqual(losses[-1].dtype, param_dtype)
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
losses[-1].backward()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
# Manually accumulate gradients into the base reference model
|
||||
# from the compute reference model in fp32
|
||||
for ref_param, ref_param_compute in zip(
|
||||
ref_model.parameters(), ref_model_compute.parameters()
|
||||
):
|
||||
self.assertTrue(ref_param_compute.grad is not None)
|
||||
self.assertEqual(ref_param.dtype, torch.float32)
|
||||
if ref_param.grad is not None:
|
||||
ref_param.grad += ref_param_compute.grad
|
||||
else:
|
||||
ref_param.grad = ref_param_compute.grad.to(ref_param.dtype)
|
||||
ref_param_compute.grad = None
|
||||
# Manually reduce gradients for the reference model on the last
|
||||
# microbatch to implement data parallelism
|
||||
if is_last_microbatch:
|
||||
for ref_param in ref_model.parameters():
|
||||
self.assertTrue(ref_param.grad is not None)
|
||||
dist.all_reduce(ref_param.grad)
|
||||
ref_param.grad /= self.world_size
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
ref_optim.step()
|
||||
optim.step()
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
# Manually copy parameters from the base reference model to the
|
||||
# compute reference model to run the optimizer step for the latter
|
||||
for ref_param, ref_param_compute in zip(
|
||||
ref_model.parameters(), ref_model_compute.parameters()
|
||||
):
|
||||
ref_param_compute.detach().copy_(ref_param)
|
||||
|
||||
|
||||
class TestReplicateMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_float16_on_one_submodule(self):
|
||||
x = torch.zeros(2, 100, device=device_type)
|
||||
|
||||
# Subtest 1: use fp16 on the second child submodule -- does not require
|
||||
# any additional casting logic
|
||||
forward_inputs: dict[str, nn.Module] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs,
|
||||
cast_forward_inputs=False,
|
||||
).to(device_type)
|
||||
replicate(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
|
||||
|
||||
# Subtest 2: use fp16 on the second child module, where the user module
|
||||
# owns the cast
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=True
|
||||
).to(device_type)
|
||||
replicate(
|
||||
model.c2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, cast_forward_inputs=False
|
||||
),
|
||||
)
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
|
||||
|
||||
# Subtest 3: use fp16 on the first child module and specify its output
|
||||
# dtype so that the second child module does not need to cast
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=False
|
||||
).to(device_type)
|
||||
replicate(
|
||||
model.c1,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, output_dtype=torch.float32
|
||||
),
|
||||
)
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_submodules_with_external_inputs(self):
|
||||
self.run_subtests(
|
||||
{"enable_submodule_cast": [False, True]},
|
||||
self._test_submodules_with_external_inputs,
|
||||
)
|
||||
|
||||
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
|
||||
class ToyModule(nn.Module):
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l = nn.Linear(100, 100)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["l2_input_x"] = x
|
||||
self.forward_inputs["l2_input_y"] = y
|
||||
return self.l(x)
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(100, 100)
|
||||
self.l2 = ToyModule(forward_inputs)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["model_input_x"] = x
|
||||
y = torch.ones(
|
||||
2, 100, device=device_type.type, dtype=torch.float32
|
||||
) # external input
|
||||
return self.l2(self.l1(x), y)
|
||||
|
||||
forward_inputs: dict[str, torch.Tensor] = {}
|
||||
model = ToyModel(forward_inputs).to(device_type)
|
||||
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
|
||||
replicate(
|
||||
model.l2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
|
||||
),
|
||||
)
|
||||
replicate(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
model(x).sum().backward()
|
||||
|
||||
# If we enable `model.l2` to cast (as default), then `l2_input_y` gets
|
||||
# cast to fp16, and if we disable, then it says as fp32.
|
||||
self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
|
||||
self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
|
||||
self.assertEqual(
|
||||
forward_inputs["l2_input_y"].dtype,
|
||||
torch.float16 if enable_submodule_cast else torch.float32,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_norm_modules_bf16(self):
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
|
||||
self._test_norm_modules(mp_policy)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_norm_modules_fp16(self):
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
|
||||
self._test_norm_modules(mp_policy)
|
||||
|
||||
def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
|
||||
def inner(model: nn.Module, x: torch.Tensor):
|
||||
# Run forward and backward to check for no type mismatch errors
|
||||
z = model(x)
|
||||
self.assertEqual(z.dtype, mp_policy.param_dtype)
|
||||
z.sum().backward()
|
||||
|
||||
# Layer norm
|
||||
model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((4, 32)))
|
||||
|
||||
# Batch norm 1D
|
||||
model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((4, 32)))
|
||||
|
||||
# Batch norm 2D: error in backward from buffer dtype mismatch
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
if TEST_HPU:
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
|
||||
):
|
||||
# Errors in batch norm 2D backward
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: cast buffers down to lower precision
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
# Casting batch norm buffers to the lower precision allows backward
|
||||
model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
|
||||
model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: use special mixed precision policy
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
|
||||
replicate(model[1], mp_policy=bn_mp_policy)
|
||||
for module in (model[0], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_clamp_reduce_dtype(self):
|
||||
# Initialize the model directly in bf16
|
||||
init_dtype = torch.bfloat16
|
||||
model = nn.Sequential(
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
).to(device_type.type)
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
|
||||
)
|
||||
# Check that we did not clamp the reduce dtype
|
||||
self.assertEqual(mp_policy.reduce_dtype, torch.bfloat16)
|
||||
for module in model:
|
||||
replicate((module), mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
|
||||
# Check that the reduce-scatter runs in bf16 even after we change the
|
||||
# model from bf16 to fp32
|
||||
model.to(torch.float32)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, torch.bfloat16)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
inp = torch.randn((4, 32), device=device_type.type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_dataclass_input(self):
|
||||
@dataclasses.dataclass
|
||||
class Input:
|
||||
x: torch.Tensor
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._layer = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, input: Input):
|
||||
return self._layer(input.x)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
|
||||
)
|
||||
model = Model()
|
||||
inp = Input(torch.randn(2, 10).cuda())
|
||||
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -5,7 +5,6 @@ import copy
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
@ -18,20 +17,8 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
_CHECKPOINT_PREFIX,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
)
|
||||
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
@ -39,7 +26,6 @@ from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
MLP,
|
||||
MLPStack,
|
||||
patch_all_gather,
|
||||
patch_reduce_scatter,
|
||||
)
|
||||
@ -856,385 +842,5 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
|
||||
class TestReplicateGradientAccumulation(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_gradient_accumulation(self):
|
||||
"""
|
||||
Tests gradient accumulation with/without gradient reduction and
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
|
||||
shard_size, replicate_size = 1, self.world_size
|
||||
meshes = init_device_mesh(
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"mesh": [meshes],
|
||||
"reshard_after_forward": [True, False],
|
||||
# "all": disable reduce-scatter for all modules
|
||||
# "root_only": disable reduce-scatter for root's linear only
|
||||
# "some_mlps": disable reduce-scatter for some MLPs
|
||||
"mode": ["all", "root_only", "some_mlps"],
|
||||
"reshard_after_backward": [False, True],
|
||||
"offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
|
||||
# For HSDP only:
|
||||
# `True`: reduce-scatter only (no all-reduce) each microbatch
|
||||
# until the last microbatch
|
||||
# `False`: neither reduce-scatter nor all-reduce each
|
||||
# microbatch until the last microbatch
|
||||
"reduce_scatter_only": [False, True],
|
||||
},
|
||||
self._test_gradient_accumulation,
|
||||
)
|
||||
|
||||
def _test_gradient_accumulation(
|
||||
self,
|
||||
mesh: DeviceMesh,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
mode: str,
|
||||
reshard_after_backward: bool,
|
||||
offload_policy: OffloadPolicy,
|
||||
reduce_scatter_only: bool, # for HSDP
|
||||
):
|
||||
if (
|
||||
(
|
||||
not reshard_after_backward
|
||||
and (reshard_after_forward is not False or mode == "some_mlps")
|
||||
)
|
||||
or (
|
||||
isinstance(offload_policy, CPUOffloadPolicy)
|
||||
and reshard_after_forward is not True
|
||||
)
|
||||
or (
|
||||
mesh.ndim != 2
|
||||
) # may eventually need to change once decision on device mesh is made
|
||||
):
|
||||
return # skip since not common or applicable
|
||||
|
||||
torch.manual_seed(42)
|
||||
batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
|
||||
if mode == "some_mlps":
|
||||
num_mlps_to_disable_reduce_scatter = 2
|
||||
modules = [nn.Linear(lin_dim, lin_dim)]
|
||||
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
|
||||
model = nn.Sequential(*modules)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for mlp in model[1:]:
|
||||
replicate_fn(mlp)
|
||||
replicate_fn(model) # root gets the 1st linear
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
def set_grad_sync_flag(
|
||||
module: nn.Module, is_last_microbatch: bool, recurse: bool = True
|
||||
):
|
||||
if reduce_scatter_only:
|
||||
module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
|
||||
else:
|
||||
module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)
|
||||
|
||||
def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
|
||||
if mode == "all":
|
||||
set_grad_sync_flag(_model, is_last_microbatch)
|
||||
if not reshard_after_backward:
|
||||
_model.set_reshard_after_backward(is_last_microbatch)
|
||||
elif mode == "some_mlps":
|
||||
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
|
||||
set_grad_sync_flag(mlp, is_last_microbatch)
|
||||
if not reshard_after_backward:
|
||||
mlp.set_reshard_after_backward(is_last_microbatch)
|
||||
elif mode == "root_only":
|
||||
set_grad_sync_flag(model, is_last_microbatch, recurse=False)
|
||||
if not reshard_after_backward:
|
||||
model.set_reshard_after_backward(is_last_microbatch, recurse=False)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
for iter_idx in range(5):
|
||||
comm_count_list = []
|
||||
|
||||
for microbatch_idx in range(num_microbatches):
|
||||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||
set_backward_flags(model, is_last_microbatch)
|
||||
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
with CommDebugMode() as comm_mode:
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
comm_count_list.append(comm_mode.get_comm_counts())
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
comm_counts = defaultdict(int)
|
||||
for comm_count_dict in comm_count_list:
|
||||
for collective, count in comm_count_dict.items():
|
||||
comm_counts[collective] += count
|
||||
|
||||
all_gather_count = comm_counts[c10d_ops._allgather_base_]
|
||||
# reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
|
||||
all_reduce_count = comm_counts[c10d_ops.allreduce_]
|
||||
|
||||
# Expect one reduce-scatter per MLP plus one for the root's linear
|
||||
# on the last microbatch
|
||||
# expected_reduce_scatter_count = 0
|
||||
expected_all_reduce_count = num_mlps + 1
|
||||
|
||||
if mode == "some_mlps":
|
||||
# Expect additional reduce-scatters for non-disabled MLPs and
|
||||
# the root's linear
|
||||
expected_all_reduce_count += (
|
||||
num_mlps - num_mlps_to_disable_reduce_scatter + 1
|
||||
) * (num_microbatches - 1)
|
||||
elif mode == "root_only":
|
||||
# Expect additional reduce-scatters for all MLPs
|
||||
expected_all_reduce_count += (num_mlps) * (num_microbatches - 1)
|
||||
|
||||
# self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
|
||||
self.assertEqual(all_reduce_count, expected_all_reduce_count)
|
||||
|
||||
# Expect one all-gather per MLP plus one for the root's linear in
|
||||
# the first microbatch's forward
|
||||
expected_all_gather_count = 0
|
||||
|
||||
self.assertEqual(all_gather_count, expected_all_gather_count)
|
||||
|
||||
for param in ref_model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
for _optim in (optim, ref_optim):
|
||||
_optim.step()
|
||||
# When `set_to_none=False`, we are exercising mixing
|
||||
# gradient accumulation with and without communication
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_1f1b_microbatching(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"use_explicit_unshard": [False, True],
|
||||
"reshard_after_backward": [False, True],
|
||||
},
|
||||
self._test_1f1b_microbatching,
|
||||
)
|
||||
|
||||
def _test_1f1b_microbatching(
|
||||
self, use_explicit_unshard: bool, reshard_after_backward: bool
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
replicate(module, reshard_after_forward=False)
|
||||
replicate(model, reshard_after_forward=False)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_microbatches = 3
|
||||
local_batch_size = 2
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inps = [
|
||||
torch.randint(
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(local_batch_size, 16),
|
||||
device=device_type.type,
|
||||
)
|
||||
for _ in range(num_microbatches)
|
||||
]
|
||||
|
||||
# Before pipelining, we may prefer to issue all all-gathers ahead of
|
||||
# time to increase overlap opportunity at no difference in parameter
|
||||
# memory usage since we do not reshard after forward
|
||||
if use_explicit_unshard:
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.unshard(async_op=True)
|
||||
|
||||
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
|
||||
# last microbatch
|
||||
losses: list[torch.Tensor] = []
|
||||
ref_losses: list[torch.Tensor] = []
|
||||
for inp_idx, inp in enumerate(inps):
|
||||
is_last_microbatch = inp_idx == num_microbatches - 1
|
||||
model.set_requires_gradient_sync(is_last_microbatch)
|
||||
model.set_is_last_backward(is_last_microbatch)
|
||||
if not reshard_after_backward:
|
||||
model.set_reshard_after_backward(is_last_microbatch)
|
||||
losses.append(model(inp).sum())
|
||||
losses[-1].backward()
|
||||
ref_losses.append(ref_model(inp).sum())
|
||||
ref_losses[-1].backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
|
||||
for loss, ref_loss in zip(losses, ref_losses):
|
||||
self.assertEqual(loss, ref_loss)
|
||||
optim.step()
|
||||
ref_optim.step()
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestReplicateCustomForwardMethod(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_register_fsdp_forward_method(self):
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
|
||||
|
||||
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_features(imgs).sum(dim=1)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
# Run `vit.forward_features`, which is not `forward`!
|
||||
patch_embeddings = self.vit.forward_features(imgs)
|
||||
return self.projector(patch_embeddings)
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = Model()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(model.vit)
|
||||
replicate(model.projector)
|
||||
replicate(model)
|
||||
register_fsdp_forward_method(model.vit, "forward_features")
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
ref_loss.backward()
|
||||
loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestReplicateTPTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
device_type.type,
|
||||
(2, 1, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(8)
|
||||
def test_replicate_tp(self):
|
||||
global_mesh = self.init_global_mesh()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 5, 16, 17],
|
||||
"foreach": [False],
|
||||
},
|
||||
functools.partial(self._test_replicate_tp, global_mesh),
|
||||
)
|
||||
|
||||
def _test_replicate_tp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
foreach: bool,
|
||||
):
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
||||
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
|
||||
|
||||
parallelize_plan = {
|
||||
# Pass `use_local_output=False` to keep as DTensor to preserve
|
||||
# uneven activation dims
|
||||
"0.in_proj": ColwiseParallel(use_local_output=False),
|
||||
"0.out_proj": RowwiseParallel(use_local_output=False),
|
||||
"1.in_proj": ColwiseParallel(use_local_output=False),
|
||||
"1.out_proj": RowwiseParallel(use_local_output=False),
|
||||
"2.in_proj": ColwiseParallel(use_local_output=False),
|
||||
"2.out_proj": (RowwiseParallel()),
|
||||
}
|
||||
|
||||
model = parallelize_module(model, tp_mesh, parallelize_plan)
|
||||
|
||||
for module in model:
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, device_mesh=dp_mesh)
|
||||
replicate(model, device_mesh=dp_mesh)
|
||||
|
||||
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
||||
# strided-sharded layers.
|
||||
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(ref_p, p.full_tensor())
|
||||
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
|
||||
|
||||
torch.manual_seed(42 + dp_pg.rank() + 1)
|
||||
device = device_type
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
||||
for param in ref_model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
|
||||
for _optim in (ref_optim, optim):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
_optim.step()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
for _, p in model.named_parameters():
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(p.device_mesh.ndim, 3)
|
||||
self.assertEqual(len(p.placements), 3)
|
||||
self.assertEqual(
|
||||
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -5,7 +5,6 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
@ -38,18 +37,6 @@ class SimpleModel(torch.nn.Module):
|
||||
return self.mlp_1(self.mlp_0(input))
|
||||
|
||||
|
||||
class SimpleModelAnnotated(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.mlp_0 = MLPModule(device)
|
||||
self.mlp_1 = MLPModule(device)
|
||||
|
||||
def forward(self, input):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
x = self.mlp_0(input)
|
||||
return self.mlp_1(x)
|
||||
|
||||
|
||||
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
# needed for stric export
|
||||
torch.utils._pytree.register_constant(DTensorSpec)
|
||||
@ -103,7 +90,7 @@ class DTensorExportTest(TestCase):
|
||||
)
|
||||
self.device_type = "cuda"
|
||||
|
||||
def _run_test(self, export_fn, test_annotation=False):
|
||||
def _run_test(self, export_fn):
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
|
||||
@ -114,11 +101,7 @@ class DTensorExportTest(TestCase):
|
||||
mesh_dim_names=["dp", "tp"],
|
||||
)
|
||||
|
||||
model = None
|
||||
if test_annotation:
|
||||
model = SimpleModelAnnotated(self.device_type)
|
||||
else:
|
||||
model = SimpleModel(self.device_type)
|
||||
model = SimpleModel(self.device_type)
|
||||
parallelize_plan = {
|
||||
"mlp_0.net1": ColwiseParallel(),
|
||||
"mlp_0.net2": RowwiseParallel(),
|
||||
@ -148,116 +131,6 @@ class DTensorExportTest(TestCase):
|
||||
1,
|
||||
)
|
||||
|
||||
if test_annotation:
|
||||
|
||||
def has_tag(node):
|
||||
return "custom" in node.meta and node.meta["custom"] == {"pp_stage": 0}
|
||||
|
||||
def marked_nodes(gm):
|
||||
return [
|
||||
node.name
|
||||
for node in gm.graph.nodes
|
||||
if has_tag(node) and node.op == "call_function"
|
||||
]
|
||||
|
||||
def unmarked_nodes(gm):
|
||||
return [
|
||||
node.name
|
||||
for node in gm.graph.nodes
|
||||
if not has_tag(node) and node.op == "call_function"
|
||||
]
|
||||
|
||||
marked_nodes_fw = [
|
||||
"t",
|
||||
"addmm",
|
||||
"view",
|
||||
"relu",
|
||||
"view_1",
|
||||
"t_1",
|
||||
"div",
|
||||
"addmm_1",
|
||||
"all_reduce",
|
||||
"wait_tensor",
|
||||
"view_2",
|
||||
"t_12",
|
||||
]
|
||||
unmarked_nodes_fw = [
|
||||
"view_3",
|
||||
"t_2",
|
||||
"addmm_2",
|
||||
"view_4",
|
||||
"relu_1",
|
||||
"view_5",
|
||||
"t_3",
|
||||
"div_1",
|
||||
"addmm_3",
|
||||
"all_reduce_1",
|
||||
"wait_tensor_1",
|
||||
"view_6",
|
||||
"t_4",
|
||||
"t_8",
|
||||
]
|
||||
|
||||
marked_nodes_bw = [
|
||||
"mm_4",
|
||||
"t_13",
|
||||
"view_1",
|
||||
"mm_5",
|
||||
"t_14",
|
||||
"sum_3",
|
||||
"view_9",
|
||||
"t_15",
|
||||
"detach",
|
||||
"detach_1",
|
||||
"detach_6",
|
||||
"detach_7",
|
||||
"threshold_backward_1",
|
||||
"t_16",
|
||||
"mm_6",
|
||||
"t_17",
|
||||
"sum_4",
|
||||
"view_10",
|
||||
"t_18",
|
||||
]
|
||||
unmarked_nodes_bw = [
|
||||
"mm",
|
||||
"t_5",
|
||||
"view_5",
|
||||
"mm_1",
|
||||
"t_6",
|
||||
"sum_1",
|
||||
"view_7",
|
||||
"t_7",
|
||||
"detach_2",
|
||||
"detach_3",
|
||||
"detach_4",
|
||||
"detach_5",
|
||||
"threshold_backward",
|
||||
"mm_2",
|
||||
"t_9",
|
||||
"mm_3",
|
||||
"t_10",
|
||||
"sum_2",
|
||||
"view_8",
|
||||
"t_11",
|
||||
"all_reduce_2",
|
||||
"wait_tensor_2",
|
||||
]
|
||||
|
||||
self.assertEqual(marked_nodes(fw_gm), marked_nodes_fw)
|
||||
self.assertEqual(unmarked_nodes(fw_gm), unmarked_nodes_fw)
|
||||
|
||||
self.assertEqual(marked_nodes(bw_gm), marked_nodes_bw)
|
||||
self.assertEqual(unmarked_nodes(bw_gm), unmarked_nodes_bw)
|
||||
|
||||
self.assertEqual(
|
||||
set(marked_nodes(joint_gm)), set(marked_nodes_fw + marked_nodes_bw)
|
||||
)
|
||||
self.assertEqual(
|
||||
set(unmarked_nodes(joint_gm)),
|
||||
set(unmarked_nodes_fw + unmarked_nodes_bw),
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
@ -277,9 +150,6 @@ class DTensorExportTest(TestCase):
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
||||
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
@ -356,7 +355,7 @@ class RedistributeTest(DTensorTestBase):
|
||||
replica_spec = Replicate()
|
||||
# 1) test replicate -> partial forward
|
||||
replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
|
||||
with self.assertRaisesRegex(RuntimeError, "Can not redistribute"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
|
||||
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
|
||||
|
||||
from torch.distributed.tensor._redistribute import Redistribute
|
||||
@ -620,38 +619,6 @@ class RedistributeTest(DTensorTestBase):
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
self.assertEqual(out.placements, [Shard(0), dst])
|
||||
|
||||
@with_comms
|
||||
def test_redistribute_to_partial(self):
|
||||
mesh = init_device_mesh(self.device_type, (2, 2))
|
||||
|
||||
tensor = torch.randn(12, 8, device=self.device_type)
|
||||
|
||||
test_cases = [
|
||||
# Partial to Partial is allowed
|
||||
([Partial(), Shard(0)], [Partial(), Shard(0)], True),
|
||||
([Partial(), Shard(0)], [Partial(), Shard(1)], True),
|
||||
([Shard(0), Partial()], [Replicate(), Partial()], True),
|
||||
([Shard(0), Partial("prod")], [Replicate(), Partial("prod")], True),
|
||||
# Non-Partial to Partial is NOT allowed
|
||||
([Shard(0), Replicate()], [Shard(0), Partial()], False),
|
||||
([Shard(0), Replicate()], [Replicate(), Partial()], False),
|
||||
([Shard(0), Shard(1)], [Replicate(), Partial()], False),
|
||||
# Partial to partial is allowed, if only the reduction ops is the same
|
||||
([Shard(0), Partial("prod")], [Replicate(), Partial("sum")], False),
|
||||
]
|
||||
|
||||
for src, dst, allow in test_cases:
|
||||
dt = DTensor.from_local(tensor, mesh, src)
|
||||
raise_context = (
|
||||
self.assertRaisesRegex(RuntimeError, "Can not redistribute")
|
||||
if not allow
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with raise_context:
|
||||
out = dt.redistribute(mesh, dst)
|
||||
self.assertEqual(out.placements, dst)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(RedistributeTest)
|
||||
|
||||
|
||||
@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
7|aten.view.default||l__self___fc1
|
||||
6|aten.t.default||l__self___fc1
|
||||
5|aten.view.default||l__self___fc1
|
||||
4|aten.view.default||flatten
|
||||
4|aten.view.default||
|
||||
2|aten.detach.default||l__self___relu1
|
||||
2|aten.detach.default||l__self___relu1
|
||||
2|aten.threshold_backward.default||l__self___relu1
|
||||
|
||||
@ -15228,11 +15228,6 @@ def forward(self, x):
|
||||
test_serdes=True,
|
||||
)
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureRetraceability
|
||||
@testing.expectedFailureStrictV2
|
||||
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
|
||||
@testing.expectedFailureSerDer
|
||||
def test_preserve_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -15251,22 +15246,17 @@ def forward(self, x):
|
||||
ep = export(m, (torch.randn(10),))
|
||||
|
||||
for node in ep.graph.nodes:
|
||||
if node.op in ("placeholder", "output"):
|
||||
continue
|
||||
if node.target == torch.ops.aten.add.Tensor:
|
||||
if node.target == torch.ops.aten.add.default:
|
||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
|
||||
elif node.target == torch.ops.aten.sub.Tensor:
|
||||
if node.target == torch.ops.aten.sub.default:
|
||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
|
||||
elif node.target == torch.ops.aten.mul.Tensor:
|
||||
if node.target == torch.ops.aten.mul.default:
|
||||
self.assertTrue(
|
||||
node.meta["custom"],
|
||||
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
|
||||
)
|
||||
elif node.target == torch.ops.aten.div.Tensor:
|
||||
if "custom" in node.meta:
|
||||
self.assertTrue(node.meta["custom"], {})
|
||||
else:
|
||||
raise AssertionError(f"Node not checked: {node}, {node.target}")
|
||||
if node.target == torch.ops.aten.div.default:
|
||||
self.assertTrue(node.meta["custom"], {})
|
||||
|
||||
def test_dynamic_shapes_serdes_generic(self):
|
||||
from torch._export.serde.dynamic_shapes import (
|
||||
|
||||
@ -13,7 +13,6 @@ import torch.fx.traceback as fx_traceback
|
||||
import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._decomp import decomposition_table
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._functorch._aot_autograd.descriptors import (
|
||||
BufferAOTInput,
|
||||
@ -778,34 +777,36 @@ class inner_f(torch.nn.Module):
|
||||
|
||||
inputs = (torch.randn(4, 3),)
|
||||
|
||||
for with_export in [False]: # TODO: make dynamo work for annotation
|
||||
for with_export in [True, False]:
|
||||
with ExitStack() as stack:
|
||||
model = None
|
||||
with fx_traceback.preserve_node_meta():
|
||||
if with_export:
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
model = _dynamo_graph_capture_for_export(model)(*inputs)
|
||||
ep = torch.export.export(SimpleLinear(), inputs)
|
||||
model = ep.module()
|
||||
else:
|
||||
model = SimpleLinear()
|
||||
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack, model, inputs, decompositions={}
|
||||
stack, model, inputs, decompositions=decomposition_table
|
||||
)
|
||||
|
||||
for node in joint_with_descriptors.graph_module.graph.nodes:
|
||||
if node.op in ("placeholder", "output"):
|
||||
continue
|
||||
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
if (
|
||||
node.target
|
||||
in (
|
||||
torch.ops.prims.transpose.default,
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.prims.mul.default,
|
||||
torch.ops.prims.broadcast_in_dim.default,
|
||||
torch.ops.prims.add.default,
|
||||
)
|
||||
# TODO: add annotation to backward graph nodes
|
||||
and node.meta.get("partitioner_tag") != "is_backward"
|
||||
):
|
||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
|
||||
elif node.target == torch.ops.aten.sub.Tensor:
|
||||
if "custom" in node.meta:
|
||||
self.assertTrue(node.meta.get("custom", {}), {})
|
||||
else:
|
||||
raise AssertionError(f"Node not checked: {node}, {node.target}")
|
||||
if node.target == torch.ops.aten.sub.default:
|
||||
self.assertTrue(node.meta.get("custom", {}), {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -908,13 +908,11 @@ class TestFakeQuantize(TestCase):
|
||||
self.assertEqual(fq_module.activation_post_process.quant_min, 0)
|
||||
self.assertEqual(fq_module.activation_post_process.quant_max, 127)
|
||||
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||
sampled_dtype=st.sampled_from(['bf16', 'fp16', 'fp32']))
|
||||
def test_fused_moving_avg_obs_fake_quant(self, device, sampled_dtype):
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']))
|
||||
def test_fused_moving_avg_obs_fake_quant(self, device):
|
||||
try:
|
||||
if device == 'cpu':
|
||||
sampled_dtype = 'fp32'
|
||||
dtype = {'bf16' : torch.bfloat16, 'fp16' : torch.half, 'fp32' : torch.float32}[sampled_dtype]
|
||||
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
|
||||
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
with torch.device(device):
|
||||
|
||||
@ -1065,17 +1065,15 @@ class TestFakeQuantizeOps(TestCase):
|
||||
|
||||
class TestFusedObsFakeQuant(TestCase):
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||
sampled_dtype=st.sampled_from(['bf16', 'fp16', 'fp32']),
|
||||
symmetric_quant=st.booleans(), use_bool=st.booleans())
|
||||
@settings(deadline=None)
|
||||
def test_fused_obs_fake_quant_moving_avg(self, device, sampled_dtype, symmetric_quant, use_bool) -> None:
|
||||
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None:
|
||||
"""
|
||||
Tests the case where we call the fused_obs_fake_quant op multiple times
|
||||
and update the running_min and max of the activation tensors.
|
||||
"""
|
||||
if device == "cpu":
|
||||
sampled_dtype = "fp32"
|
||||
dtype = {'bf16' : torch.bfloat16, 'fp16' : torch.half, 'fp32' : torch.float32}[sampled_dtype]
|
||||
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
|
||||
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
|
||||
|
||||
in_running_min_ref = out_running_min_ref = torch.tensor(float("inf"), dtype=dtype)
|
||||
in_running_min_op = torch.tensor(float("inf"), dtype=dtype, device=device)
|
||||
|
||||
@ -806,60 +806,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
|
||||
|
||||
@onlyCUDA
|
||||
@parametrize("ops", [("mm", torch.mm), ("bmm", torch.bmm), ("addmm", torch.addmm), ("baddbmm", torch.baddbmm)])
|
||||
def test_input_dimension_checking_out_dtype(self, ops):
|
||||
op_name, op = ops
|
||||
B = 2
|
||||
M, N, K = 32, 32, 32
|
||||
|
||||
def is_addmm():
|
||||
return "add" in op_name
|
||||
|
||||
def is_batched():
|
||||
return "bmm" in op_name
|
||||
|
||||
if is_batched():
|
||||
a = torch.randn(B, M, K, device="cuda", dtype=torch.bfloat16)
|
||||
mismatch_k_b = torch.randn(B, K + 1, N, device="cuda", dtype=torch.bfloat16)
|
||||
c = torch.randn(B, M, N, device="cuda", dtype=torch.bfloat16)
|
||||
extra_dim_b = a.clone().unsqueeze(0)
|
||||
|
||||
mismatch_k_err = "Expected size for first two dimensions of batch2 tensor to be"
|
||||
extra_dim_err = "batch2 must be a 3D tensor"
|
||||
else:
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
mismatch_k_b = torch.randn(K + 1, N, device="cuda", dtype=torch.bfloat16)
|
||||
c = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
|
||||
extra_dim_b = a.clone().unsqueeze(0)
|
||||
|
||||
mismatch_k_err = "mat1 and mat2 shapes cannot be multiplied"
|
||||
extra_dim_err = "mat2 must be a matrix, got 3-D tensor"
|
||||
|
||||
# Test mismatch K
|
||||
with self.assertRaisesRegex(RuntimeError, mismatch_k_err):
|
||||
if is_addmm():
|
||||
op(c, a, mismatch_k_b, out_dtype=torch.float32)
|
||||
else:
|
||||
op(a, mismatch_k_b, out_dtype=torch.float32)
|
||||
|
||||
# Test extra dimension
|
||||
with self.assertRaisesRegex(RuntimeError, extra_dim_err):
|
||||
if is_addmm():
|
||||
op(c, a, extra_dim_b, out_dtype=torch.float32)
|
||||
else:
|
||||
op(c, extra_dim_b, out_dtype=torch.float32)
|
||||
|
||||
if is_batched():
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected size for first two dimensions of batch2 tensor to be"):
|
||||
# Test mismatch B for bmm/baddbmm
|
||||
mismatch_batch_dim_b = torch.randn(B + 1, K, N, device="cuda", dtype=torch.bfloat16)
|
||||
if is_addmm():
|
||||
op(c, a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
||||
else:
|
||||
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
||||
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices"
|
||||
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
||||
|
||||
@ -726,10 +726,9 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
self.export_metadata = ExportMetaData()
|
||||
|
||||
# Dict of inlined unspecialized modules to generate the
|
||||
# dynamo_flat_name_to_original_fqn mapping. The code here follows what
|
||||
# export was doing earlier.
|
||||
self.inlined_unspecialized_modules: dict[str, torch.nn.Module] = {}
|
||||
# Set of inlined unspecialized modules names to generate the
|
||||
# dynamo_flat_name_to_original_fqn mapping.
|
||||
self.used_inlined_inbuilt_modules_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
def mark_bytecode_tracing_start(self) -> None:
|
||||
self.compiler_trace_stack.enter_context(
|
||||
@ -2585,7 +2584,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
# some of the tensor objects to be held alive for longer than necessary.
|
||||
self.root_tx = None # type: ignore[assignment]
|
||||
self.nn_modules.clear()
|
||||
self.inlined_unspecialized_modules.clear()
|
||||
self.used_inlined_inbuilt_modules_names.clear()
|
||||
self.param_name_to_source = None
|
||||
|
||||
for node in self.graph.nodes:
|
||||
@ -2619,9 +2618,9 @@ class OutputGraph(OutputGraphCommon):
|
||||
) -> None:
|
||||
name = OutputGraph.module_key_name(source.name())
|
||||
name = get_unique_name_wrt(
|
||||
name, self.inlined_unspecialized_modules, self.global_scope
|
||||
name, self.used_inlined_inbuilt_modules_names, self.global_scope
|
||||
)
|
||||
self.inlined_unspecialized_modules[name] = inlined_module
|
||||
self.used_inlined_inbuilt_modules_names.add(name)
|
||||
|
||||
def register_leaf_name(leaf_name: str) -> None:
|
||||
assert self.param_name_to_source is not None
|
||||
|
||||
@ -420,18 +420,14 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
|
||||
# the descendants of graph inputs corresponding to fwd inputs, didn't
|
||||
# seem obvious at first glance on how to partition graph inputs into
|
||||
# fwd vs bwd without relying on string names.
|
||||
return (
|
||||
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
|
||||
)
|
||||
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
|
||||
|
||||
def _is_backward_node_with_seq_nr(node):
|
||||
# For now, assume that if nn_module_stack_metadata is not populated,
|
||||
# this node is from the backward. Ignore nodes without `seq_nr`.
|
||||
# TODO(future): there is likely a less brittle way to do this, same
|
||||
# as with the forward.
|
||||
return (
|
||||
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
|
||||
)
|
||||
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
|
||||
|
||||
fwd_seq_nr_to_node = {}
|
||||
for node in fx_g.graph.nodes:
|
||||
@ -451,10 +447,8 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
|
||||
# fwd_node should always exist, but handle non-existence just in case
|
||||
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
|
||||
if fwd_node is not None:
|
||||
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
|
||||
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
|
||||
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
|
||||
# TODO: better to change to a specific field of custom?
|
||||
node.meta["custom"] = fwd_node.meta.get("custom")
|
||||
|
||||
|
||||
def register_buffer_assignment_hook(mod, assigned_buffers):
|
||||
|
||||
@ -476,7 +476,7 @@ def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
|
||||
# u0==u1 assume the same, no broadcasting!
|
||||
torch._check(
|
||||
x == y,
|
||||
lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
|
||||
"sizes assumed to be the same due to unbacked broadcasting semantics",
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@ -83,7 +83,7 @@ def _quantize_weight(float_wt, observer):
|
||||
torch.qint8,
|
||||
)
|
||||
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
|
||||
elif observer.qscheme == torch.per_channel_affine_float_qparams:
|
||||
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
|
||||
qweight = torch.quantize_per_channel(
|
||||
float_wt,
|
||||
wt_scale.to(torch.float),
|
||||
|
||||
@ -64,7 +64,9 @@ def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
|
||||
|
||||
|
||||
def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
|
||||
return qscheme == torch.per_channel_affine_float_qparams
|
||||
return qscheme in [
|
||||
torch.per_channel_affine_float_qparams,
|
||||
]
|
||||
|
||||
|
||||
class FakeQuantizeBase(ABC, Module):
|
||||
|
||||
@ -227,7 +227,7 @@ def is_getattr_tensor_metadata_node(node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == getattr
|
||||
and node.args[1] == "shape"
|
||||
and node.args[1] in ["shape"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -388,7 +388,7 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
)
|
||||
else:
|
||||
zero_point = zero_point.new_full(zero_point.size(), 128)
|
||||
elif self.dtype == torch.uint16:
|
||||
elif self.dtype in [torch.uint16]:
|
||||
zero_point = zero_point.new_full(zero_point.size(), 2**15)
|
||||
elif self.qscheme == torch.per_channel_affine_float_qparams:
|
||||
scale = (max_val - min_val) / float(quant_max - quant_min)
|
||||
|
||||
@ -237,7 +237,7 @@ def _add_observer_(
|
||||
|
||||
for name, child in module.named_children():
|
||||
# TODO remove Dropout special after codebase stable
|
||||
if type_before_parametrizations(child) is nn.Dropout:
|
||||
if type_before_parametrizations(child) in [nn.Dropout]:
|
||||
continue
|
||||
elif issubclass(
|
||||
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
|
||||
|
||||
@ -598,7 +598,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
_annotate_nodes_not_quantize(linear_node)
|
||||
return
|
||||
input_qspec_map = {}
|
||||
assert linear_node.target == torch.ops.aten.linear.default
|
||||
assert linear_node.target in (torch.ops.aten.linear.default,)
|
||||
has_bias = len(linear_node.args) == 3
|
||||
input_index = 0
|
||||
weight_index = 1
|
||||
@ -1436,9 +1436,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
"Linear partition cannot have more than one output node"
|
||||
)
|
||||
linear_node = partition.output_nodes[0]
|
||||
if (
|
||||
linear_node.op != "call_function"
|
||||
or linear_node.target != torch.ops.aten.linear.default
|
||||
if linear_node.op != "call_function" or linear_node.target not in (
|
||||
torch.ops.aten.linear.default,
|
||||
):
|
||||
raise ValueError(f"{linear_node} is not an aten linear operator")
|
||||
# skip annotation if it is already annotated
|
||||
@ -1468,9 +1467,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
linear_node, unary_node = self._get_output_nodes_of_partitions(
|
||||
[linear_partition, unary_partition]
|
||||
)
|
||||
if (
|
||||
linear_node.op != "call_function"
|
||||
or linear_node.target != torch.ops.aten.linear.default
|
||||
if linear_node.op != "call_function" or linear_node.target not in (
|
||||
torch.ops.aten.linear.default,
|
||||
):
|
||||
continue
|
||||
if _skip_annotate([unary_node, linear_node], filter_fn):
|
||||
|
||||
@ -501,9 +501,9 @@ def calculate_qmin_qmax(
|
||||
quant_min, quant_max = 0, 255
|
||||
elif dtype in [torch.qint32, torch.int32]:
|
||||
quant_min, quant_max = -1 * (2**31), (2**31) - 1
|
||||
elif dtype == torch.uint16:
|
||||
elif dtype in [torch.uint16]:
|
||||
quant_min, quant_max = 0, 2**16 - 1
|
||||
elif dtype == torch.int16:
|
||||
elif dtype in [torch.int16]:
|
||||
quant_min, quant_max = -(2**15), 2**15 - 1
|
||||
else:
|
||||
quant_min, quant_max = 0, 15
|
||||
|
||||
@ -624,7 +624,7 @@ void broadcast(
|
||||
")");
|
||||
ncclComm_t comm = comms[i];
|
||||
NCCL_CHECK(ncclBcast(
|
||||
tensors[i].mutable_data_ptr(),
|
||||
tensors[i].data_ptr(),
|
||||
numel,
|
||||
data_type,
|
||||
0,
|
||||
@ -669,9 +669,9 @@ void reduce(
|
||||
|
||||
ncclComm_t comm = comms_ref[i];
|
||||
NCCL_CHECK(ncclReduce(
|
||||
inputs[i].const_data_ptr(),
|
||||
inputs[i].data_ptr(),
|
||||
static_cast<std::remove_cv_t<decltype(i)>>(root) == i
|
||||
? output.mutable_data_ptr()
|
||||
? output.data_ptr()
|
||||
: nullptr,
|
||||
count,
|
||||
data_type,
|
||||
@ -723,8 +723,8 @@ void all_reduce(
|
||||
|
||||
ncclComm_t comm = comms_ref[i];
|
||||
NCCL_CHECK(ncclAllReduce(
|
||||
inputs[i].const_data_ptr(),
|
||||
outputs[i].mutable_data_ptr(),
|
||||
inputs[i].data_ptr(),
|
||||
outputs[i].data_ptr(),
|
||||
count,
|
||||
data_type,
|
||||
to_nccl_red_op(op),
|
||||
@ -765,8 +765,8 @@ void reduce_scatter(
|
||||
|
||||
ncclComm_t comm = comms_ref[i];
|
||||
NCCL_CHECK(ncclReduceScatter(
|
||||
inputs[i].const_data_ptr(),
|
||||
outputs[i].mutable_data_ptr(),
|
||||
inputs[i].data_ptr(),
|
||||
outputs[i].data_ptr(),
|
||||
count,
|
||||
data_type,
|
||||
to_nccl_red_op(op),
|
||||
@ -807,18 +807,18 @@ void all_gather(
|
||||
ncclComm_t comm = comms_ref[i];
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
NCCL_CHECK(ncclAllGather(
|
||||
inputs[i].const_data_ptr(),
|
||||
outputs[i].mutable_data_ptr(),
|
||||
inputs[i].data_ptr(),
|
||||
outputs[i].data_ptr(),
|
||||
count,
|
||||
data_type,
|
||||
to_nccl_comm(comm),
|
||||
stream));
|
||||
#else
|
||||
NCCL_CHECK(ncclAllGather(
|
||||
inputs[i].const_data_ptr(),
|
||||
inputs[i].data_ptr(),
|
||||
count,
|
||||
data_type,
|
||||
outputs[i].mutable_data_ptr(),
|
||||
outputs[i].data_ptr(),
|
||||
to_nccl_comm(comm),
|
||||
stream));
|
||||
#endif
|
||||
@ -843,7 +843,7 @@ void all2all_single_equal_split(
|
||||
size_t count = input.numel() / size;
|
||||
[[maybe_unused]] size_t rankdiff = input.nbytes() / size;
|
||||
const auto* sendbuff = reinterpret_cast<const char*>(input.const_data_ptr());
|
||||
auto* recvbuff = reinterpret_cast<char*>(output.mutable_data_ptr());
|
||||
auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
#if defined(USE_ROCM) || defined(NCCL_ALLTOALL_SUPPORTED)
|
||||
// NCCL_ALLTOALL_SUPPORTED is used so NCCL can differentiate send/recv
|
||||
@ -964,7 +964,7 @@ void all2all(
|
||||
|
||||
if (_nccl_should_send_recv(input.numel())) {
|
||||
NCCL_CHECK(ncclSend(
|
||||
input.const_data_ptr(),
|
||||
input.data_ptr(),
|
||||
input.numel(),
|
||||
to_nccl_data_type(input),
|
||||
r,
|
||||
@ -973,7 +973,7 @@ void all2all(
|
||||
}
|
||||
if (_nccl_should_send_recv(output.numel())) {
|
||||
NCCL_CHECK(ncclRecv(
|
||||
output.mutable_data_ptr(),
|
||||
output.data_ptr(),
|
||||
output.numel(),
|
||||
to_nccl_data_type(output),
|
||||
r,
|
||||
@ -1005,7 +1005,7 @@ void send(
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclSend(
|
||||
input.const_data_ptr(),
|
||||
input.data_ptr(),
|
||||
input.numel(),
|
||||
to_nccl_data_type(input),
|
||||
dst,
|
||||
@ -1014,7 +1014,7 @@ void send(
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(
|
||||
ncclSend(
|
||||
input.const_data_ptr(),
|
||||
input.data_ptr(),
|
||||
input.numel(),
|
||||
to_nccl_data_type(input),
|
||||
dst,
|
||||
@ -1041,7 +1041,7 @@ void recv(
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclRecv(
|
||||
output.mutable_data_ptr(),
|
||||
output.data_ptr(),
|
||||
output.numel(),
|
||||
to_nccl_data_type(output),
|
||||
src,
|
||||
@ -1050,7 +1050,7 @@ void recv(
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(
|
||||
ncclRecv(
|
||||
output.mutable_data_ptr(),
|
||||
output.data_ptr(),
|
||||
output.numel(),
|
||||
to_nccl_data_type(output),
|
||||
src,
|
||||
@ -1091,7 +1091,7 @@ void gather(
|
||||
if (cur_rank == root) {
|
||||
for (const auto r : c10::irange(numranks)) {
|
||||
if (r != root) {
|
||||
auto* recvbuff = reinterpret_cast<char*>(outputs[r].mutable_data_ptr());
|
||||
auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
|
||||
NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
|
||||
} else {
|
||||
// on its own rank, simply copy from the input
|
||||
@ -1152,7 +1152,7 @@ void scatter(
|
||||
} else {
|
||||
size_t recv_count = outputs.numel();
|
||||
auto recv_type = to_nccl_data_type(outputs);
|
||||
auto* recvbuff = reinterpret_cast<char*>(outputs.mutable_data_ptr());
|
||||
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
|
||||
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
|
||||
@ -420,7 +420,7 @@ class _CudaKernel:
|
||||
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
|
||||
# CDNA4 allows a max of 160KB shared memory
|
||||
max_shared_mem = (
|
||||
65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
|
||||
65536 if device_props.gcnArchName not in ["gfx950"] else 160 * 1024
|
||||
)
|
||||
else:
|
||||
max_shared_mem = getattr(
|
||||
|
||||
@ -171,7 +171,3 @@ def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
|
||||
):
|
||||
return x
|
||||
return x.to(dtype)
|
||||
|
||||
|
||||
def is_bw() -> bool:
|
||||
return torch._C._current_graph_task_id() != -1
|
||||
|
||||
@ -31,7 +31,6 @@ from ._fsdp_common import (
|
||||
compiled_autograd_enabled,
|
||||
FSDPMeshInfo,
|
||||
HSDPMeshInfo,
|
||||
is_bw,
|
||||
TrainingState,
|
||||
)
|
||||
from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState
|
||||
@ -273,7 +272,7 @@ class FSDPParamGroup:
|
||||
the staging buffers for collective comms.
|
||||
"""
|
||||
assert isinstance(
|
||||
self._all_gather_comm, (DefaultAllGather | ProcessGroupAllocAllGather)
|
||||
self._all_gather_comm, (DefaultAllGather, ProcessGroupAllocAllGather)
|
||||
), (
|
||||
"cannot call set_allocate_memory_from_process_group() "
|
||||
f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}"
|
||||
@ -286,7 +285,7 @@ class FSDPParamGroup:
|
||||
|
||||
assert isinstance(
|
||||
self._reduce_scatter_comm,
|
||||
(DefaultReduceScatter | ProcessGroupAllocReduceScatter),
|
||||
(DefaultReduceScatter, ProcessGroupAllocReduceScatter),
|
||||
), (
|
||||
"cannot call set_allocate_memory_from_process_group() "
|
||||
f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}"
|
||||
@ -446,15 +445,8 @@ class FSDPParamGroup:
|
||||
if not compiled_autograd_enabled():
|
||||
logger.debug("%s", self._with_fqn("FSDP::post_forward"))
|
||||
with record_function(self._with_fqn("FSDP::post_forward")):
|
||||
if not compiled_autograd_enabled():
|
||||
# for AC(fully_shard(model)), AC runs fsdp's _pre_forward
|
||||
# it shouldn't change post_forward_order
|
||||
if not is_bw():
|
||||
self.reshard()
|
||||
self._record_post_forward()
|
||||
else:
|
||||
self.reshard()
|
||||
self._record_post_forward()
|
||||
self.reshard()
|
||||
self._record_post_forward()
|
||||
self._training_state = TrainingState.IDLE
|
||||
return output
|
||||
|
||||
|
||||
@ -528,10 +528,9 @@ class DTensor(torch.Tensor):
|
||||
|
||||
placements = list(placements)
|
||||
for i, placement in enumerate(placements):
|
||||
if placement.is_partial() and self.placements[i] != placement:
|
||||
if placement.is_partial():
|
||||
raise RuntimeError(
|
||||
f"Can not redistribute from {self.placements[i]} to {placement}, "
|
||||
"redistributing to Partial is for internal use only!"
|
||||
"Can not redistribute to Partial, redistributing to Partial is for internal use only!"
|
||||
)
|
||||
elif isinstance(placement, Shard) and placement.dim < 0:
|
||||
# normalize shard dim to be positive
|
||||
|
||||
@ -34,7 +34,7 @@ class LayerNorm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format != torch.contiguous_format:
|
||||
if self.data_format not in [torch.contiguous_format]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
|
||||
@ -241,11 +241,8 @@ def _check_input_constraints_pre_hook(self, args, kwargs):
|
||||
_check_inputs_match(args, kwargs, self._in_spec)
|
||||
return
|
||||
|
||||
# NOTE: for some reason, Dynamo is tracing into this, we should see why and
|
||||
# put compile at the right place. Until then, we can skip the input
|
||||
# constraint checks.
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
_check_input_constraints_for_module(self, args, kwargs)
|
||||
# NOTE: this call is Dynamo disabled, as it used to be
|
||||
_check_input_constraints_for_module(self, args, kwargs)
|
||||
|
||||
|
||||
def _unlift_inputs_as_getattr(
|
||||
|
||||
@ -65,76 +65,3 @@ def get_node_context(node, num_nodes=2) -> str:
|
||||
break
|
||||
cur = cur.prev
|
||||
return "\n".join(node_contexts[::-1])
|
||||
|
||||
|
||||
def map_recorded_events_to_aten_ops_with_stack_trace(graph_module, traced_data):
|
||||
"""
|
||||
Maps recorded profiler events to their corresponding aten operations and adds stack traces.
|
||||
|
||||
Args:
|
||||
graph_module: The FX GraphModule
|
||||
traced_data: Json of profiler events from Chrome trace
|
||||
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
|
||||
# Create a mapping from node name to node for easy lookup
|
||||
node_map = {node.name: node for node in graph_module.graph.nodes}
|
||||
|
||||
|
||||
# Find aten operation events
|
||||
aten_events = [e for e in trace_events if e.get("cat") == "cpu_op"]
|
||||
|
||||
# Map recorded events to aten ops and add stack traces
|
||||
event_mapping = {}
|
||||
|
||||
for recorded_event in trace_events:
|
||||
if (recorded_event.get("cat") in ["cpu_op"] and
|
||||
recorded_event.get("name", "").startswith("## ") and
|
||||
recorded_event.get("name", "").endswith(" ##")):
|
||||
# Extract node name from "## node_name ##"
|
||||
node_name = recorded_event["name"][3:-3] # Remove "## " and " ##"
|
||||
|
||||
if node_name in node_map:
|
||||
node = node_map[node_name]
|
||||
|
||||
# Find corresponding aten operations within this recorded event's time window
|
||||
recorded_start = recorded_event["ts"]
|
||||
recorded_end = recorded_start + recorded_event["dur"]
|
||||
|
||||
# Find aten ops that fall within this time window
|
||||
corresponding_aten_ops = []
|
||||
for aten_event in aten_events:
|
||||
aten_start = aten_event["ts"]
|
||||
aten_end = aten_start + aten_event["dur"]
|
||||
|
||||
# Check if aten event overlaps with recorded event
|
||||
if (aten_start >= recorded_start and aten_start <= recorded_end) or \
|
||||
(aten_end >= recorded_start and aten_end <= recorded_end) or \
|
||||
(aten_start <= recorded_start and aten_end >= recorded_end):
|
||||
corresponding_aten_ops.append(aten_event)
|
||||
|
||||
# Add stack trace to recorded event and aten ops
|
||||
stack_trace = node.meta.get("stack_trace", "No stack trace available")
|
||||
|
||||
# Add stack trace to the recorded event
|
||||
if "args" not in recorded_event:
|
||||
recorded_event["args"] = {}
|
||||
recorded_event["args"]["stack_trace"] = stack_trace
|
||||
|
||||
# Add stack trace to corresponding aten ops
|
||||
for aten_op in corresponding_aten_ops:
|
||||
if "args" not in aten_op:
|
||||
aten_op["args"] = {}
|
||||
aten_op["args"]["stack_trace"] = stack_trace
|
||||
|
||||
event_mapping[node_name] = {
|
||||
"recorded_event": recorded_event,
|
||||
"aten_operations": corresponding_aten_ops,
|
||||
"node": node,
|
||||
"stack_trace": stack_trace
|
||||
}
|
||||
|
||||
return event_mapping
|
||||
|
||||
@ -440,7 +440,6 @@ class CodeGen:
|
||||
colored: bool = False,
|
||||
# Render each argument on its own line
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
free_vars: list[str] = []
|
||||
body: list[str] = []
|
||||
@ -778,13 +777,8 @@ class CodeGen:
|
||||
# node index, which will be deleted later
|
||||
# after going through _body_transformer
|
||||
body.append(f"# COUNTER: {i}\n")
|
||||
do_record = record_func and node.op in ("call_function", "call_method", "call_module")
|
||||
if do_record:
|
||||
body.append(f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {node.name} ##'); _rf_{node.name}.__enter__()\n")
|
||||
emit_node(node)
|
||||
delete_unused_values(node)
|
||||
if do_record:
|
||||
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@ -1215,9 +1209,6 @@ class Graph:
|
||||
name = self._graph_namespace.create_name(candidate, None)
|
||||
n = Node(self, name, op, target, args, kwargs, type_expr)
|
||||
|
||||
# print(name)
|
||||
# breakpoint()
|
||||
|
||||
if (
|
||||
self.owning_module is not None
|
||||
and getattr(self.owning_module, "_create_node_hooks", None) is not None
|
||||
@ -1642,7 +1633,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
"""
|
||||
Turn this ``Graph`` into valid Python code.
|
||||
@ -1710,7 +1700,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def _python_code(
|
||||
@ -1723,7 +1712,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
return self._codegen._gen_python_code(
|
||||
self.nodes,
|
||||
@ -1734,7 +1722,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -161,7 +161,6 @@ class Interpreter:
|
||||
delay=0,
|
||||
)
|
||||
|
||||
print("running inside interpreter")
|
||||
for node in self.graph.nodes:
|
||||
pbar.update(1)
|
||||
if node in self.env:
|
||||
|
||||
@ -95,27 +95,6 @@ bits4x2
|
||||
bits8
|
||||
bits16
|
||||
|
||||
# torch/headeronly/core/DeviceType.h
|
||||
DeviceType
|
||||
kCPU
|
||||
kCUDA
|
||||
kHIP
|
||||
kFPGA
|
||||
kMAIA
|
||||
kXLA
|
||||
kMPS
|
||||
kMeta
|
||||
kVulkan
|
||||
kMetal
|
||||
kXPU
|
||||
kHPU
|
||||
kVE
|
||||
kLazy
|
||||
kIPU
|
||||
kMTIA
|
||||
kPrivateUse1
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES
|
||||
|
||||
# torch/headeronly/core/ScalarType.h
|
||||
NumScalarTypes
|
||||
ScalarType
|
||||
|
||||
@ -1,125 +0,0 @@
|
||||
// This is directly synchronized with caffe2/proto/caffe2.proto, but
|
||||
// doesn't require me to figure out how to get Protobuf headers into
|
||||
// ATen/core (which would require a lot more build system hacking.)
|
||||
// If you modify me, keep me synchronized with that file.
|
||||
|
||||
#include <torch/headeronly/macros/Export.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// These contains all device types that also have a BackendComponent
|
||||
// and therefore participate in per-backend functionality dispatch keys.
|
||||
// This is most backends except PrivateUse2 and PrivateUse3
|
||||
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
|
||||
_(CPU, extra) \
|
||||
_(CUDA, extra) \
|
||||
_(HIP, extra) \
|
||||
_(XLA, extra) \
|
||||
_(MPS, extra) \
|
||||
_(IPU, extra) \
|
||||
_(XPU, extra) \
|
||||
_(HPU, extra) \
|
||||
_(VE, extra) \
|
||||
_(Lazy, extra) \
|
||||
_(Meta, extra) \
|
||||
_(MTIA, extra) \
|
||||
_(PrivateUse1, extra)
|
||||
|
||||
enum class DeviceType : int8_t {
|
||||
CPU = 0,
|
||||
CUDA = 1, // CUDA.
|
||||
MKLDNN = 2, // Reserved for explicit MKLDNN
|
||||
OPENGL = 3, // OpenGL
|
||||
OPENCL = 4, // OpenCL
|
||||
IDEEP = 5, // IDEEP.
|
||||
HIP = 6, // AMD HIP
|
||||
FPGA = 7, // FPGA
|
||||
MAIA = 8, // ONNX Runtime / Microsoft
|
||||
XLA = 9, // XLA / TPU
|
||||
Vulkan = 10, // Vulkan
|
||||
Metal = 11, // Metal
|
||||
XPU = 12, // XPU
|
||||
MPS = 13, // MPS
|
||||
Meta = 14, // Meta (tensors with no data)
|
||||
HPU = 15, // HPU / HABANA
|
||||
VE = 16, // SX-Aurora / NEC
|
||||
Lazy = 17, // Lazy Tensors
|
||||
IPU = 18, // Graphcore IPU
|
||||
MTIA = 19, // Meta training and inference devices
|
||||
PrivateUse1 = 20, // PrivateUse1 device
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in c10/core/DeviceType.cpp
|
||||
// - Change the number below
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
|
||||
};
|
||||
|
||||
constexpr DeviceType kCPU = DeviceType::CPU;
|
||||
constexpr DeviceType kCUDA = DeviceType::CUDA;
|
||||
constexpr DeviceType kHIP = DeviceType::HIP;
|
||||
constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||
constexpr DeviceType kMAIA = DeviceType::MAIA;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
constexpr DeviceType kMPS = DeviceType::MPS;
|
||||
constexpr DeviceType kMeta = DeviceType::Meta;
|
||||
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
||||
constexpr DeviceType kMetal = DeviceType::Metal;
|
||||
constexpr DeviceType kXPU = DeviceType::XPU;
|
||||
constexpr DeviceType kHPU = DeviceType::HPU;
|
||||
constexpr DeviceType kVE = DeviceType::VE;
|
||||
constexpr DeviceType kLazy = DeviceType::Lazy;
|
||||
constexpr DeviceType kIPU = DeviceType::IPU;
|
||||
constexpr DeviceType kMTIA = DeviceType::MTIA;
|
||||
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
|
||||
|
||||
// define explicit int constant
|
||||
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
|
||||
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
|
||||
static_assert(
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
|
||||
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
|
||||
"for this constant to reflect the actual number of DeviceTypes we support "
|
||||
"in PyTorch; it's important that this number is not too large as we "
|
||||
"use this to allocate stack arrays in some places in our code. If you "
|
||||
"are indeed just adding the 20th device type, feel free to change "
|
||||
"the check to 32; but if you are adding some sort of extensible device "
|
||||
"types registration, please be aware that you are affecting code that "
|
||||
"this number is small. Try auditing uses of this constant.");
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::DeviceType> {
|
||||
std::size_t operator()(c10::DeviceType k) const {
|
||||
return std::hash<int>()(static_cast<int>(k));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
namespace torch::headeronly {
|
||||
using c10::COMPILE_TIME_MAX_DEVICE_TYPES;
|
||||
using c10::DeviceType;
|
||||
using c10::kCPU;
|
||||
using c10::kCUDA;
|
||||
using c10::kFPGA;
|
||||
using c10::kHIP;
|
||||
using c10::kHPU;
|
||||
using c10::kIPU;
|
||||
using c10::kLazy;
|
||||
using c10::kMAIA;
|
||||
using c10::kMeta;
|
||||
using c10::kMetal;
|
||||
using c10::kMPS;
|
||||
using c10::kMTIA;
|
||||
using c10::kPrivateUse1;
|
||||
using c10::kVE;
|
||||
using c10::kVulkan;
|
||||
using c10::kXLA;
|
||||
using c10::kXPU;
|
||||
} // namespace torch::headeronly
|
||||
@ -427,7 +427,7 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
|
||||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||||
elif torch.is_signed(input) or dtype == torch.uint8:
|
||||
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
|
||||
elif op_name == "logsumexp":
|
||||
elif op_name in {"logsumexp"}:
|
||||
if torch.is_floating_point(input):
|
||||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||||
elif torch.is_complex(input):
|
||||
|
||||
@ -3272,13 +3272,11 @@ def gaussian_nll_loss(
|
||||
if input.size()[:-1] == var.size():
|
||||
var = torch.unsqueeze(var, -1)
|
||||
|
||||
# This checks if the var is broadcastable to the input and there is only one mismatched dimension.
|
||||
# This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
|
||||
# This is also a homoscedastic case.
|
||||
# e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
|
||||
# or input.size = (4, 3, 32, 32), var.size = (4, 1, 32, 32)
|
||||
elif (
|
||||
input.ndim == var.ndim
|
||||
and sum(y for x, y in zip(input.size(), var.size()) if x != y) == 1
|
||||
input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1
|
||||
): # Heteroscedastic case
|
||||
pass
|
||||
|
||||
|
||||
@ -432,18 +432,15 @@ def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
('reduction_none', {'reduction': 'none'}),
|
||||
('homoscedastic', {'homoscedastic': True}),
|
||||
]
|
||||
|
||||
module_inputs = []
|
||||
for desc, constructor_kwargs in cases:
|
||||
homoscedastic = constructor_kwargs.pop('homoscedastic', False)
|
||||
var_input = make_input(1, 3).abs() if homoscedastic else make_input(4, 1).abs()
|
||||
module_inputs.append(
|
||||
ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
|
||||
forward_input=FunctionInput(make_input(4, 3),
|
||||
make_target(4, 3),
|
||||
var_input),
|
||||
forward_input=FunctionInput(make_input(3),
|
||||
make_target(3),
|
||||
make_input(1).abs()),
|
||||
desc=desc,
|
||||
reference_fn=no_batch_dim_reference_fn)
|
||||
)
|
||||
|
||||
@ -765,7 +765,7 @@ class QuantizationTestCase(TestCase):
|
||||
and not isinstance(module, _FusedModule)
|
||||
):
|
||||
for child in module.children():
|
||||
if type(child) is nn.Dropout:
|
||||
if type(child) in [nn.Dropout]:
|
||||
continue
|
||||
self.checkObservers(
|
||||
child, propagate_qconfig_list, prepare_custom_config_dict
|
||||
|
||||
@ -192,7 +192,7 @@ class Trainer:
|
||||
self.hybrid_module = HybridModel(
|
||||
self.remote_em_rref,
|
||||
self.remote_net_rref,
|
||||
self.trainer_group if ddp_mode == DdpMode.INSIDE else None,
|
||||
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
|
||||
)
|
||||
self.ddp_params, self.non_ddp_params = (
|
||||
self.hybrid_module.ddp_params,
|
||||
|
||||
@ -707,7 +707,7 @@ class DistributedTest:
|
||||
self.assertNotEqual(args.get("dtype", ""), "")
|
||||
|
||||
per_coll_meta[collname].append(args)
|
||||
if collname == "wait":
|
||||
if collname in {"wait"}:
|
||||
continue
|
||||
|
||||
self.assertEqual(args["Process Group Description"], "default_pg")
|
||||
@ -7029,7 +7029,7 @@ class DistributedTest:
|
||||
self.assertNotEqual(attrs.get("dtype", ""), "")
|
||||
|
||||
per_coll_meta[collname].append(attrs)
|
||||
if collname == "wait":
|
||||
if collname in {"wait"}:
|
||||
continue
|
||||
|
||||
self.assertEqual(attrs["pg_name"], "0") # yes this is a string
|
||||
|
||||
@ -125,7 +125,7 @@ class DebugMode(TorchDispatchMode):
|
||||
_get_current_dispatch_mode(), FakeTensorMode
|
||||
):
|
||||
if self.record_faketensor:
|
||||
if func != torch.ops.prim.device.default:
|
||||
if func not in {torch.ops.prim.device.default}:
|
||||
self.operators.append((func, args, kwargs, self.call_depth + 1))
|
||||
elif len(types) == 0:
|
||||
if self.record_realtensor:
|
||||
|
||||
@ -103,7 +103,7 @@ class Capture:
|
||||
def __getattr__(self, attrname):
|
||||
if attrname == "kwarg" or attrname == "kwargs":
|
||||
raise RuntimeError("no kwargs!")
|
||||
if attrname == "__deepcopy__":
|
||||
if attrname in ["__deepcopy__"]:
|
||||
raise AttributeError
|
||||
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
|
||||
return result
|
||||
|
||||
@ -783,7 +783,7 @@ class _FlopCounterMode(TorchDispatchMode):
|
||||
return result, flop_counts
|
||||
|
||||
def _handle_higher_order_ops(self, func, types, args, kwargs):
|
||||
if func is not torch.ops.higher_order.cond:
|
||||
if func not in {torch.ops.higher_order.cond, }:
|
||||
return NotImplemented
|
||||
|
||||
# The flop counter for cond counts the upper bound of flops.
|
||||
|
||||
Reference in New Issue
Block a user