mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
Compare commits
18 Commits
ciflow/tru
...
annotate_f
Author | SHA1 | Date | |
---|---|---|---|
98826fd37b | |||
585b9dbb5e | |||
d795fb225a | |||
7df9aca529 | |||
d4a713cd9c | |||
5daef30b26 | |||
6dedd34c31 | |||
a303d6dda9 | |||
7669ac9402 | |||
86fd4fc23e | |||
99097b6d89 | |||
a214371008 | |||
7d87d7052e | |||
1a34ff4e04 | |||
fe5ccb1a74 | |||
85586d7efc | |||
e1d71a6b35 | |||
d61a9b88cf |
12
.github/scripts/generate_binary_build_matrix.py
vendored
12
.github/scripts/generate_binary_build_matrix.py
vendored
@ -241,7 +241,11 @@ def generate_libtorch_matrix(
|
||||
arches += CUDA_ARCHES
|
||||
arches += ROCM_ARCHES
|
||||
elif os == "windows":
|
||||
arches += CUDA_ARCHES
|
||||
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
|
||||
# in 2.10
|
||||
windows_cuda_arches = CUDA_ARCHES.copy()
|
||||
windows_cuda_arches.remove("12.9")
|
||||
arches += windows_cuda_arches
|
||||
if libtorch_variants is None:
|
||||
libtorch_variants = [
|
||||
"shared-with-deps",
|
||||
@ -305,7 +309,11 @@ def generate_wheels_matrix(
|
||||
if os == "linux":
|
||||
arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES
|
||||
elif os == "windows":
|
||||
arches += CUDA_ARCHES + XPU_ARCHES
|
||||
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
|
||||
# in 2.10
|
||||
windows_cuda_arches = CUDA_ARCHES.copy()
|
||||
windows_cuda_arches.remove("12.9")
|
||||
arches += windows_cuda_arches + XPU_ARCHES
|
||||
elif os == "linux-aarch64":
|
||||
# Separate new if as the CPU type is different and
|
||||
# uses different build/test scripts
|
||||
|
2
.github/workflows/_linux-build.yml
vendored
2
.github/workflows/_linux-build.yml
vendored
@ -37,7 +37,7 @@ on:
|
||||
runner:
|
||||
required: false
|
||||
type: string
|
||||
default: "linux.2xlarge"
|
||||
default: "linux.c7i.2xlarge"
|
||||
description: |
|
||||
Label of the runner this job should run on.
|
||||
test-matrix:
|
||||
|
250
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
250
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
@ -788,256 +788,6 @@ jobs:
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda12_9-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-debug
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cuda12_9-shared-with-deps-debug-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cuda12_9-shared-with-deps-debug-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-debug
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
libtorch-cuda12_9-shared-with-deps-debug-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: libtorch-cuda12_9-shared-with-deps-debug-test
|
||||
with:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
build_name: libtorch-cuda12_9-shared-with-deps-debug
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda13_0-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
|
250
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
250
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
@ -788,256 +788,6 @@ jobs:
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda12_9-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-release
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cuda12_9-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cuda12_9-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cuda12_9-shared-with-deps-release
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: libtorch-cuda12_9-shared-with-deps-release-test
|
||||
with:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu129
|
||||
GPU_ARCH_VERSION: "12.9"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
build_name: libtorch-cuda12_9-shared-with-deps-release
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
libtorch-cuda13_0-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
|
1666
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
1666
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
4
.github/workflows/lint.yml
vendored
4
.github/workflows/lint.yml
vendored
@ -118,9 +118,9 @@ jobs:
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
echo "Running all other linters"
|
||||
if [ "$CHANGED_FILES" = '*' ]; then
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||
else
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
fi
|
||||
|
||||
quick-checks:
|
||||
|
@ -209,6 +209,46 @@ command = [
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
|
||||
[[linter]]
|
||||
code = 'PYREFLY'
|
||||
include_patterns = [
|
||||
'torch/**/*.py',
|
||||
'torch/**/*.pyi',
|
||||
'torchgen/**/*.py',
|
||||
'torchgen/**/*.pyi',
|
||||
'functorch/**/*.py',
|
||||
'functorch/**/*.pyi',
|
||||
]
|
||||
exclude_patterns = []
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pyrefly_linter.py',
|
||||
'--config=pyrefly.toml',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
'sympy==1.13.3',
|
||||
'types-requests==2.27.25',
|
||||
'types-pyyaml==6.0.2',
|
||||
'types-tabulate==0.8.8',
|
||||
'types-protobuf==5.29.1.20250403',
|
||||
'types-setuptools==79.0.0.20250422',
|
||||
'types-jinja2==2.11.9',
|
||||
'types-colorama==0.4.6',
|
||||
'filelock==3.18.0',
|
||||
'junitparser==2.1.1',
|
||||
'rich==14.1.0',
|
||||
'optree==0.17.0',
|
||||
'types-openpyxl==3.1.5.20250919',
|
||||
'types-python-dateutil==2.9.0.20251008'
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'CLANGTIDY'
|
||||
include_patterns = [
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
namespace at::autocast {
|
||||
@ -36,10 +37,29 @@ namespace {
|
||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||
using val_type = std::tuple<weakref_type, Tensor>;
|
||||
|
||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
|
||||
return cached_casts;
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
|
||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
|
||||
// are not incorrectly reused in gradient-enabled contexts.
|
||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
|
||||
return cached_casts_grad_enabled;
|
||||
}
|
||||
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
|
||||
return cached_casts_grad_disabled;
|
||||
}
|
||||
|
||||
// Helper function to get the appropriate cache based on current gradient mode.
|
||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
|
||||
// preventing incorrect cache hits when gradient mode changes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
return at::GradMode::is_enabled() ?
|
||||
get_cached_casts_grad_enabled() :
|
||||
get_cached_casts_grad_disabled();
|
||||
}
|
||||
|
||||
std::mutex cached_casts_mutex;
|
||||
|
||||
|
||||
@ -86,7 +106,9 @@ thread_local bool cache_enabled = true;
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
get_cached_casts().clear();
|
||||
// Clear both caches to ensure consistent behavior regardless of current gradient mode
|
||||
get_cached_casts_grad_enabled().clear();
|
||||
get_cached_casts_grad_disabled().clear();
|
||||
}
|
||||
|
||||
int increment_nesting() {
|
||||
@ -121,6 +143,11 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
||||
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
||||
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
||||
// See cached_casts declaration above for detailed strategy.
|
||||
//
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes
|
||||
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
|
||||
// with torch.autocast(), while maintaining optimal performance for both training and inference.
|
||||
// This fixes issue #158232 without any performance regression.
|
||||
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
||||
|
@ -1759,6 +1759,7 @@ enum class ScaledGemmImplementation {
|
||||
MXFP8_MXFP8 = 6,
|
||||
NVFP4_NVFP4 = 7,
|
||||
NVFP4_NVFP4_SINGLE_SCALE = 8,
|
||||
MXFP4_MXFP4 = 9,
|
||||
};
|
||||
|
||||
/**
|
||||
@ -1955,10 +1956,39 @@ bool check_mxfp8_recipe(c10::ScalarType type_a,
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp4
|
||||
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
|
||||
*/
|
||||
bool check_mxfp4_recipe(c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp4
|
||||
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scales, 1 recipes for each input
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
using namespace std::placeholders;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{
|
||||
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
|
||||
@ -1969,7 +1999,8 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
|
||||
ScaledGemmImplementation::BLOCK_1x128_1x128},
|
||||
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
|
||||
|
||||
Tensor&
|
||||
_scaled_tensorwise_tensorwise(
|
||||
@ -2187,15 +2218,22 @@ _scaled_mxfp8_mxfp8(
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
auto scale_a_elems = ceil_div<int64_t>(mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
@ -2225,6 +2263,56 @@ _scaled_mxfp8_mxfp8(
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
|
||||
Tensor&
|
||||
_scaled_mxfp4_mxfp4(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
const Tensor& scale_a, const SwizzleType swizzle_a,
|
||||
const Tensor& scale_b, const SwizzleType swizzle_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
#if ROCM_VERSION >= 70000
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
|
||||
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
|
||||
|
||||
TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 &&
|
||||
mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0,
|
||||
"Matrix dimensions must be multiples of 32 for block-wise scaling");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_nvfp4_nvfp4(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -2468,6 +2556,8 @@ _scaled_mm_cuda_v2_out(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
|
||||
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
|
||||
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
|
||||
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really");
|
||||
}
|
||||
|
@ -184,15 +184,23 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
|
||||
0 & \text{ else }
|
||||
\end{cases}
|
||||
*/
|
||||
float scale_val = scale[0].item<float>();
|
||||
float inv_scale_val = 1.0f / scale_val;
|
||||
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
|
||||
|
||||
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
|
||||
bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
|
||||
|
||||
at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
|
||||
at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
|
||||
at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
|
||||
at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
|
||||
|
||||
float scale_val = scale_[0].item<float>();
|
||||
float inv_scale_val = 1.0f / scale_val;
|
||||
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false);
|
||||
|
||||
TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size");
|
||||
TORCH_CHECK(
|
||||
quant_min <= 0 && quant_max >= 0,
|
||||
"`quant_min` should be less than or \
|
||||
@ -200,28 +208,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
|
||||
TORCH_CHECK(
|
||||
zero_point_val >= quant_min && zero_point_val <= quant_max,
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
if (X.numel() <= 0) {
|
||||
if (X_.numel() <= 0) {
|
||||
return std::make_tuple(X, scale, zero_point);
|
||||
}
|
||||
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
.add_output(dX)
|
||||
.add_output(dScale_vec)
|
||||
.add_output(dZeroPoint_vec)
|
||||
.add_input(X)
|
||||
.add_input(dY)
|
||||
.add_input(X_)
|
||||
.add_input(dY_)
|
||||
.build();
|
||||
|
||||
fake_quant_grad_learnable_tensor_stub(
|
||||
X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
|
||||
X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
|
||||
|
||||
// The total sums over the scale and zero point gradient vectors are what will be returned in the end.
|
||||
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
|
||||
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
|
||||
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device());
|
||||
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device());
|
||||
|
||||
return std::make_tuple(dX, dScale, dZeroPoint);
|
||||
}
|
||||
|
@ -1729,10 +1729,8 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/backends/backend_debug_info.cpp",
|
||||
"torch/csrc/jit/backends/backend_interface.cpp",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags() + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
|
||||
}),
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
link_whole = True,
|
||||
linker_flags = get_no_as_needed_linker_flag(),
|
||||
@ -2025,9 +2023,6 @@ def define_buck_targets(
|
||||
"ovr_config//os:android-x86_64": [
|
||||
"-mssse3",
|
||||
],
|
||||
}) + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
|
||||
}),
|
||||
exported_preprocessor_flags = get_aten_preprocessor_flags(),
|
||||
exported_deps = [
|
||||
|
3
cmake/External/aotriton.cmake
vendored
3
cmake/External/aotriton.cmake
vendored
@ -244,7 +244,8 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
else()
|
||||
set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}")
|
||||
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX)
|
||||
if(${__AOTRITON_RUNTIME_INDEX} LESS 0)
|
||||
# Always build aotriton runtime from source on Windows due to lack of pre-built binaries
|
||||
if(${__AOTRITON_RUNTIME_INDEX} LESS 0 OR WIN32)
|
||||
message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \
|
||||
Build runtime from source")
|
||||
aotriton_build_from_source(ON aotriton_runtime)
|
||||
|
@ -1,5 +1,7 @@
|
||||
# A Pyrefly configuration for PyTorch
|
||||
# Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini
|
||||
python-version = "3.12"
|
||||
|
||||
project-includes = [
|
||||
"torch",
|
||||
"caffe2",
|
||||
@ -36,6 +38,7 @@ project-excludes = [
|
||||
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
|
||||
"torch/_inductor/codecache.py",
|
||||
"torch/distributed/elastic/metrics/__init__.py",
|
||||
"torch/_inductor/fx_passes/bucketing.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
@ -266,6 +266,41 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
compiled_out = compiled_fn(mesh)
|
||||
self.assertEqual(opt_fn, compiled_out)
|
||||
|
||||
def test_get_local_rank_compile(self):
|
||||
mesh = init_device_mesh(
|
||||
self.device_type, (self.world_size,), mesh_dim_names=("dp",)
|
||||
)
|
||||
|
||||
def fn_with_str_arg(x):
|
||||
local_rank = x.device_mesh.get_local_rank("dp")
|
||||
return x * local_rank
|
||||
|
||||
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
|
||||
ref = fn_with_str_arg(x)
|
||||
|
||||
opt_fn = torch.compile(fn_with_str_arg, backend="aot_eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def fn_with_int_arg(x):
|
||||
local_rank = x.device_mesh.get_local_rank(0)
|
||||
return x * local_rank
|
||||
|
||||
ref2 = fn_with_int_arg(x)
|
||||
opt_fn2 = torch.compile(fn_with_int_arg, backend="aot_eager", fullgraph=True)
|
||||
res2 = opt_fn2(x)
|
||||
self.assertEqual(res2, ref2)
|
||||
|
||||
def fn_without_arg(x):
|
||||
# will fail if device_mesh.ndim > 1
|
||||
local_rank = x.device_mesh.get_local_rank()
|
||||
return x + local_rank
|
||||
|
||||
ref3 = fn_without_arg(x)
|
||||
opt_fn3 = torch.compile(fn_without_arg, backend="aot_eager", fullgraph=True)
|
||||
res3 = opt_fn3(x)
|
||||
self.assertEqual(res3, ref3)
|
||||
|
||||
def test_fakify_dtensor(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
|
||||
def test_remap_to_tensor(self):
|
||||
"""Test the remap_to_tensor method for various scenarios."""
|
||||
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
||||
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
|
||||
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
||||
result1 = layout1.remap_to_tensor(original_mesh)
|
||||
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# Test 2: Non-consecutive ranks - should map to actual ranks
|
||||
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
|
||||
layout2 = _Layout((2, 2), (2, 1))
|
||||
result2 = layout2.remap_to_tensor(original_mesh)
|
||||
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
||||
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
|
||||
self.assertEqual(result5, expected5)
|
||||
|
||||
# Test 6: Tensor Cute representation of a 2D mesh
|
||||
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
|
||||
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
|
||||
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
||||
result6 = layout6.remap_to_tensor(original_mesh)
|
||||
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
|
@ -1804,6 +1804,63 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
correct = f(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all_custom_ops_multidtype"])
|
||||
def test_all_gather_bucket_multidtype(self, bucket_mode):
|
||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
|
||||
ag_0_w = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0, group_size, group_name
|
||||
)
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_w)
|
||||
ag_0_out = ag_0_out * 2
|
||||
|
||||
ag_1_w = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_1, group_size, group_name
|
||||
)
|
||||
|
||||
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_w)
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.bfloat16)
|
||||
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
correct = func(*inputs, **self.get_world_trs())
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": bucket_mode,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
(
|
||||
FileCheck()
|
||||
.check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
|
||||
count=1,
|
||||
exactly=True,
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
_, y_ag0, y_ag1 = out
|
||||
assert y_ag0.dtype == ag_0.dtype
|
||||
assert y_ag1.dtype == ag_1.dtype
|
||||
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||
|
@ -294,7 +294,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize("gather_dim", [0, 1, 2])
|
||||
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
|
||||
self._init_process()
|
||||
|
||||
@ -306,7 +306,10 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
rank = self.rank
|
||||
|
||||
torch.manual_seed(42 + rank)
|
||||
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
||||
A_shard_shape = [BATCH, M, K]
|
||||
A_shard_shape[gather_dim] //= self.world_size
|
||||
|
||||
A_shard = torch.rand(A_shard_shape, device="cuda")
|
||||
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
|
||||
|
||||
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
|
||||
@ -523,7 +526,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
BATCH = 8
|
||||
M = 64
|
||||
N = 16
|
||||
K = 32
|
||||
K = 1024
|
||||
group = dist.group.WORLD
|
||||
rank = self.rank
|
||||
|
||||
|
@ -57,7 +57,7 @@ def graph_capture(model, inputs, with_export):
|
||||
with ExitStack() as stack:
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack,
|
||||
model,
|
||||
gm,
|
||||
inputs,
|
||||
)
|
||||
return joint_with_descriptors.graph_module
|
||||
@ -922,6 +922,46 @@ class inner_f(torch.nn.Module):
|
||||
in custom_metadata
|
||||
)
|
||||
|
||||
def test_preserve_annotate_function(self):
|
||||
"""Test basic annotate_fn usage"""
|
||||
|
||||
@fx_traceback.annotate_fn({"pp_stage": 1})
|
||||
def example_function(x):
|
||||
return x * x
|
||||
|
||||
class SimpleLinear(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 2)
|
||||
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
y = self.linear(x)
|
||||
y = example_function(y)
|
||||
return y - 1
|
||||
|
||||
inputs = (torch.randn(4, 3),)
|
||||
model = SimpleLinear()
|
||||
|
||||
for with_export in [True, False]:
|
||||
graph_module = graph_capture(model, inputs, with_export)
|
||||
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 't', {'pp_stage': 0})
|
||||
('call_function', 'addmm', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 1})
|
||||
('call_function', 'mul_1', {'pp_stage': 1})
|
||||
('call_function', 'mul_2', {'pp_stage': 1})
|
||||
('call_function', 't_1', {'pp_stage': 0})
|
||||
('call_function', 'mm', {'pp_stage': 0})
|
||||
('call_function', 't_2', {'pp_stage': 0})
|
||||
('call_function', 'sum_1', {'pp_stage': 0})
|
||||
('call_function', 'view', {'pp_stage': 0})
|
||||
('call_function', 't_3', {'pp_stage': 0})""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -589,6 +589,31 @@ class LoopOrderingTest(TestCase):
|
||||
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
|
||||
).run(code[0])
|
||||
|
||||
@inductor_config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
"test_configs.max_mm_configs": 4,
|
||||
}
|
||||
)
|
||||
@skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune")
|
||||
def test_interaction_with_multi_template(self):
|
||||
"""
|
||||
Skip MultiTemplateBuffer during loop reordering
|
||||
"""
|
||||
|
||||
@torch.compile
|
||||
def f(x, y):
|
||||
return (x @ y), x + 1
|
||||
|
||||
N = 2
|
||||
x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
|
||||
out, code = run_and_get_code(f, x, y)
|
||||
# didn't fuse due to small savings
|
||||
FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0])
|
||||
|
||||
def test_fuse_with_scalar_shared_memory(self):
|
||||
"""
|
||||
Make sure if we can fuse two nodes sharing a scalar before,
|
||||
|
@ -51,11 +51,18 @@ def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, qu
|
||||
return res.to(dtype)
|
||||
|
||||
# Reference method for the gradients of the fake quantize operator
|
||||
def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device):
|
||||
def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device, dtype):
|
||||
r"""This method references the following literatures for back propagation on scale and zero point.
|
||||
- https://arxiv.org/pdf/1902.08153.pdf
|
||||
- https://arxiv.org/pdf/1903.08066.pdf
|
||||
"""
|
||||
|
||||
if dtype is torch.bfloat16:
|
||||
dY = dY.to(dtype=torch.float32)
|
||||
X = X.to(dtype=torch.float32)
|
||||
scale = scale.to(dtype=torch.float32)
|
||||
zero_point = zero_point.to(dtype=torch.float32)
|
||||
|
||||
zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item())
|
||||
Xq = torch.round(X * (1.0 / scale) + zero_point_rounded)
|
||||
|
||||
@ -87,6 +94,12 @@ def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero
|
||||
|
||||
grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0)
|
||||
grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0)
|
||||
|
||||
if dtype is torch.bfloat16:
|
||||
grad_X = grad_X.to(torch.bfloat16)
|
||||
grad_scale = grad_scale.to(torch.bfloat16)
|
||||
grad_zp = grad_zp.to(torch.bfloat16)
|
||||
|
||||
return grad_X, grad_scale, grad_zp
|
||||
|
||||
|
||||
@ -467,7 +480,7 @@ class TestFakeQuantizeOps(TestCase):
|
||||
self._test_learnable_forward_per_tensor(
|
||||
X, 'cuda', scale_base, zero_point_base)
|
||||
|
||||
def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base):
|
||||
def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base, dtype=torch.float32):
|
||||
r"""Tests the backward method with additional backprop support for scale and zero point.
|
||||
"""
|
||||
X_base = torch.tensor(X).to(device)
|
||||
@ -475,7 +488,7 @@ class TestFakeQuantizeOps(TestCase):
|
||||
for n_bits in (4, 8):
|
||||
quant_min, quant_max = 0, 2 ** n_bits - 1
|
||||
|
||||
X = X_base.clone().float().to(device)
|
||||
X = X_base.clone().to(device)
|
||||
X.requires_grad_()
|
||||
scale_base = scale_base.to(device)
|
||||
zero_point_base = zero_point_base.to(device)
|
||||
@ -488,7 +501,7 @@ class TestFakeQuantizeOps(TestCase):
|
||||
X, scale, zero_point, quant_min, quant_max, grad_factor).to(device)
|
||||
dout = torch.rand_like(X, dtype=torch.float).to(device)
|
||||
dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference(
|
||||
dout, X, scale, zero_point, quant_min, quant_max, device)
|
||||
dout, X, scale, zero_point, quant_min, quant_max, device, dtype)
|
||||
Y_prime.backward(dout)
|
||||
|
||||
expected_dX = dX.to(device).detach()
|
||||
@ -525,17 +538,20 @@ class TestFakeQuantizeOps(TestCase):
|
||||
self._test_learnable_backward_per_tensor(
|
||||
X, 'cpu', scale_base, zero_point_base)
|
||||
|
||||
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
|
||||
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
|
||||
qparams=hu.qparams(dtypes=torch.quint8)))
|
||||
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
|
||||
def test_learnable_backward_per_tensor_cuda(self, X):
|
||||
torch.random.manual_seed(NP_RANDOM_SEED)
|
||||
X, (_, _, _) = X
|
||||
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
|
||||
zero_point_base = torch.normal(mean=0, std=128, size=(1,))
|
||||
self._test_learnable_backward_per_tensor(
|
||||
X, 'cuda', scale_base, zero_point_base)
|
||||
def test_learnable_backward_per_tensor_cuda(self):
|
||||
# setting seed to avoid increasing tolerance due to cases where
|
||||
# difference in Python vs CPP downcasting causes tensor mismatches
|
||||
# e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op
|
||||
torch.random.manual_seed(12)
|
||||
x_shape = (2, 1)
|
||||
|
||||
for dtype in [torch.bfloat16, torch.float32]:
|
||||
X_base = torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100).to(dtype=dtype)
|
||||
zero_point_base = torch.normal(mean=0, std=128, size=(1,)).to(dtype=dtype)
|
||||
self._test_learnable_backward_per_tensor(
|
||||
X_base, 'cuda', scale_base, zero_point_base, dtype)
|
||||
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
|
||||
|
@ -384,6 +384,143 @@ class TestTorchAutocast(TestCase):
|
||||
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
|
||||
torch.autocast(device_type=dev)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_nograd_caching_issue_158232(self):
|
||||
"""
|
||||
Regression test for issue #158232: autocast + no_grad incompatibility
|
||||
|
||||
When torch.no_grad() is nested inside torch.autocast(), the autocast cache
|
||||
must not cache tensors created in the no_grad context, because they lack
|
||||
gradient tracking. If cached, subsequent operations in gradient-enabled mode
|
||||
would incorrectly use the no-gradient cached version.
|
||||
|
||||
Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
|
||||
After fix: Should work correctly
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# First forward pass in no_grad context (e.g., shape inference)
|
||||
with torch.no_grad():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(
|
||||
out1.requires_grad, "Output in no_grad should not require grad"
|
||||
)
|
||||
|
||||
# Second forward pass with gradients enabled (e.g., training)
|
||||
out2 = model(inp)
|
||||
self.assertTrue(
|
||||
out2.requires_grad,
|
||||
"Output should require gradients after exiting no_grad",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
out2.grad_fn, "Output should have grad_fn after exiting no_grad"
|
||||
)
|
||||
|
||||
# Backward pass should work
|
||||
loss = out2.mean()
|
||||
loss.backward()
|
||||
|
||||
# Verify gradients were computed
|
||||
self.assertIsNotNone(model.weight.grad)
|
||||
self.assertIsNotNone(model.bias.grad)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_inference_mode_interaction(self):
|
||||
"""
|
||||
Test that autocast works correctly with torch.inference_mode()
|
||||
|
||||
InferenceMode is a stricter version of no_grad that provides additional
|
||||
performance optimizations. Verify it doesn't break with autocast.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
# Test 1: inference_mode inside autocast
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
with torch.inference_mode():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(out1.requires_grad)
|
||||
self.assertEqual(out1.dtype, torch.bfloat16)
|
||||
|
||||
# After exiting inference_mode, gradients should work
|
||||
out2 = model(inp)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
out2.mean().backward()
|
||||
|
||||
# Test 2: autocast inside inference_mode
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
out = model(inp)
|
||||
self.assertFalse(out.requires_grad)
|
||||
self.assertEqual(out.dtype, torch.bfloat16)
|
||||
|
||||
def test_autocast_caching_still_works_with_gradients(self):
|
||||
"""
|
||||
Verify that autocast caching still functions correctly when gradients ARE enabled.
|
||||
|
||||
This test ensures the fix for #158232 didn't break normal caching behavior.
|
||||
We can't directly observe cache hits, but we verify that repeated operations
|
||||
with gradients enabled work correctly.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# Multiple forward passes with gradients enabled
|
||||
out1 = model(inp)
|
||||
out2 = model(inp)
|
||||
out3 = model(inp)
|
||||
|
||||
# All should have gradients
|
||||
self.assertTrue(out1.requires_grad)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
self.assertTrue(out3.requires_grad)
|
||||
|
||||
# All should have grad_fn
|
||||
self.assertIsNotNone(out1.grad_fn)
|
||||
self.assertIsNotNone(out2.grad_fn)
|
||||
self.assertIsNotNone(out3.grad_fn)
|
||||
|
||||
# Backward should work on all
|
||||
out1.mean().backward(retain_graph=True)
|
||||
out2.mean().backward(retain_graph=True)
|
||||
out3.mean().backward()
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_mixed_grad_contexts(self):
|
||||
"""
|
||||
Test complex nesting of gradient contexts within autocast.
|
||||
|
||||
This ensures the gradient mode check works correctly across
|
||||
multiple transitions between gradient-enabled and disabled states.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# Pass 1: no_grad
|
||||
with torch.no_grad():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(out1.requires_grad)
|
||||
|
||||
# Pass 2: gradients enabled
|
||||
out2 = model(inp)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
|
||||
# Pass 3: no_grad again
|
||||
with torch.no_grad():
|
||||
out3 = model(inp)
|
||||
self.assertFalse(out3.requires_grad)
|
||||
|
||||
# Pass 4: gradients enabled again
|
||||
out4 = model(inp)
|
||||
self.assertTrue(out4.requires_grad)
|
||||
|
||||
# Backward on gradient-enabled outputs
|
||||
(out2.mean() + out4.mean()).backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -265,6 +265,12 @@ class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
|
||||
)
|
||||
class ParallelForkServerPerfTest(TestCase):
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info >= (3, 13, 8),
|
||||
"Python 3.13.8+ changed forkserver module caching behavior",
|
||||
# https://docs.python.org/3.13/whatsnew/changelog.html
|
||||
# gh-126631
|
||||
)
|
||||
def test_forkserver_perf(self):
|
||||
|
||||
start_method = 'forkserver'
|
||||
|
@ -152,15 +152,34 @@ def infer_scale_swizzle(mat, scale):
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||
scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
if not torch.version.hip:
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
return None, None
|
||||
|
||||
@ -1211,7 +1230,7 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
self.assertEqual(no_carveout, no_carveout_again)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability == (10, 0):
|
||||
if capability in {(10, 0), (10, 3), (12, 0), (12, 1)}:
|
||||
# expected failure
|
||||
# CUTLASS only supports SM carveout via green contexts on SM100
|
||||
self.assertEqual(no_carveout, carveout_66)
|
||||
@ -1489,7 +1508,7 @@ class TestFP8Matmul(TestCase):
|
||||
assert sqnr.item() > approx_match_sqnr_target
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
|
||||
@parametrize("recipe", ["mxfp8", "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32
|
||||
@ -1503,7 +1522,7 @@ class TestFP8Matmul(TestCase):
|
||||
if recipe == "mxfp8":
|
||||
x_lowp = x.to(e4m3_type)
|
||||
y_lowp = y.to(e4m3_type).t()
|
||||
else: # nvfp4
|
||||
else: # nvfp4 #mxfp4
|
||||
x_lowp = _bfloat16_to_float4_e2m1fn_x2(x.bfloat16())
|
||||
y_lowp = _bfloat16_to_float4_e2m1fn_x2(y.bfloat16()).t()
|
||||
|
||||
@ -1517,7 +1536,10 @@ class TestFP8Matmul(TestCase):
|
||||
if recipe == "nvfp4"
|
||||
else ScalingType.BlockWise1x32
|
||||
)
|
||||
swizzle = SwizzleType.SWIZZLE_32_4_4
|
||||
if torch.version.hip:
|
||||
swizzle = SwizzleType.NO_SWIZZLE
|
||||
else:
|
||||
swizzle = SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# Test wrong scale tensor size for scale_a with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
|
258
tools/linter/adapters/pyrefly_linter.py
Normal file
258
tools/linter/adapters/pyrefly_linter.py
Normal file
@ -0,0 +1,258 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
ADVICE = "advice"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
# Note: This regex pattern is kept for reference but not used for pyrefly JSON parsing
|
||||
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
(?P<line>\d+):
|
||||
(?:(?P<column>-?\d+):)?
|
||||
\s(?P<severity>\S+?):?
|
||||
\s(?P<message>.*)
|
||||
\s(?P<code>\[.*\])
|
||||
$
|
||||
"""
|
||||
)
|
||||
|
||||
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
|
||||
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
(?P<line>\d+):
|
||||
\s(?P<severity>\S+?):?
|
||||
\s(?P<message>INTERNAL\sERROR.*)
|
||||
$
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def run_command(
|
||||
args: list[str],
|
||||
*,
|
||||
extra_env: dict[str, str] | None,
|
||||
retries: int,
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
return subprocess.run(
|
||||
args,
|
||||
capture_output=True,
|
||||
)
|
||||
finally:
|
||||
end_time = time.monotonic()
|
||||
logging.debug("took %dms", (end_time - start_time) * 1000)
|
||||
|
||||
|
||||
# Severity mapping (currently only used for stderr internal errors)
|
||||
# Pyrefly JSON output doesn't include severity, so all errors default to ERROR
|
||||
severities = {
|
||||
"error": LintSeverity.ERROR,
|
||||
"note": LintSeverity.ADVICE,
|
||||
}
|
||||
|
||||
|
||||
def check_pyrefly_installed(code: str) -> list[LintMessage]:
|
||||
cmd = ["pyrefly", "--version"]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
return []
|
||||
except subprocess.CalledProcessError as e:
|
||||
msg = e.stderr.decode(errors="replace")
|
||||
return [
|
||||
LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=code,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"Could not run '{' '.join(cmd)}': {msg}",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def in_github_actions() -> bool:
|
||||
return bool(os.getenv("GITHUB_ACTIONS"))
|
||||
|
||||
|
||||
def check_files(
|
||||
code: str,
|
||||
config: str,
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
pyrefly_commands = [
|
||||
"pyrefly",
|
||||
"check",
|
||||
"--config",
|
||||
config,
|
||||
"--output-format=json",
|
||||
]
|
||||
proc = run_command(
|
||||
[*pyrefly_commands],
|
||||
extra_env={},
|
||||
retries=0,
|
||||
)
|
||||
except OSError as err:
|
||||
return [
|
||||
LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=code,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
|
||||
)
|
||||
]
|
||||
stdout = str(proc.stdout, "utf-8").strip()
|
||||
stderr = str(proc.stderr, "utf-8").strip()
|
||||
if proc.returncode not in (0, 1):
|
||||
return [
|
||||
LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=code,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=stderr,
|
||||
)
|
||||
]
|
||||
|
||||
# Parse JSON output from pyrefly
|
||||
try:
|
||||
if stdout:
|
||||
result = json.loads(stdout)
|
||||
errors = result.get("errors", [])
|
||||
else:
|
||||
errors = []
|
||||
# For now filter out deprecated warnings and only report type errors as warnings
|
||||
# until we remove mypy
|
||||
errors = [error for error in errors if error["name"] != "deprecated"]
|
||||
rc = [
|
||||
LintMessage(
|
||||
path=error["path"],
|
||||
name=error["name"],
|
||||
description=error.get(
|
||||
"description", error.get("concise_description", "")
|
||||
),
|
||||
line=error["line"],
|
||||
char=error["column"],
|
||||
code=code,
|
||||
severity=LintSeverity.ADVICE,
|
||||
# uncomment and replace when we switch to pyrefly
|
||||
# severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR,
|
||||
original=None,
|
||||
replacement=None,
|
||||
)
|
||||
for error in errors
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
return [
|
||||
LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=code,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="json-parse-error",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"Failed to parse pyrefly JSON output: {e}",
|
||||
)
|
||||
]
|
||||
|
||||
# Still check stderr for internal errors
|
||||
rc += [
|
||||
LintMessage(
|
||||
path=match["file"],
|
||||
name="INTERNAL ERROR",
|
||||
description=match["message"],
|
||||
line=int(match["line"]),
|
||||
char=None,
|
||||
code=code,
|
||||
severity=severities.get(match["severity"], LintSeverity.ERROR),
|
||||
original=None,
|
||||
replacement=None,
|
||||
)
|
||||
for match in INTERNAL_ERROR_RE.finditer(stderr)
|
||||
]
|
||||
return rc
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="pyrefly wrapper linter.",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--code",
|
||||
default="PYREFLY",
|
||||
help="the code this lint should report as",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="path to an mypy .ini config file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
format="<%(threadName)s:%(levelname)s> %(message)s",
|
||||
level=logging.INFO,
|
||||
stream=sys.stderr,
|
||||
)
|
||||
|
||||
lint_messages = check_pyrefly_installed(args.code) + check_files(
|
||||
args.code, args.config
|
||||
)
|
||||
for lint_message in lint_messages:
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -299,7 +299,11 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
if name == "get_rank":
|
||||
return ConstantVariable.create(self.value.get_rank())
|
||||
if name == "get_local_rank":
|
||||
return ConstantVariable.create(self.value.get_local_rank())
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
return ConstantVariable.create(
|
||||
self.value.get_local_rank(*const_args, **const_kwargs)
|
||||
)
|
||||
if name == "get_group":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
|
@ -76,6 +76,7 @@ class StreamVariable(VariableTracker):
|
||||
super().__init__(**kwargs)
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = device
|
||||
|
||||
def python_type(self) -> type:
|
||||
|
@ -1492,6 +1492,7 @@ def _aot_stage2a_partition(
|
||||
|
||||
# apply joint_gm callback here
|
||||
if callable(torch._functorch.config.joint_custom_pass):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
|
||||
|
||||
static_lifetime_input_indices = fw_metadata.static_input_indices
|
||||
@ -1761,6 +1762,7 @@ def _aot_stage2b_bw_compile(
|
||||
# tensor which is wrong.
|
||||
|
||||
ph_size = ph_arg.size()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if len(ph_size) == 0 and len(real_stride) > 0:
|
||||
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
|
||||
# (e.g., via squeeze), its stride should be () not (1,).
|
||||
|
@ -628,6 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
|
||||
]
|
||||
# add the scale nodes to the output find the first sym_node in the output
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
idx = find_first_sym_node(output_updated_args)
|
||||
scale_nodes = tensor_scale_nodes + sym_scale_nodes
|
||||
if scale_nodes:
|
||||
|
@ -1,7 +1,8 @@
|
||||
import collections
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -17,16 +18,24 @@ from torch.utils._ordered_set import OrderedSet
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
|
||||
|
||||
|
||||
# Helper functions moved to top for better organization
|
||||
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
|
||||
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined]
|
||||
_, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
return (group_name, dtype)
|
||||
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
|
||||
_, group_size, group_name = node.args
|
||||
assert isinstance(group_name, str)
|
||||
return (group_name,)
|
||||
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined]
|
||||
_, reduce_op, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
@ -53,6 +62,11 @@ def bucket_key(node: torch.fx.Node) -> object | None:
|
||||
return None
|
||||
|
||||
|
||||
def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
|
||||
assert len(dtypes) > 0
|
||||
return min(dtypes, key=operator.attrgetter("itemsize"))
|
||||
|
||||
|
||||
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
"""
|
||||
Determine the size of a bucket based on its ID.
|
||||
@ -69,15 +83,15 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
def bucket_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
|
||||
if len(ag_buckets) == 0:
|
||||
return
|
||||
merge_all_gather(gm, ag_buckets, mode)
|
||||
@ -86,15 +100,17 @@ def bucket_all_gather(
|
||||
def bucket_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
rs_buckets = bucket_reduce_scatter_by_mb(
|
||||
gm, bucket_cap_mb_by_bucket_idx, None, mode
|
||||
)
|
||||
if len(rs_buckets) == 0:
|
||||
return
|
||||
merge_reduce_scatter(gm, rs_buckets, mode)
|
||||
@ -252,6 +268,7 @@ def bucket_all_gather_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all all_gather nodes and groups them into buckets,
|
||||
@ -271,11 +288,15 @@ def bucket_all_gather_by_mb(
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
|
||||
"""
|
||||
|
||||
group_key_fn = (
|
||||
_ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
|
||||
)
|
||||
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
is_all_gather_into_tensor,
|
||||
_ag_group_key,
|
||||
group_key_fn,
|
||||
filter_wait_node,
|
||||
)
|
||||
|
||||
@ -284,6 +305,7 @@ def bucket_reduce_scatter_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all reduce_scatter nodes and groups them into buckets,
|
||||
@ -301,6 +323,10 @@ def bucket_reduce_scatter_by_mb(
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
|
||||
"""
|
||||
|
||||
assert "multidtype" not in mode, (
|
||||
"reduce scatter bucketing does not support multidtype"
|
||||
)
|
||||
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
@ -439,13 +465,17 @@ def _pre_bucket_all_gather(
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> torch.Tensor:
|
||||
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
device = ag_ins[0].device
|
||||
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
||||
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
||||
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
||||
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
|
||||
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
||||
return new_ag_out
|
||||
|
||||
@ -457,7 +487,11 @@ def _pre_bucket_all_gather_fake(
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> torch.Tensor:
|
||||
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
device = ag_ins[0].device
|
||||
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||
@ -468,14 +502,28 @@ _pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
|
||||
|
||||
|
||||
def all_gather_merge_fn_to_trace_custom_ops(
|
||||
ag_ins: list[torch.Tensor],
|
||||
_ag_ins: list[torch.Tensor],
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> list[torch.Tensor]:
|
||||
ag_ins = [
|
||||
torch._prims.convert_element_type(_ag_in, out_dtype)
|
||||
if _ag_in.dtype != out_dtype
|
||||
else _ag_in
|
||||
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
|
||||
]
|
||||
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
||||
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [
|
||||
ag_in.numel() * out_dtype.itemsize
|
||||
for ag_in, out_dtype in zip(ag_ins, out_dtypes)
|
||||
]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
|
||||
ag_ins, group_size, group_name, dtype, rank
|
||||
@ -487,14 +535,14 @@ def all_gather_merge_fn_to_trace_custom_ops(
|
||||
)
|
||||
)
|
||||
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
||||
outs = torch.split_with_sizes(
|
||||
outs_bucket_dtype = torch.split_with_sizes(
|
||||
new_ag_out_reshaped,
|
||||
ins_split_sizes,
|
||||
dim=1,
|
||||
)
|
||||
outs_reshaped = [
|
||||
o.reshape((shape[0] * group_size,) + shape[1:])
|
||||
for o, shape in zip(outs, ins_sizes)
|
||||
o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
|
||||
for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
|
||||
]
|
||||
return outs_reshaped
|
||||
|
||||
@ -504,6 +552,7 @@ def all_gather_merge_fn_to_trace(
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> list[torch.Tensor]:
|
||||
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
||||
@ -538,6 +587,7 @@ def all_gather_merge_fn_to_trace_functional(
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
use_fsdp_ag_copy_in: bool = False,
|
||||
) -> list[torch.Tensor]:
|
||||
@ -733,7 +783,7 @@ def process_collective_bucket(
|
||||
def merge_reduce_scatter_bucket(
|
||||
g: torch.fx.Graph,
|
||||
rs_nodes: list[torch.fx.Node],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
@ -826,29 +876,27 @@ def merge_all_reduce_bucket(
|
||||
def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
ag0 = ag_nodes[0]
|
||||
ag0_val = ag0.meta["val"]
|
||||
_, group_size, group_name = ag0.args
|
||||
dtype = ag0_val.dtype
|
||||
assert isinstance(group_name, str)
|
||||
_ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined]
|
||||
|
||||
for n in ag_nodes:
|
||||
assert (
|
||||
n.args[1] == group_size
|
||||
and n.args[2] == group_name
|
||||
and n.meta["val"].dtype == dtype
|
||||
)
|
||||
assert n.args[1] == group_size and n.args[2] == group_name
|
||||
_ag_dtypes.append(n.meta["val"].dtype)
|
||||
|
||||
bucket_dtype = pick_bucket_dtype(_ag_dtypes)
|
||||
|
||||
# Choose merge function based on mode
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace
|
||||
if mode and "custom_ops" in mode:
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
|
||||
if mode is not None and "custom_ops" in mode:
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment]
|
||||
|
||||
# Process bucket with lazy input collection
|
||||
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
||||
@ -858,7 +906,8 @@ def merge_all_gather_bucket(
|
||||
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
||||
group_size,
|
||||
group_name,
|
||||
dtype,
|
||||
bucket_dtype,
|
||||
_ag_dtypes,
|
||||
rank,
|
||||
)
|
||||
|
||||
@ -874,7 +923,7 @@ def merge_all_gather_bucket(
|
||||
def merge_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
rs_buckets: list[list[torch.fx.Node]],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
||||
@ -898,7 +947,7 @@ def merge_reduce_scatter(
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
ag_buckets: list[list[torch.fx.Node]],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of all_gather to joint all_gather.
|
||||
|
@ -209,6 +209,7 @@ def addmm_patterns_init():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
int8_woq_fusion_replacement,
|
||||
[val(), val(), val(), val(), scale(), scale(), scale()],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_patterns[0],
|
||||
@ -230,6 +231,7 @@ def addmm_patterns_init():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
matmul_replacement,
|
||||
[val(), val(), val(), val()],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_patterns[0],
|
||||
@ -251,6 +253,7 @@ def addmm_patterns_init():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
matmul_replacement_two,
|
||||
[val(), val(), val()],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_patterns[0],
|
||||
@ -276,6 +279,7 @@ def addmm_patterns_init():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
addmm_fuse_replacement_second,
|
||||
[val() for _ in range(7)],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_patterns[0],
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather_by_mb,
|
||||
bucket_reduce_scatter_by_mb,
|
||||
BucketMode,
|
||||
merge_all_gather,
|
||||
merge_reduce_scatter,
|
||||
)
|
||||
@ -56,7 +57,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
|
||||
def bucket_fsdp_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP all_gather ops.
|
||||
@ -86,7 +87,7 @@ def bucket_fsdp_all_gather(
|
||||
def bucket_fsdp_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP reduce_scatter ops.
|
||||
|
@ -27,6 +27,10 @@ aten = torch.ops.aten
|
||||
patterns = PatternMatcherPass()
|
||||
|
||||
|
||||
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
|
||||
return dim == t.ndim - 1 or dim == -1
|
||||
|
||||
|
||||
def _is_backward(graph: torch.fx.Graph) -> bool:
|
||||
placeholders = []
|
||||
for node in graph.nodes:
|
||||
@ -645,9 +649,17 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
|
||||
if not is_symm_mem_enabled_for_group(group_name):
|
||||
return
|
||||
|
||||
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
|
||||
# Decomposing the matmul on the K dimension is not supported
|
||||
return
|
||||
filter_matmul = None
|
||||
if _is_last_dim(_get_tensor(shard_node), gather_dim):
|
||||
# Decomposed mms should not be too small
|
||||
if _get_tensor(shard_node).shape[-1] < 1024:
|
||||
return
|
||||
|
||||
# scaled_mm is not supported yet for last dim
|
||||
def _filter_out_scaled_matmul(matmul: _Matmul):
|
||||
return not isinstance(matmul, _ScaledMatmul)
|
||||
|
||||
filter_matmul = _filter_out_scaled_matmul
|
||||
|
||||
# Find consumer matmuls
|
||||
matmuls = _find_consumer_matmuls(ag_res_node)
|
||||
@ -663,18 +675,29 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
|
||||
if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
|
||||
return
|
||||
|
||||
if _is_last_dim(_get_tensor(shard_node), gather_dim) and len(
|
||||
all_gather.res_node.users
|
||||
) > len(matmuls):
|
||||
# The result of ag-split-cat is used not only in matmuls.
|
||||
# Then it has to be materialized, which can have overhead.
|
||||
return
|
||||
|
||||
if filter_matmul and not filter_matmul(matmuls[0]):
|
||||
return
|
||||
|
||||
# Fuse the all_gather_tensor with the eligible matmuls
|
||||
graph = ag_node.graph
|
||||
with graph.inserting_before(ag_node):
|
||||
if "val" in shard_node.meta:
|
||||
restrided = restride_A_shard_for_fused_all_gather_matmul(
|
||||
_get_tensor(shard_node),
|
||||
gather_dim,
|
||||
)
|
||||
shard_node = graph.call_function(
|
||||
inductor_prims.force_stride_order,
|
||||
args=(shard_node, restrided.stride()),
|
||||
)
|
||||
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
|
||||
if "val" in shard_node.meta:
|
||||
restrided = restride_A_shard_for_fused_all_gather_matmul(
|
||||
_get_tensor(shard_node),
|
||||
gather_dim,
|
||||
)
|
||||
shard_node = graph.call_function(
|
||||
inductor_prims.force_stride_order,
|
||||
args=(shard_node, restrided.stride()),
|
||||
)
|
||||
|
||||
fused_node = _insert_fused_all_gather_matmul(
|
||||
graph, matmuls, shard_node, gather_dim, group_name
|
||||
@ -881,7 +904,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
||||
return
|
||||
|
||||
filter_matmul = None
|
||||
if orig_scatter_dim == _get_tensor(input_node).ndim - 1:
|
||||
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
|
||||
# scaled_mm is not supported yet for last dim mm+rs
|
||||
def _filter_out_scaled_matmul(matmul: _Matmul):
|
||||
return not isinstance(matmul, _ScaledMatmul)
|
||||
|
@ -49,6 +49,7 @@ def _misc_patterns_init():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
randperm_index_add_replacement,
|
||||
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
[post_grad_patterns, joint_graph_patterns],
|
||||
@ -68,6 +69,7 @@ def _misc_patterns_init():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
randperm_index_replacement,
|
||||
[torch.empty(4, 8, device=device)],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
[post_grad_patterns, joint_graph_patterns],
|
||||
|
@ -919,6 +919,7 @@ def _pad_mm_init() -> None:
|
||||
pattern,
|
||||
replacement,
|
||||
args,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
joint_fwd_bwd,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
patterns,
|
||||
@ -931,6 +932,7 @@ def _pad_mm_init() -> None:
|
||||
pattern,
|
||||
replacement,
|
||||
args,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
patterns,
|
||||
|
@ -216,7 +216,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
lambda graph: p(
|
||||
graph.owning_module,
|
||||
config.bucket_reduce_scatters_fx_bucket_size_determinator,
|
||||
config.bucket_reduce_scatters_fx,
|
||||
config.bucket_reduce_scatters_fx, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
collectives_bucketing = True
|
||||
@ -236,7 +236,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
lambda graph: p(
|
||||
graph.owning_module,
|
||||
config.bucket_all_gathers_fx_bucket_size_determinator,
|
||||
config.bucket_all_gathers_fx,
|
||||
config.bucket_all_gathers_fx, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
collectives_bucketing = True
|
||||
@ -666,6 +666,7 @@ def lazy_init():
|
||||
prepare_softmax_replacement,
|
||||
[torch.empty(4, 8)],
|
||||
scalar_workaround=dict(dim=-1),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
trace_fn=fwd_only,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_dicts=pass_patterns[1],
|
||||
@ -730,6 +731,7 @@ def register_lowering_pattern(
|
||||
return pattern_matcher.register_lowering_pattern(
|
||||
pattern,
|
||||
extra_check,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_dict=pass_patterns[pass_number],
|
||||
)
|
||||
|
||||
@ -1573,6 +1575,7 @@ def register_partial_reduction_pattern():
|
||||
|
||||
@register_graph_pattern(
|
||||
MultiOutputPattern([partial_reduc, full_reduc]),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pass_dict=pass_patterns[2],
|
||||
)
|
||||
def reuse_partial(match, input, reduced_dims, keepdim):
|
||||
|
@ -27,7 +27,7 @@ from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.associative_scan import associative_scan_op
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._library.utils import get_layout_constraint_tag
|
||||
from torch._prims_common import (
|
||||
from torch._prims_common import ( # pyrefly: ignore # deprecated
|
||||
canonicalize_dim,
|
||||
canonicalize_dims,
|
||||
check,
|
||||
|
@ -3994,6 +3994,12 @@ class Scheduler:
|
||||
):
|
||||
return -1
|
||||
|
||||
# in some rare case, a template can be passed in.
|
||||
# Check test_interaction_with_multi_template in test_loop_ordering.py
|
||||
# and https://github.com/pytorch/pytorch/issues/165579
|
||||
if node1.is_template() or node2.is_template():
|
||||
return -1
|
||||
|
||||
node1_buffer_names = node1.read_writes.buffer_names()
|
||||
node2_buffer_names = node2.read_writes.buffer_names()
|
||||
# Fast path: no common buffers.
|
||||
|
@ -173,6 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
|
||||
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
||||
)
|
||||
_OPAQUE_TYPES[cls] = name
|
||||
# pyrefly: ignore # missing-attribute
|
||||
torch._C._register_opaque_type(name)
|
||||
|
||||
|
||||
@ -182,4 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
|
||||
"""
|
||||
if cls not in _OPAQUE_TYPES:
|
||||
return False
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
||||
|
@ -135,7 +135,7 @@ if is_available():
|
||||
# this.
|
||||
# pyrefly: ignore # deprecated
|
||||
from .distributed_c10d import * # noqa: F403
|
||||
from .distributed_c10d import (
|
||||
from .distributed_c10d import ( # pyrefly: ignore # deprecated
|
||||
_all_gather_base,
|
||||
_coalescing_manager,
|
||||
_CoalescingManager,
|
||||
|
@ -1009,8 +1009,8 @@ lib_impl.impl("broadcast", _broadcast_meta, "Meta")
|
||||
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
|
||||
|
||||
# mark these ops has side effect so that they won't be removed by DCE
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type]
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type]
|
||||
|
||||
# Register legacy ops for backward compatibility
|
||||
# TODO(yifu): remove these in functional collective beta release
|
||||
@ -1176,7 +1176,7 @@ def all_gather_inplace(
|
||||
return tensor_list
|
||||
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
|
||||
_all_gather_base as legacy_all_gather_base,
|
||||
_reduce_scatter_base as legacy_reduce_scatter_base,
|
||||
all_gather as legacy_all_gather,
|
||||
@ -1190,11 +1190,11 @@ from torch.distributed.distributed_c10d import (
|
||||
# This dict should contain sets of functions that dynamo is allowed to remap.
|
||||
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
|
||||
traceable_collective_remaps = {
|
||||
legacy_allgather: all_gather_tensor_inplace,
|
||||
legacy_reducescatter: reduce_scatter_tensor_inplace,
|
||||
legacy_allreduce: all_reduce_inplace,
|
||||
legacy_all_to_all_single: all_to_all_inplace,
|
||||
legacy_all_gather: all_gather_inplace,
|
||||
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
|
||||
legacy_all_gather_base: all_gather_tensor_inplace,
|
||||
legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type]
|
||||
legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type]
|
||||
legacy_allreduce: all_reduce_inplace, # type: ignore[has-type]
|
||||
legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type]
|
||||
legacy_all_gather: all_gather_inplace, # type: ignore[has-type]
|
||||
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type]
|
||||
legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type]
|
||||
}
|
||||
|
@ -393,6 +393,7 @@ class LocalTensor(torch.Tensor):
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
parts = []
|
||||
for k, v in self._local_tensors.items():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
parts.append(f" {k}: {v}")
|
||||
tensors_str = ",\n".join(parts)
|
||||
return f"LocalTensor(\n{tensors_str}\n)"
|
||||
@ -680,6 +681,7 @@ class LocalTensorMode(TorchDispatchMode):
|
||||
def _unpatch_device_mesh(self) -> None:
|
||||
assert self._old_get_coordinate is not None
|
||||
DeviceMesh.get_coordinate = self._old_get_coordinate
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._old_get_coordinate = None
|
||||
|
||||
|
||||
|
@ -316,6 +316,7 @@ def _local_all_gather_(
|
||||
assert len(input_tensors) == 1
|
||||
|
||||
input_tensor = input_tensors[0]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
output_tensors = output_tensors[0]
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
@ -336,10 +337,12 @@ def _local_all_gather_(
|
||||
source_tensor = input_tensor
|
||||
if isinstance(input_tensor, LocalTensor):
|
||||
source_tensor = input_tensor._local_tensors[rank_i]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
output_tensors[i].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
# pyrefly: ignore # bad-return
|
||||
return ([output_tensors], work_so)
|
||||
|
||||
|
||||
@ -426,6 +429,7 @@ def _local_scatter_(
|
||||
assert len(output_tensors) == 1
|
||||
assert len(input_tensors) == 1
|
||||
output_tensor = output_tensors[0]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
input_tensors = input_tensors[0]
|
||||
|
||||
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
||||
|
@ -9,6 +9,7 @@ from itertools import product
|
||||
|
||||
import torch
|
||||
from torch.distributed._pycute import (
|
||||
as_tuple,
|
||||
coalesce,
|
||||
complement,
|
||||
composition,
|
||||
@ -17,7 +18,6 @@ from torch.distributed._pycute import (
|
||||
is_int,
|
||||
is_tuple,
|
||||
Layout,
|
||||
suffix_product,
|
||||
)
|
||||
|
||||
|
||||
@ -79,6 +79,11 @@ class _MeshLayout(Layout):
|
||||
|
||||
# # operator [] (get-i like tuples)
|
||||
def __getitem__(self, i: int) -> "_MeshLayout":
|
||||
if i < -len(self) or i >= len(self):
|
||||
raise IndexError(
|
||||
f"Dim {i} is out of range for layout with {len(self)} dimensions. "
|
||||
f"Expected dim to be in range [{-len(self)}, {len(self) - 1}]."
|
||||
)
|
||||
layout = super().__getitem__(i)
|
||||
return _MeshLayout(layout.shape, layout.stride)
|
||||
|
||||
@ -156,50 +161,11 @@ class _MeshLayout(Layout):
|
||||
layout = complement(self, world_size)
|
||||
return _MeshLayout(layout.shape, layout.stride)
|
||||
|
||||
def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
|
||||
"""
|
||||
Unflatten a single dimension in the layout by splitting it into multiple dimensions.
|
||||
It takes a dimension at position `dim` and splits it into multiple new dimensions
|
||||
with the specified sizes.
|
||||
|
||||
Args:
|
||||
dim (int): The index of the dimension to unflatten. Must be a valid dimension index.
|
||||
unflatten_sizes (tuple[int, ...]): The new sizes for the dimensions that will replace
|
||||
the original dimension at `dim`. The product of these sizes must equal the size
|
||||
of the original dimension at `dim`.
|
||||
|
||||
Returns:
|
||||
_MeshLayout: A new layout with the specified dimension unflattened.
|
||||
|
||||
Example:
|
||||
Original: sizes=(8,), strides=(1,) # 8 ranks in 1D
|
||||
Call: unflatten(0, (2, 2, 2)) # Create 3D topology
|
||||
Result: sizes=(2, 2, 2), strides=(4, 2, 1) # 2*2*2 unflattened topology
|
||||
"""
|
||||
# Check that dim is within valid range
|
||||
if dim < 0 or dim >= len(self):
|
||||
raise ValueError(
|
||||
f"dim {dim} is out of range for layout with {len(self)} dimensions. "
|
||||
f"Expected dim to be in range [0, {len(self) - 1}]."
|
||||
)
|
||||
|
||||
# Check that the product of unflatten_sizes equals the original dimension size
|
||||
original_size = self[dim].numel()
|
||||
unflatten_product = math.prod(unflatten_sizes)
|
||||
if unflatten_product != original_size:
|
||||
raise ValueError(
|
||||
f"The product of unflatten_sizes {unflatten_sizes} is {unflatten_product}, "
|
||||
f"but the original dimension at dim={dim} has size {original_size}. "
|
||||
f"These must be equal for unflatten to work correctly."
|
||||
)
|
||||
|
||||
sizes = list(self.sizes) # type: ignore[arg-type]
|
||||
strides = list(self.strides) # type: ignore[arg-type]
|
||||
unflatten_layout = self[dim].composition(
|
||||
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
|
||||
)
|
||||
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
|
||||
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
|
||||
def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout":
|
||||
sizes = list(as_tuple(self.sizes))
|
||||
strides = list(as_tuple(self.strides))
|
||||
sizes[start:end] = list(as_tuple(layout.sizes))
|
||||
strides[start:end] = list(as_tuple(layout.strides))
|
||||
return _MeshLayout(tuple(sizes), tuple(strides))
|
||||
|
||||
def all_ranks_from_zero(self) -> list[int]:
|
||||
@ -301,10 +267,7 @@ class _MeshLayout(Layout):
|
||||
ranks = self.all_ranks_from_zero()
|
||||
return len(ranks) == len(set(ranks))
|
||||
|
||||
def remap_to_tensor(
|
||||
self,
|
||||
mesh_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
|
||||
transformation to actual device ranks.
|
||||
@ -316,10 +279,7 @@ class _MeshLayout(Layout):
|
||||
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
|
||||
sub-tensor for DeviceMesh and its backend creation.
|
||||
|
||||
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
|
||||
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
|
||||
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
|
||||
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
|
||||
The shape of the `rank_map` must be 1D and contiguous.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -336,18 +296,18 @@ class _MeshLayout(Layout):
|
||||
Return: [[[10,30],[20,40]]]
|
||||
|
||||
Args:
|
||||
mesh_tensor: The concrete mesh tensor with actual device ranks
|
||||
rank_map: The concrete mesh tensor with actual device ranks
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
|
||||
torch.Tensor: A tensor representing the actual device allocation from rank_map
|
||||
"""
|
||||
complement_layout = self.complement(mesh_tensor.numel())
|
||||
assert rank_map.ndim == 1
|
||||
assert rank_map.is_contiguous()
|
||||
assert rank_map.numel() >= self.cosize()
|
||||
|
||||
return (
|
||||
mesh_tensor.flatten()
|
||||
.as_strided(
|
||||
flatten(complement_layout.sizes) + flatten(self.sizes),
|
||||
flatten(complement_layout.strides) + flatten(self.strides),
|
||||
)
|
||||
.reshape(-1, *(self[i].numel() for i in range(len(self))))
|
||||
)
|
||||
complement_layout = self.complement(rank_map.numel())
|
||||
|
||||
return rank_map.as_strided(
|
||||
flatten(complement_layout.sizes) + flatten(self.sizes),
|
||||
flatten(complement_layout.strides) + flatten(self.strides),
|
||||
).reshape(-1, *self.top_level_sizes)
|
||||
|
@ -31,6 +31,7 @@
|
||||
#################################################################################################
|
||||
|
||||
from .int_tuple import (
|
||||
as_tuple,
|
||||
crd2crd,
|
||||
crd2idx,
|
||||
elem_scale,
|
||||
|
@ -54,6 +54,12 @@ def is_tuple(x: object) -> TypeIs[tuple]:
|
||||
return isinstance(x, tuple)
|
||||
|
||||
|
||||
def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]:
|
||||
if is_int(x):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def flatten(t: IntTuple) -> tuple[int, ...]:
|
||||
if is_tuple(t):
|
||||
if len(t) == 0:
|
||||
|
@ -524,6 +524,19 @@ def _fused_all_gather_matmul_impl(
|
||||
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
|
||||
if gather_dim == A_shard.ndim - 1 or gather_dim == -1:
|
||||
return _fused_all_gather_matmul_last_gather_dim_impl(
|
||||
mm_out_op,
|
||||
A_shard,
|
||||
Bs,
|
||||
A_scale,
|
||||
kwargs_list,
|
||||
out_dtypes,
|
||||
gather_dim,
|
||||
group_name,
|
||||
return_A,
|
||||
)
|
||||
|
||||
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
|
||||
# The flattened tensor doesn't need to be contiguous (for computation
|
||||
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
|
||||
@ -624,6 +637,140 @@ def _fused_all_gather_matmul_impl(
|
||||
return A, [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
def _pipelined_all_gather_and_consume_last_dim(
|
||||
shard: torch.Tensor,
|
||||
shard_consumer: Callable[[torch.Tensor, int], None],
|
||||
ag_out: torch.Tensor,
|
||||
group_name: str,
|
||||
ag_out_needed: bool = True,
|
||||
) -> None:
|
||||
p2p_workspace_size_req = 0
|
||||
p2p_workspace_size_req = shard.numel() * shard.element_size()
|
||||
symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
|
||||
group_size = symm_mem.world_size
|
||||
rank = symm_mem.rank
|
||||
|
||||
symm_mem.barrier(channel=0)
|
||||
backend_stream = _get_backend_stream()
|
||||
backend_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None:
|
||||
dst.copy_(src)
|
||||
|
||||
def get_p2p_buf(remote_rank: int) -> torch.Tensor:
|
||||
buf = symm_mem.get_buffer(
|
||||
remote_rank,
|
||||
shard.shape,
|
||||
shard.dtype,
|
||||
)
|
||||
return buf
|
||||
|
||||
local_p2p_buf = get_p2p_buf(rank)
|
||||
|
||||
shards = ag_out.chunk(group_size)
|
||||
|
||||
copy_shard(dst=local_p2p_buf, src=shard)
|
||||
symm_mem.barrier(channel=1)
|
||||
backend_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
# At this point, all ranks have copied their local shard to
|
||||
# their local p2p buffer. Each rank can now copy and consume
|
||||
# remote shards.
|
||||
shard_consumer(shard, rank)
|
||||
|
||||
for step in range(1, group_size):
|
||||
if step % 2 == 0:
|
||||
stream = torch.cuda.current_stream()
|
||||
else:
|
||||
stream = backend_stream
|
||||
remote_rank = (step + rank) % group_size
|
||||
remote_p2p_buf = get_p2p_buf(remote_rank)
|
||||
with stream:
|
||||
copy_shard(dst=shards[remote_rank], src=remote_p2p_buf)
|
||||
shard_consumer(shards[remote_rank], remote_rank)
|
||||
|
||||
if ag_out_needed:
|
||||
# Copy from input to the all-gather output. Opportunistically overlap
|
||||
# it with the last shard_consumer.
|
||||
if group_size % 2 == 0:
|
||||
stream = torch.cuda.current_stream()
|
||||
else:
|
||||
stream = backend_stream
|
||||
with stream:
|
||||
copy_shard(dst=shards[rank], src=shard)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(backend_stream)
|
||||
symm_mem.barrier(channel=0)
|
||||
|
||||
|
||||
def _fused_all_gather_matmul_last_gather_dim_impl(
|
||||
mm_out_op: torch._ops.OpOverload,
|
||||
A_shard: torch.Tensor,
|
||||
Bs: list[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
kwargs_list: list[dict[str, Any]],
|
||||
out_dtypes: list[torch.dtype | None],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
return_A: bool,
|
||||
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
group_size = group.size()
|
||||
|
||||
B_shards = [B.chunk(group.size()) for B in Bs]
|
||||
|
||||
leading_dims = list(A_shard.shape[:-1])
|
||||
A_shard_flat = A_shard.flatten(0, -2)
|
||||
|
||||
def unflatten(t: torch.Tensor) -> torch.Tensor:
|
||||
return t.view(*leading_dims, -1)
|
||||
|
||||
A_flat_out = A_shard_flat.new_empty(
|
||||
A_shard_flat.shape[0] * group.size(),
|
||||
A_shard_flat.shape[1],
|
||||
)
|
||||
|
||||
outputs = [
|
||||
torch.empty(
|
||||
(A_shard_flat.shape[0], B.shape[1]),
|
||||
dtype=out_dtype or B.dtype,
|
||||
device=A_shard.device,
|
||||
)
|
||||
for B, out_dtype in zip(Bs, out_dtypes)
|
||||
]
|
||||
|
||||
first = True
|
||||
events = [torch.cuda.Event() for _ in outputs]
|
||||
|
||||
def default_consumer(shard: torch.Tensor, rank: int) -> None:
|
||||
nonlocal first
|
||||
for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list):
|
||||
event.wait()
|
||||
if first:
|
||||
torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out)
|
||||
else:
|
||||
out.addmm_(shard, B_shard[rank])
|
||||
event.record()
|
||||
|
||||
first = False
|
||||
|
||||
_pipelined_all_gather_and_consume_last_dim(
|
||||
A_shard_flat,
|
||||
default_consumer,
|
||||
A_flat_out,
|
||||
group_name,
|
||||
return_A,
|
||||
)
|
||||
ret_A = None
|
||||
if return_A:
|
||||
# This path is inefficient and will be filtered out at passes stage
|
||||
# Added only for completeness.
|
||||
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
|
||||
ret_A = unflatten(A_split_cat_out_flat)
|
||||
|
||||
return ret_A, [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
|
||||
def _fused_all_gather_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
@ -638,6 +785,15 @@ def _fused_all_gather_matmul_fallback(
|
||||
A_shard.contiguous(), group_size, group_name
|
||||
)
|
||||
A = torch.ops._c10d_functional.wait_tensor(A)
|
||||
if gather_dim == A.ndim - 1 or gather_dim == -1:
|
||||
A_splits = A.chunk(group_size)
|
||||
A_mm = torch.cat(A_splits, dim=-1)
|
||||
res = [torch.matmul(A_mm, B) for B in Bs]
|
||||
if return_A:
|
||||
return A_mm, res
|
||||
else:
|
||||
return None, res
|
||||
|
||||
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
|
||||
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
|
||||
if return_A:
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
@ -12,7 +11,7 @@ from typing import Optional, TYPE_CHECKING, Union
|
||||
import torch
|
||||
from torch.distributed import is_available
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed._pycute import is_int
|
||||
from torch.distributed._pycute import is_int, suffix_product
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
@ -173,7 +172,7 @@ else:
|
||||
"""
|
||||
|
||||
_device_type: str
|
||||
_mesh: torch.Tensor
|
||||
_rank_map: torch.Tensor
|
||||
_mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_layout: _MeshLayout
|
||||
_root_mesh: Optional["DeviceMesh"] = None
|
||||
@ -183,60 +182,75 @@ else:
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"],
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
self._device_type = device_type
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
self._mesh = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * self.mesh.ndim
|
||||
elif len(backend_override) != self.mesh.ndim:
|
||||
raise ValueError(
|
||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||
f"but got {len(backend_override)} and {self.mesh.ndim}."
|
||||
if mesh is not None:
|
||||
if _layout is not None or _rank_map is not None:
|
||||
raise TypeError(
|
||||
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
|
||||
)
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
mesh_tensor = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = (
|
||||
_layout
|
||||
if _layout
|
||||
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||
)
|
||||
assert self._layout.check_non_overlap(), (
|
||||
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
||||
_rank_map = mesh_tensor.flatten()
|
||||
else:
|
||||
if _layout is None or _rank_map is None:
|
||||
raise TypeError(
|
||||
"The mesh argument is required except for PRIVATE USAGE ONLY!"
|
||||
)
|
||||
|
||||
assert _layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
assert self._layout.top_level_sizes == self.mesh.size(), (
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
|
||||
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
|
||||
assert _rank_map.numel() >= _layout.cosize(), (
|
||||
f"The rank map contains {_rank_map.numel()} element, "
|
||||
f"which isn't large enough for layout {_layout}"
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
self._device_type = device_type
|
||||
self._layout = _layout
|
||||
self._rank_map = _rank_map
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
self._root_mesh = _root_mesh
|
||||
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * len(self._layout)
|
||||
elif len(backend_override) != len(self._layout):
|
||||
raise ValueError(
|
||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||
f"but got {len(backend_override)} and {len(self._layout)}."
|
||||
)
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
self._thread_id = None
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
if _init_backend:
|
||||
self._setup_world_group_and_device()
|
||||
self._init_process_groups(backend_override)
|
||||
self._dim_group_names = self._init_process_groups(
|
||||
self._layout,
|
||||
self._rank_map,
|
||||
self._mesh_dim_names,
|
||||
backend_override,
|
||||
)
|
||||
|
||||
if is_initialized() and get_backend() == "threaded":
|
||||
# pyrefly: ignore # bad-assignment
|
||||
@ -252,6 +266,11 @@ else:
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
@ -260,7 +279,17 @@ else:
|
||||
@property
|
||||
def mesh(self) -> torch.Tensor:
|
||||
"""Returns the tensor representing the layout of devices."""
|
||||
return self._mesh
|
||||
full_mesh = self._layout.remap_to_tensor(self._rank_map)
|
||||
if full_mesh.size(0) == 1:
|
||||
return full_mesh[0]
|
||||
my_coords = (full_mesh == get_rank()).nonzero()
|
||||
if my_coords.size(0) > 0:
|
||||
return full_mesh[my_coords[0, 0]]
|
||||
raise RuntimeError(
|
||||
"In order to get the mesh Tensor of a DeviceMesh it needs to "
|
||||
"either have all its original dimensions (e.g., no slicing) "
|
||||
"or it needs to contain the local rank"
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
||||
@ -275,9 +304,9 @@ else:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self.mesh.numel() > world_size:
|
||||
if self._layout.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
|
||||
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
|
||||
)
|
||||
|
||||
# ONLY set the device if the current device is not initialized, if user already
|
||||
@ -317,10 +346,13 @@ else:
|
||||
|
||||
return _get_default_group()
|
||||
|
||||
@staticmethod
|
||||
def _init_process_groups(
|
||||
self,
|
||||
layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
mesh_dim_names: Optional[tuple[str, ...]],
|
||||
backend_override: tuple[BackendConfig, ...],
|
||||
):
|
||||
) -> list[str]:
|
||||
# group_name associated with each mesh dimension, each
|
||||
# mesh dimension should have one sub-group per rank
|
||||
#
|
||||
@ -328,8 +360,8 @@ else:
|
||||
default_group = _get_default_group()
|
||||
|
||||
if (
|
||||
self.mesh.ndim == 1
|
||||
and self.mesh.numel() == get_world_size()
|
||||
len(layout) == 1
|
||||
and layout.numel() == get_world_size()
|
||||
and backend_override[0] == (None, None)
|
||||
):
|
||||
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
||||
@ -348,12 +380,10 @@ else:
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
for dim in range(len(layout)):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||
-1, self.mesh.size(dim)
|
||||
)
|
||||
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
|
||||
backend, pg_options = backend_override[dim]
|
||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||
# the default timeout will be used to override the timeout set in option.
|
||||
@ -365,8 +395,8 @@ else:
|
||||
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
|
||||
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
|
||||
group_desc = (
|
||||
f"mesh_{self._mesh_dim_names[dim]}"
|
||||
if self._mesh_dim_names
|
||||
f"mesh_{mesh_dim_names[dim]}"
|
||||
if mesh_dim_names
|
||||
else f"mesh_dim_{dim}"
|
||||
)
|
||||
|
||||
@ -424,14 +454,14 @@ else:
|
||||
)
|
||||
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if get_rank() in subgroup_ranks:
|
||||
if len(dim_group_names) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
|
||||
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
|
||||
self._dim_group_names = dim_group_names
|
||||
return dim_group_names
|
||||
|
||||
def _get_root_mesh(self) -> "DeviceMesh":
|
||||
return self._root_mesh if self._root_mesh else self
|
||||
@ -448,14 +478,14 @@ else:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
device_mesh_repr = (
|
||||
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
|
||||
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
|
||||
if self._mesh_dim_names
|
||||
else f"{tuple(self._mesh.shape)}"
|
||||
else f"{self._layout.top_level_sizes}"
|
||||
)
|
||||
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
|
||||
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
|
||||
# We only print the mesh tensor if the debug mode is turned on.
|
||||
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
|
||||
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
|
||||
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
|
||||
return f"{device_mesh_repr})"
|
||||
|
||||
def __hash__(self):
|
||||
@ -465,7 +495,7 @@ else:
|
||||
self._hash = hash(
|
||||
(
|
||||
self._flatten_mesh_list,
|
||||
self._mesh.shape,
|
||||
self._layout,
|
||||
self._device_type,
|
||||
self._mesh_dim_names,
|
||||
self._thread_id,
|
||||
@ -481,7 +511,7 @@ else:
|
||||
return False
|
||||
return (
|
||||
self._flatten_mesh_list == other._flatten_mesh_list
|
||||
and self._mesh.shape == other._mesh.shape
|
||||
and self._layout == other._layout
|
||||
and self._device_type == other._device_type
|
||||
and self._mesh_dim_names == other._mesh_dim_names
|
||||
and self._thread_id == other._thread_id
|
||||
@ -573,16 +603,16 @@ else:
|
||||
if not hasattr(self, "_dim_group_names"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
|
||||
if self.mesh.ndim > 1 and mesh_dim is None:
|
||||
if len(self._layout) > 1 and mesh_dim is None:
|
||||
raise RuntimeError(
|
||||
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
|
||||
"please use `get_all_groups()` instead.",
|
||||
)
|
||||
|
||||
# Quick return if the current device_mesh is a 1D mesh.
|
||||
if self.mesh.ndim == 1 and mesh_dim is None:
|
||||
if len(self._layout) == 1 and mesh_dim is None:
|
||||
return not_none(_resolve_process_group(self._dim_group_names[0]))
|
||||
|
||||
root_mesh = self._get_root_mesh()
|
||||
@ -608,7 +638,7 @@ else:
|
||||
Returns:
|
||||
A list of :class:`ProcessGroup` object.
|
||||
"""
|
||||
return [self.get_group(i) for i in range(self.mesh.ndim)]
|
||||
return [self.get_group(i) for i in range(len(self._layout))]
|
||||
|
||||
def _create_sub_mesh(
|
||||
self,
|
||||
@ -634,18 +664,13 @@ else:
|
||||
not_none(flatten_mesh._mesh_dim_names).index(name)
|
||||
]
|
||||
)
|
||||
cur_rank = self.get_rank()
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
res_submesh = DeviceMesh(
|
||||
self._device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
submesh_dim_names,
|
||||
_init_backend=False,
|
||||
_layout=layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=submesh_dim_names,
|
||||
_root_mesh=root_mesh,
|
||||
_init_backend=False,
|
||||
)
|
||||
res_submesh._dim_group_names = slice_dim_group_name
|
||||
return res_submesh
|
||||
@ -689,22 +714,13 @@ else:
|
||||
f"Please specify another valid mesh_dim_name."
|
||||
)
|
||||
|
||||
cur_rank = root_mesh.get_rank()
|
||||
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
||||
# new_group api to avoid potential hang.
|
||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
res_flattened_mesh = DeviceMesh(
|
||||
root_mesh._device_type,
|
||||
pg_ranks_by_dim.flatten(
|
||||
start_dim=1
|
||||
), # this is needed for flatten non-contiguous mesh dims.
|
||||
cur_rank,
|
||||
(mesh_dim_name,),
|
||||
(backend_override,),
|
||||
_layout=flattened_mesh_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_root_mesh=root_mesh,
|
||||
backend_override=(backend_override,),
|
||||
)
|
||||
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
|
||||
|
||||
@ -833,9 +849,7 @@ else:
|
||||
"""
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
||||
layout = self._layout[mesh_dim]
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
self.mesh,
|
||||
)
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
|
||||
cur_rank = self.get_rank()
|
||||
res_submeshes = []
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
@ -854,60 +868,6 @@ else:
|
||||
|
||||
return res_submeshes
|
||||
|
||||
@staticmethod
|
||||
def _create_mesh_from_ranks(
|
||||
device_type: str,
|
||||
pg_ranks_by_dim: torch.Tensor,
|
||||
cur_rank: int,
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
|
||||
the constraint of ProcessGroup API that all ranks have to call the PG creation API
|
||||
even if the rank is not in that PG.
|
||||
We will create a potentially very large number of DeviceMesh objects
|
||||
(e.g., on 1024 GPUs with TP=2, this could be up to 512 DeviceMeshes), only to throw
|
||||
them all away except when the mesh contains the current rank.
|
||||
|
||||
#TODO: Further refactor this method once we relax the ProcessGroup API constraint.
|
||||
|
||||
Args:
|
||||
device_type: The device type of the mesh.
|
||||
pg_ranks_by_dim: all ranks within the worlds organized by dimensions.
|
||||
cur_rank: The current global rank in the mesh.
|
||||
mesh_dim_names: Mesh dimension names.
|
||||
backend_override: Optional backend override for the mesh.
|
||||
_init_backend: Whether to initialize the backend of the mesh.
|
||||
_layout: Optional layout for the mesh.
|
||||
|
||||
Returns:
|
||||
The DeviceMesh containing the current rank.
|
||||
"""
|
||||
res_mesh = None
|
||||
for mesh_nd in pg_ranks_by_dim:
|
||||
mesh = DeviceMesh(
|
||||
device_type,
|
||||
mesh_nd,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
_init_backend=_init_backend,
|
||||
_layout=_layout,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_mesh = mesh
|
||||
if res_mesh is None:
|
||||
raise RuntimeError(
|
||||
f"Current rank {cur_rank} not found in any mesh, "
|
||||
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
|
||||
)
|
||||
if _root_mesh is not None:
|
||||
res_mesh._root_mesh = _root_mesh
|
||||
return res_mesh
|
||||
|
||||
@staticmethod
|
||||
def from_group(
|
||||
group: Union[ProcessGroup, list[ProcessGroup]],
|
||||
@ -1004,15 +964,17 @@ else:
|
||||
return device_mesh
|
||||
|
||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
|
||||
if mesh_dim is not None:
|
||||
return self._layout[mesh_dim].numel()
|
||||
return self._layout.numel()
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.mesh.ndim
|
||||
return len(self._layout)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return tuple(self.mesh.shape)
|
||||
return self._layout.top_level_sizes
|
||||
|
||||
def get_rank(self) -> int:
|
||||
"""
|
||||
@ -1051,7 +1013,7 @@ else:
|
||||
"""
|
||||
if self.ndim > 1 and mesh_dim is None:
|
||||
raise RuntimeError(
|
||||
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||
)
|
||||
elif mesh_dim is None:
|
||||
@ -1112,22 +1074,28 @@ else:
|
||||
tuple[Optional[str], Optional[C10dBackend.Options]], ...
|
||||
] = ((None, None),),
|
||||
) -> "DeviceMesh":
|
||||
root_mesh = self._get_root_mesh()
|
||||
cur_rank = self.get_rank()
|
||||
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
|
||||
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
|
||||
|
||||
if inner_layout.numel() != self._layout[dim].numel():
|
||||
raise ValueError(
|
||||
f"The product of {mesh_sizes=} is {inner_layout.numel()}, "
|
||||
f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. "
|
||||
f"These must be equal for unflatten to work correctly."
|
||||
)
|
||||
|
||||
partial_layout = self._layout[dim].composition(inner_layout)
|
||||
unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
|
||||
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
||||
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
||||
res_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
|
||||
root_mesh = self._get_root_mesh()
|
||||
res_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
tuple(unflattened_mesh_dim_names),
|
||||
_init_backend=False,
|
||||
_layout=unflattened_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=tuple(unflattened_mesh_dim_names),
|
||||
_root_mesh=root_mesh,
|
||||
_init_backend=False,
|
||||
)
|
||||
|
||||
# If original mesh has initiated its backend, we need to initialize the backend
|
||||
@ -1135,33 +1103,13 @@ else:
|
||||
# TODO: To make backend init more efficient with cute layout representation and support
|
||||
# per dim backend init.
|
||||
if hasattr(self, "_dim_group_names"):
|
||||
unflatten_length = len(mesh_sizes)
|
||||
unflatten_layout = _MeshLayout(
|
||||
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
|
||||
)
|
||||
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
self.device_type,
|
||||
unflatten_pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
dim_group_names = self._dim_group_names.copy()
|
||||
dim_group_names[dim : dim + 1] = self._init_process_groups(
|
||||
partial_layout,
|
||||
root_mesh._rank_map,
|
||||
mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
backend_override,
|
||||
)
|
||||
dim_group_names = []
|
||||
for idx in range(0, res_mesh.ndim):
|
||||
if idx < dim:
|
||||
dim_group_names.append(self._dim_group_names[idx])
|
||||
elif idx >= dim + unflatten_length:
|
||||
dim_group_names.append(
|
||||
self._dim_group_names[idx - unflatten_length + 1]
|
||||
)
|
||||
else:
|
||||
dim_group_names.append(
|
||||
unflatten_submesh._dim_group_names[idx - dim]
|
||||
)
|
||||
res_mesh._dim_group_names = dim_group_names
|
||||
|
||||
return res_mesh
|
||||
@ -1349,13 +1297,15 @@ else:
|
||||
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
|
||||
)
|
||||
|
||||
# Always initialize the mesh's tensor on CPU, regardless of what the
|
||||
layout = _MeshLayout(tuple(mesh_shape), suffix_product(mesh_shape))
|
||||
# Always initialize the (identity) rank map on CPU, regardless of what the
|
||||
# external device type has been set to be (e.g. meta)
|
||||
with torch.device("cpu"):
|
||||
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
|
||||
rank_map = torch.arange(layout.numel(), dtype=torch.int)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
_layout=layout,
|
||||
_rank_map=rank_map,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override_tuple,
|
||||
)
|
||||
|
@ -90,6 +90,7 @@ class DTensorSpec:
|
||||
if not isinstance(self.placements, tuple):
|
||||
self.placements = tuple(self.placements)
|
||||
if self.shard_order is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
|
||||
self._hash: int | None = None
|
||||
|
||||
|
@ -701,6 +701,7 @@ def _restore_state_dict(
|
||||
for name, _ in list(
|
||||
chain(
|
||||
original_module.named_parameters(remove_duplicate=False),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
original_module.named_buffers(remove_duplicate=False),
|
||||
)
|
||||
):
|
||||
|
@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"annotate",
|
||||
"annotate_fn",
|
||||
"preserve_node_meta",
|
||||
"has_preserved_node_meta",
|
||||
"set_stack_trace",
|
||||
@ -266,9 +267,10 @@ def annotate(annotation_dict: dict):
|
||||
into the FX trace metadata.
|
||||
|
||||
Example:
|
||||
After exiting the context, custom annotations are removed.
|
||||
|
||||
>>> with annotate({"source": "custom_pass", "tag": 42}):
|
||||
... # compute here
|
||||
# After exiting the context, custom annotations are removed.
|
||||
... pass # Your computation here
|
||||
"""
|
||||
|
||||
global current_meta
|
||||
@ -291,6 +293,43 @@ def annotate(annotation_dict: dict):
|
||||
del current_meta["custom"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def annotate_fn(annotation_dict: dict):
|
||||
"""
|
||||
A decorator that wraps a function with the annotate context manager.
|
||||
Use this when you want to annotate an entire function instead of a specific code block.
|
||||
|
||||
Note:
|
||||
This API is **not backward compatible** and may evolve in future releases.
|
||||
|
||||
Note:
|
||||
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
|
||||
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
|
||||
|
||||
Args:
|
||||
annotation_dict (dict): A dictionary of custom key-value pairs to inject
|
||||
into the FX trace metadata for all operations in the function.
|
||||
|
||||
Example:
|
||||
All operations in my_function will have {"pp_stage": 1} in their metadata.
|
||||
|
||||
>>> @annotate_fn({"pp_stage": 1})
|
||||
... def my_function(x):
|
||||
... return x + 1
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with annotate(annotation_dict):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def set_grad_fn_seq_nr(seq_nr):
|
||||
global current_meta
|
||||
|
@ -218,6 +218,7 @@ class FlexKernelOptions(TypedDict, total=False):
|
||||
waves_per_eu: NotRequired[int]
|
||||
"""ROCm-specific waves per execution unit."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
force_flash: NotRequired[bool]
|
||||
""" If True, forces use of the cute-dsl flash attention kernel.
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from . import parametrizations, parametrize, rnn, stateless
|
||||
from .clip_grad import (
|
||||
from .clip_grad import ( # pyrefly: ignore # deprecated
|
||||
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
||||
_get_total_norm as get_total_norm,
|
||||
clip_grad_norm,
|
||||
|
@ -283,6 +283,7 @@ def clip_grad_value_(
|
||||
clip_value = float(clip_value)
|
||||
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grouped_grads = _group_tensors_by_device_and_dtype([grads])
|
||||
|
||||
for (device, _), ([grads], _) in grouped_grads.items():
|
||||
|
@ -111,8 +111,10 @@ class _Orthogonal(Module):
|
||||
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
|
||||
|
||||
if hasattr(self, "base"):
|
||||
# pyrefly: ignore # unbound-name
|
||||
Q = self.base @ Q
|
||||
if transposed:
|
||||
# pyrefly: ignore # unbound-name
|
||||
Q = Q.mT
|
||||
return Q # type: ignore[possibly-undefined]
|
||||
|
||||
|
@ -170,6 +170,7 @@ class TorchTensor(ir.Tensor):
|
||||
|
||||
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
|
||||
raise TypeError(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
|
||||
"with a tensor backed by real data using ONNXProgram.apply_weights() "
|
||||
"or save the model without initializers by setting include_initializers=False."
|
||||
|
@ -297,6 +297,7 @@ class AveragedModel(Module):
|
||||
avg_fn = get_swa_avg_fn()
|
||||
n_averaged = self.n_averaged.to(device)
|
||||
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
||||
else:
|
||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||
|
@ -71,6 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
||||
nrows // 16, 16
|
||||
)
|
||||
).view(-1)
|
||||
# pyrefly: ignore # unbound-name
|
||||
outp = outp.index_copy(1, cols_permuted, outp)
|
||||
|
||||
# interleave_column_major_tensor
|
||||
|
@ -67,6 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return self.__class__(
|
||||
torch.Size([self.shape[-1], self.shape[0]]),
|
||||
packed=self.packed_t,
|
||||
|
@ -184,6 +184,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
outer_stride,
|
||||
) -> torch.Tensor:
|
||||
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return cls(
|
||||
shape=shape,
|
||||
packed=inner_tensors.get("packed", None),
|
||||
@ -413,6 +414,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass,
|
||||
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=sparse_tensor_cutlass,
|
||||
@ -499,6 +501,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
original_tensor, algorithm=algorithm, use_cutlass=True
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
@ -560,6 +563,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
cls, original_tensor: torch.Tensor
|
||||
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return cls(
|
||||
shape=original_tensor.shape,
|
||||
packed=torch._cslt_compress(original_tensor),
|
||||
@ -626,6 +630,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
packed = packed.view(original_tensor.shape[0], -1)
|
||||
packed_t = packed_t.view(original_tensor.shape[1], -1)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
|
@ -1336,6 +1336,7 @@ class Identity(sympy.Function):
|
||||
|
||||
def _sympystr(self, printer):
|
||||
"""Controls how sympy's StrPrinter prints this"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return f"({printer.doprint(self.args[0])})"
|
||||
|
||||
def _eval_is_real(self):
|
||||
|
Reference in New Issue
Block a user