Compare commits

...

18 Commits

Author SHA1 Message Date
98826fd37b [annotate] add annotate_fn function decorator 2025-10-17 09:23:55 -07:00
585b9dbb5e [async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068)
Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim).

When we decompose matmul by reduction dimension we result in partials that needs additional reduction,
we allocate memory for accumulator.

Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan).

scaled_mm is not supported yet for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
Approved by: https://github.com/ngimel
2025-10-16 20:14:39 +00:00
d795fb225a [RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
2025-10-16 20:07:09 +00:00
7df9aca529 [ROCm][Windows] Enable AOTriton runtime compile on Windows (#165538)
AOTriton uses prebuilt runtime binaries if the user's ROCm version matches the ones used to generate the prebuilt runtime. However, since there's no prebuilt runtime available for Windows, this check needs to be bypassed for Windows. This PR enables it by changing condition to always build AOTriton runtime from source on Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165538
Approved by: https://github.com/xinyazhang, https://github.com/jeffdaily
2025-10-16 19:51:43 +00:00
d4a713cd9c Change forkserver test to only run below 3.13.8 (#165667)
A multiprocessing bug is fixed in 3.13.8, see [https://docs.python.org/3.13/whatsnew/changelog.html](https://l.workplace.com/l.php?u=https%3A%2F%2Fdocs.python.org%2F3.13%2Fwhatsnew%2Fchangelog.html&h=AT0qUhHJq5c2UJvQaq9_MrSo0mVhwn1VOfq1nDQl2C1UOhDI80RMbzVayhG7LSAT1uYHKtkftKnBDwiGMhbw0YRvQLe5vwE01qejpPFautHvU3LXeOE1KChPykqz3qnCRzk7czu_iNzQ05shR4F1N_qYOzR5YxejA52ZZQ), [gh-126631](https://github.com/python/cpython/issues/126631)

So this test will fail when we update to python 3.13.8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165667
Approved by: https://github.com/malfet
2025-10-16 19:34:10 +00:00
5daef30b26 158232 Fix autocast cache incorrectly retaining no_grad state (#165068)
Fixes #158232
The autocast caching heuristic in `aten/src/ATen/autocast_mode.cpp:139` did not account for gradient mode state when deciding whether to cache. FSDP2 is not directly related.

~~This PR adds `GradMode::is_enabled()` check to caching condition. Caching is now disabled in `no_grad()` contexts to prevent storing tensors with incorrect gradient state. Ensures correctness at the cost of using cache.~~
This PR proposes separate caches for gradient-enabled and gradient-disabled modes.
Adds tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165068
Approved by: https://github.com/ngimel, https://github.com/janeyx99
2025-10-16 19:32:01 +00:00
6dedd34c31 [CD] Skip 12.9 build on Windows (#165665)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165665
Approved by: https://github.com/Camyll, https://github.com/malfet
2025-10-16 19:11:27 +00:00
a303d6dda9 [inductor] don't try to reorder loops for template (#165601)
fix https://github.com/pytorch/pytorch/issues/165579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165601
Approved by: https://github.com/yushangdi
2025-10-16 19:05:21 +00:00
7669ac9402 [ROCm] Add scaled_mm v2 support. (#165528)
Add mx fp4 support in Blas.cpp.
Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support.
Modify the tests under test_scaled_matmul_cuda accordingly.

PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise
115 test passed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165528
Approved by: https://github.com/jeffdaily
2025-10-16 18:36:41 +00:00
86fd4fc23e [DeviceMesh] Simplify unflatten method (#165556)
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
2025-10-16 18:36:16 +00:00
99097b6d89 [DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
2025-10-16 18:36:16 +00:00
eqy
a214371008 [FP8] Add other Blackwell compute-capabiilities to expected fail test_honor_sm_carveout (#165159)
CUTLASS SM hint also isn't working for other Blackwells, need green context for carveout

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165159
Approved by: https://github.com/Skylion007
2025-10-16 18:35:06 +00:00
7d87d7052e [inductor][bucketing] Fx collectives bucketing of multiple dtypes (#162470)
Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470
Approved by: https://github.com/eellison
2025-10-16 18:31:43 +00:00
1a34ff4e04 Fixing get_local_rank() variable missing when compiled (#165432)
Fixes #165215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165432
Approved by: https://github.com/bdhirsh
2025-10-16 18:20:34 +00:00
fe5ccb1a74 bf16 support for per tensor backward (#165362)
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`.

Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs  27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165362
Approved by: https://github.com/andrewor14
2025-10-16 17:47:01 +00:00
85586d7efc Make c7i the default for _linux-build.yml (#164747)
Use linux.c7i.2xlarge as the default runner for the _linux-build.yml workflow. In testing we found that switching from c5 - c7i grants a 15-20% faster build times despite c7i costing 5% more. This should reduce costs of jobs using _linux-build.yml.

Relates to pytorch/test-infra#7175.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164747
Approved by: https://github.com/atalman
2025-10-16 17:37:51 +00:00
e1d71a6b35 Revert "12/n : Remove fbandroid_compiler_flags (#165558)"
This reverts commit d7ffa8b8a29ba6071c51499c1df3d702d0a26f72.

Reverted https://github.com/pytorch/pytorch/pull/165558 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/165558#issuecomment-3411879769))
2025-10-16 17:18:56 +00:00
d61a9b88cf [DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
2025-10-16 17:01:44 +00:00
60 changed files with 1385 additions and 2546 deletions

View File

@ -241,7 +241,11 @@ def generate_libtorch_matrix(
arches += CUDA_ARCHES
arches += ROCM_ARCHES
elif os == "windows":
arches += CUDA_ARCHES
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
# in 2.10
windows_cuda_arches = CUDA_ARCHES.copy()
windows_cuda_arches.remove("12.9")
arches += windows_cuda_arches
if libtorch_variants is None:
libtorch_variants = [
"shared-with-deps",
@ -305,7 +309,11 @@ def generate_wheels_matrix(
if os == "linux":
arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES
elif os == "windows":
arches += CUDA_ARCHES + XPU_ARCHES
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
# in 2.10
windows_cuda_arches = CUDA_ARCHES.copy()
windows_cuda_arches.remove("12.9")
arches += windows_cuda_arches + XPU_ARCHES
elif os == "linux-aarch64":
# Separate new if as the CPU type is different and
# uses different build/test scripts

View File

@ -37,7 +37,7 @@ on:
runner:
required: false
type: string
default: "linux.2xlarge"
default: "linux.c7i.2xlarge"
description: |
Label of the runner this job should run on.
test-matrix:

View File

@ -788,256 +788,6 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_9-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cuda12_9-shared-with-deps-debug
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-debug-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cuda12_9-shared-with-deps-debug-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cuda12_9-shared-with-deps-debug
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-debug-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_9-shared-with-deps-debug-test
with:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
build_name: libtorch-cuda12_9-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type

View File

@ -788,256 +788,6 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_9-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cuda12_9-shared-with-deps-release
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cuda12_9-shared-with-deps-release-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cuda12_9-shared-with-deps-release
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_9-shared-with-deps-release-test
with:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
build_name: libtorch-cuda12_9-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type

File diff suppressed because it is too large Load Diff

View File

@ -118,9 +118,9 @@ jobs:
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running all other linters"
if [ "$CHANGED_FILES" = '*' ]; then
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
else
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
fi
quick-checks:

View File

@ -209,6 +209,46 @@ command = [
'@{{PATHSFILE}}'
]
[[linter]]
code = 'PYREFLY'
include_patterns = [
'torch/**/*.py',
'torch/**/*.pyi',
'torchgen/**/*.py',
'torchgen/**/*.pyi',
'functorch/**/*.py',
'functorch/**/*.pyi',
]
exclude_patterns = []
command = [
'python3',
'tools/linter/adapters/pyrefly_linter.py',
'--config=pyrefly.toml',
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',
'types-requests==2.27.25',
'types-pyyaml==6.0.2',
'types-tabulate==0.8.8',
'types-protobuf==5.29.1.20250403',
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'optree==0.17.0',
'types-openpyxl==3.1.5.20250919',
'types-python-dateutil==2.9.0.20251008'
]
[[linter]]
code = 'CLANGTIDY'
include_patterns = [

View File

@ -2,6 +2,7 @@
#include <mutex>
#include <ATen/CachedTensorUtils.h>
#include <c10/core/GradMode.h>
#include <c10/util/flat_hash_map.h>
namespace at::autocast {
@ -36,10 +37,29 @@ namespace {
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
return cached_casts;
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
// are not incorrectly reused in gradient-enabled contexts.
// This fixes issue #158232 while maintaining optimal performance for both modes.
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
return cached_casts_grad_enabled;
}
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
return cached_casts_grad_disabled;
}
// Helper function to get the appropriate cache based on current gradient mode.
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
// preventing incorrect cache hits when gradient mode changes.
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
return at::GradMode::is_enabled() ?
get_cached_casts_grad_enabled() :
get_cached_casts_grad_disabled();
}
std::mutex cached_casts_mutex;
@ -86,7 +106,9 @@ thread_local bool cache_enabled = true;
void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
get_cached_casts().clear();
// Clear both caches to ensure consistent behavior regardless of current gradient mode
get_cached_casts_grad_enabled().clear();
get_cached_casts_grad_disabled().clear();
}
int increment_nesting() {
@ -121,6 +143,11 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
// See cached_casts declaration above for detailed strategy.
//
// We maintain separate caches for gradient-enabled and gradient-disabled modes
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
// with torch.autocast(), while maintaining optimal performance for both training and inference.
// This fixes issue #158232 without any performance regression.
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
arg.is_leaf() && !arg.is_view() && cache_enabled &&

View File

@ -1759,6 +1759,7 @@ enum class ScaledGemmImplementation {
MXFP8_MXFP8 = 6,
NVFP4_NVFP4 = 7,
NVFP4_NVFP4_SINGLE_SCALE = 8,
MXFP4_MXFP4 = 9,
};
/**
@ -1955,10 +1956,39 @@ bool check_mxfp8_recipe(c10::ScalarType type_a,
return true;
}
/**
* Both inputs must be fp4
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
using namespace std::placeholders;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
@ -1969,7 +1999,8 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
ScaledGemmImplementation::BLOCK_1x128_1x128},
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}};
Tensor&
_scaled_tensorwise_tensorwise(
@ -2187,15 +2218,22 @@ _scaled_mxfp8_mxfp8(
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
#ifdef USE_ROCM
auto scale_a_elems = ceil_div<int64_t>(mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(mat_b.size(1), 32) * mat_b.size(0);
#else
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4);
#endif
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifndef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
@ -2225,6 +2263,56 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
_scaled_mxfp4_mxfp4(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const SwizzleType swizzle_a,
const Tensor& scale_b, const SwizzleType swizzle_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
#if ROCM_VERSION >= 70000
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 &&
mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0,
"Matrix dimensions must be multiples of 32 for block-wise scaling");
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
_scaled_nvfp4_nvfp4(
const Tensor& mat_a, const Tensor& mat_b,
@ -2468,6 +2556,8 @@ _scaled_mm_cuda_v2_out(
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
} else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else {
TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really");
}

View File

@ -184,15 +184,23 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
0 & \text{ else }
\end{cases}
*/
float scale_val = scale[0].item<float>();
float inv_scale_val = 1.0f / scale_val;
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
float scale_val = scale_[0].item<float>();
float inv_scale_val = 1.0f / scale_val;
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false);
TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size");
TORCH_CHECK(
quant_min <= 0 && quant_max >= 0,
"`quant_min` should be less than or \
@ -200,28 +208,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
TORCH_CHECK(
zero_point_val >= quant_min && zero_point_val <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");
if (X.numel() <= 0) {
if (X_.numel() <= 0) {
return std::make_tuple(X, scale, zero_point);
}
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto iter = TensorIteratorConfig()
.add_output(dX)
.add_output(dScale_vec)
.add_output(dZeroPoint_vec)
.add_input(X)
.add_input(dY)
.add_input(X_)
.add_input(dY_)
.build();
fake_quant_grad_learnable_tensor_stub(
X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
// The total sums over the scale and zero point gradient vectors are what will be returned in the end.
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device());
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device());
return std::make_tuple(dX, dScale, dZeroPoint);
}

View File

@ -1729,10 +1729,8 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2025,9 +2023,6 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -244,7 +244,8 @@ if(NOT __AOTRITON_INCLUDED)
else()
set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}")
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX)
if(${__AOTRITON_RUNTIME_INDEX} LESS 0)
# Always build aotriton runtime from source on Windows due to lack of pre-built binaries
if(${__AOTRITON_RUNTIME_INDEX} LESS 0 OR WIN32)
message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \
Build runtime from source")
aotriton_build_from_source(ON aotriton_runtime)

View File

@ -1,5 +1,7 @@
# A Pyrefly configuration for PyTorch
# Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini
python-version = "3.12"
project-includes = [
"torch",
"caffe2",
@ -36,6 +38,7 @@ project-excludes = [
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
"torch/_inductor/codecache.py",
"torch/distributed/elastic/metrics/__init__.py",
"torch/_inductor/fx_passes/bucketing.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -266,6 +266,41 @@ def forward(self, b_parametrizations_buffer_original0, x):
compiled_out = compiled_fn(mesh)
self.assertEqual(opt_fn, compiled_out)
def test_get_local_rank_compile(self):
mesh = init_device_mesh(
self.device_type, (self.world_size,), mesh_dim_names=("dp",)
)
def fn_with_str_arg(x):
local_rank = x.device_mesh.get_local_rank("dp")
return x * local_rank
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
ref = fn_with_str_arg(x)
opt_fn = torch.compile(fn_with_str_arg, backend="aot_eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
def fn_with_int_arg(x):
local_rank = x.device_mesh.get_local_rank(0)
return x * local_rank
ref2 = fn_with_int_arg(x)
opt_fn2 = torch.compile(fn_with_int_arg, backend="aot_eager", fullgraph=True)
res2 = opt_fn2(x)
self.assertEqual(res2, ref2)
def fn_without_arg(x):
# will fail if device_mesh.ndim > 1
local_rank = x.device_mesh.get_local_rank()
return x + local_rank
ref3 = fn_without_arg(x)
opt_fn3 = torch.compile(fn_without_arg, backend="aot_eager", fullgraph=True)
res3 = opt_fn3(x)
self.assertEqual(res3, ref3)
def test_fakify_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)

View File

@ -1804,6 +1804,63 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = f(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all_custom_ops_multidtype"])
def test_all_gather_bucket_multidtype(self, bucket_mode):
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ag_0_w = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0, group_size, group_name
)
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_w)
ag_0_out = ag_0_out * 2
ag_1_w = torch.ops._c10d_functional.all_gather_into_tensor(
ag_1, group_size, group_name
)
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_w)
return y, ag_0_out, ag_1_out
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.bfloat16)
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
correct = func(*inputs, **self.get_world_trs())
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": bucket_mode,
"reorder_for_compute_comm_overlap": False,
}
):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
count=1,
exactly=True,
)
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
_, y_ag0, y_ag1 = out
assert y_ag0.dtype == ag_0.dtype
assert y_ag1.dtype == ag_1.dtype
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all", "all_custom_ops"])

View File

@ -294,7 +294,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
@parametrize("gather_dim", [0, 1, 2])
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
self._init_process()
@ -306,7 +306,10 @@ class AsyncTPTest(MultiProcContinuousTest):
rank = self.rank
torch.manual_seed(42 + rank)
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
A_shard_shape = [BATCH, M, K]
A_shard_shape[gather_dim] //= self.world_size
A_shard = torch.rand(A_shard_shape, device="cuda")
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
@ -523,7 +526,7 @@ class AsyncTPTest(MultiProcContinuousTest):
BATCH = 8
M = 64
N = 16
K = 32
K = 1024
group = dist.group.WORLD
rank = self.rank

View File

@ -57,7 +57,7 @@ def graph_capture(model, inputs, with_export):
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
gm,
inputs,
)
return joint_with_descriptors.graph_module
@ -922,6 +922,46 @@ class inner_f(torch.nn.Module):
in custom_metadata
)
def test_preserve_annotate_function(self):
"""Test basic annotate_fn usage"""
@fx_traceback.annotate_fn({"pp_stage": 1})
def example_function(x):
return x * x
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
y = self.linear(x)
y = example_function(y)
return y - 1
inputs = (torch.randn(4, 3),)
model = SimpleLinear()
for with_export in [True, False]:
graph_module = graph_capture(model, inputs, with_export)
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 't', {'pp_stage': 0})
('call_function', 'addmm', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 1})
('call_function', 'mul_1', {'pp_stage': 1})
('call_function', 'mul_2', {'pp_stage': 1})
('call_function', 't_1', {'pp_stage': 0})
('call_function', 'mm', {'pp_stage': 0})
('call_function', 't_2', {'pp_stage': 0})
('call_function', 'sum_1', {'pp_stage': 0})
('call_function', 'view', {'pp_stage': 0})
('call_function', 't_3', {'pp_stage': 0})""",
)
if __name__ == "__main__":
run_tests()

View File

@ -589,6 +589,31 @@ class LoopOrderingTest(TestCase):
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
).run(code[0])
@inductor_config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
"test_configs.max_mm_configs": 4,
}
)
@skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune")
def test_interaction_with_multi_template(self):
"""
Skip MultiTemplateBuffer during loop reordering
"""
@torch.compile
def f(x, y):
return (x @ y), x + 1
N = 2
x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
out, code = run_and_get_code(f, x, y)
# didn't fuse due to small savings
FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0])
def test_fuse_with_scalar_shared_memory(self):
"""
Make sure if we can fuse two nodes sharing a scalar before,

View File

@ -51,11 +51,18 @@ def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, qu
return res.to(dtype)
# Reference method for the gradients of the fake quantize operator
def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device):
def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device, dtype):
r"""This method references the following literatures for back propagation on scale and zero point.
- https://arxiv.org/pdf/1902.08153.pdf
- https://arxiv.org/pdf/1903.08066.pdf
"""
if dtype is torch.bfloat16:
dY = dY.to(dtype=torch.float32)
X = X.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32)
zero_point = zero_point.to(dtype=torch.float32)
zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item())
Xq = torch.round(X * (1.0 / scale) + zero_point_rounded)
@ -87,6 +94,12 @@ def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero
grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0)
grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0)
if dtype is torch.bfloat16:
grad_X = grad_X.to(torch.bfloat16)
grad_scale = grad_scale.to(torch.bfloat16)
grad_zp = grad_zp.to(torch.bfloat16)
return grad_X, grad_scale, grad_zp
@ -467,7 +480,7 @@ class TestFakeQuantizeOps(TestCase):
self._test_learnable_forward_per_tensor(
X, 'cuda', scale_base, zero_point_base)
def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base):
def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base, dtype=torch.float32):
r"""Tests the backward method with additional backprop support for scale and zero point.
"""
X_base = torch.tensor(X).to(device)
@ -475,7 +488,7 @@ class TestFakeQuantizeOps(TestCase):
for n_bits in (4, 8):
quant_min, quant_max = 0, 2 ** n_bits - 1
X = X_base.clone().float().to(device)
X = X_base.clone().to(device)
X.requires_grad_()
scale_base = scale_base.to(device)
zero_point_base = zero_point_base.to(device)
@ -488,7 +501,7 @@ class TestFakeQuantizeOps(TestCase):
X, scale, zero_point, quant_min, quant_max, grad_factor).to(device)
dout = torch.rand_like(X, dtype=torch.float).to(device)
dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference(
dout, X, scale, zero_point, quant_min, quant_max, device)
dout, X, scale, zero_point, quant_min, quant_max, device, dtype)
Y_prime.backward(dout)
expected_dX = dX.to(device).detach()
@ -525,17 +538,20 @@ class TestFakeQuantizeOps(TestCase):
self._test_learnable_backward_per_tensor(
X, 'cpu', scale_base, zero_point_base)
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_learnable_backward_per_tensor_cuda(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, _) = X
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(1,))
self._test_learnable_backward_per_tensor(
X, 'cuda', scale_base, zero_point_base)
def test_learnable_backward_per_tensor_cuda(self):
# setting seed to avoid increasing tolerance due to cases where
# difference in Python vs CPP downcasting causes tensor mismatches
# e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op
torch.random.manual_seed(12)
x_shape = (2, 1)
for dtype in [torch.bfloat16, torch.float32]:
X_base = torch.randn(x_shape, dtype=dtype, device='cuda')
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100).to(dtype=dtype)
zero_point_base = torch.normal(mean=0, std=128, size=(1,)).to(dtype=dtype)
self._test_learnable_backward_per_tensor(
X_base, 'cuda', scale_base, zero_point_base, dtype)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),

View File

@ -384,6 +384,143 @@ class TestTorchAutocast(TestCase):
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
torch.autocast(device_type=dev)
@skipIfTorchDynamo()
def test_autocast_nograd_caching_issue_158232(self):
"""
Regression test for issue #158232: autocast + no_grad incompatibility
When torch.no_grad() is nested inside torch.autocast(), the autocast cache
must not cache tensors created in the no_grad context, because they lack
gradient tracking. If cached, subsequent operations in gradient-enabled mode
would incorrectly use the no-gradient cached version.
Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
After fix: Should work correctly
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# First forward pass in no_grad context (e.g., shape inference)
with torch.no_grad():
out1 = model(inp)
self.assertFalse(
out1.requires_grad, "Output in no_grad should not require grad"
)
# Second forward pass with gradients enabled (e.g., training)
out2 = model(inp)
self.assertTrue(
out2.requires_grad,
"Output should require gradients after exiting no_grad",
)
self.assertIsNotNone(
out2.grad_fn, "Output should have grad_fn after exiting no_grad"
)
# Backward pass should work
loss = out2.mean()
loss.backward()
# Verify gradients were computed
self.assertIsNotNone(model.weight.grad)
self.assertIsNotNone(model.bias.grad)
@skipIfTorchDynamo()
def test_autocast_inference_mode_interaction(self):
"""
Test that autocast works correctly with torch.inference_mode()
InferenceMode is a stricter version of no_grad that provides additional
performance optimizations. Verify it doesn't break with autocast.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
# Test 1: inference_mode inside autocast
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
with torch.inference_mode():
out1 = model(inp)
self.assertFalse(out1.requires_grad)
self.assertEqual(out1.dtype, torch.bfloat16)
# After exiting inference_mode, gradients should work
out2 = model(inp)
self.assertTrue(out2.requires_grad)
out2.mean().backward()
# Test 2: autocast inside inference_mode
with torch.inference_mode():
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
out = model(inp)
self.assertFalse(out.requires_grad)
self.assertEqual(out.dtype, torch.bfloat16)
def test_autocast_caching_still_works_with_gradients(self):
"""
Verify that autocast caching still functions correctly when gradients ARE enabled.
This test ensures the fix for #158232 didn't break normal caching behavior.
We can't directly observe cache hits, but we verify that repeated operations
with gradients enabled work correctly.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# Multiple forward passes with gradients enabled
out1 = model(inp)
out2 = model(inp)
out3 = model(inp)
# All should have gradients
self.assertTrue(out1.requires_grad)
self.assertTrue(out2.requires_grad)
self.assertTrue(out3.requires_grad)
# All should have grad_fn
self.assertIsNotNone(out1.grad_fn)
self.assertIsNotNone(out2.grad_fn)
self.assertIsNotNone(out3.grad_fn)
# Backward should work on all
out1.mean().backward(retain_graph=True)
out2.mean().backward(retain_graph=True)
out3.mean().backward()
@skipIfTorchDynamo()
def test_autocast_mixed_grad_contexts(self):
"""
Test complex nesting of gradient contexts within autocast.
This ensures the gradient mode check works correctly across
multiple transitions between gradient-enabled and disabled states.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# Pass 1: no_grad
with torch.no_grad():
out1 = model(inp)
self.assertFalse(out1.requires_grad)
# Pass 2: gradients enabled
out2 = model(inp)
self.assertTrue(out2.requires_grad)
# Pass 3: no_grad again
with torch.no_grad():
out3 = model(inp)
self.assertFalse(out3.requires_grad)
# Pass 4: gradients enabled again
out4 = model(inp)
self.assertTrue(out4.requires_grad)
# Backward on gradient-enabled outputs
(out2.mean() + out4.mean()).backward()
if __name__ == "__main__":
run_tests()

View File

@ -265,6 +265,12 @@ class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
)
class ParallelForkServerPerfTest(TestCase):
@unittest.skipIf(
sys.version_info >= (3, 13, 8),
"Python 3.13.8+ changed forkserver module caching behavior",
# https://docs.python.org/3.13/whatsnew/changelog.html
# gh-126631
)
def test_forkserver_perf(self):
start_method = 'forkserver'

View File

@ -152,15 +152,34 @@ def infer_scale_swizzle(mat, scale):
):
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
# MX
# MXFP4 w/o swizzle
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
if not torch.version.hip:
# MXFP8 w/ swizzle
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
else:
# MXFP8 w/o swizzle
if (
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
return None, None
@ -1211,7 +1230,7 @@ class TestFP8Matmul(TestCase):
self.assertEqual(no_carveout, no_carveout_again)
capability = torch.cuda.get_device_capability()
if capability == (10, 0):
if capability in {(10, 0), (10, 3), (12, 0), (12, 1)}:
# expected failure
# CUTLASS only supports SM carveout via green contexts on SM100
self.assertEqual(no_carveout, carveout_66)
@ -1489,7 +1508,7 @@ class TestFP8Matmul(TestCase):
assert sqnr.item() > approx_match_sqnr_target
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
@parametrize("recipe", ["mxfp8", "nvfp4"])
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
M, K, N = (1024, 512, 2048)
BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32
@ -1503,7 +1522,7 @@ class TestFP8Matmul(TestCase):
if recipe == "mxfp8":
x_lowp = x.to(e4m3_type)
y_lowp = y.to(e4m3_type).t()
else: # nvfp4
else: # nvfp4 #mxfp4
x_lowp = _bfloat16_to_float4_e2m1fn_x2(x.bfloat16())
y_lowp = _bfloat16_to_float4_e2m1fn_x2(y.bfloat16()).t()
@ -1517,7 +1536,10 @@ class TestFP8Matmul(TestCase):
if recipe == "nvfp4"
else ScalingType.BlockWise1x32
)
swizzle = SwizzleType.SWIZZLE_32_4_4
if torch.version.hip:
swizzle = SwizzleType.NO_SWIZZLE
else:
swizzle = SwizzleType.SWIZZLE_32_4_4
# Test wrong scale tensor size for scale_a with correct dtype
with self.assertRaisesRegex(

View File

@ -0,0 +1,258 @@
from __future__ import annotations
import argparse
import json
import logging
import os
import re
import subprocess
import sys
import time
from enum import Enum
from typing import NamedTuple
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
# Note: This regex pattern is kept for reference but not used for pyrefly JSON parsing
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
(?:(?P<column>-?\d+):)?
\s(?P<severity>\S+?):?
\s(?P<message>.*)
\s(?P<code>\[.*\])
$
"""
)
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
\s(?P<severity>\S+?):?
\s(?P<message>INTERNAL\sERROR.*)
$
"""
)
def run_command(
args: list[str],
*,
extra_env: dict[str, str] | None,
retries: int,
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
return subprocess.run(
args,
capture_output=True,
)
finally:
end_time = time.monotonic()
logging.debug("took %dms", (end_time - start_time) * 1000)
# Severity mapping (currently only used for stderr internal errors)
# Pyrefly JSON output doesn't include severity, so all errors default to ERROR
severities = {
"error": LintSeverity.ERROR,
"note": LintSeverity.ADVICE,
}
def check_pyrefly_installed(code: str) -> list[LintMessage]:
cmd = ["pyrefly", "--version"]
try:
subprocess.run(cmd, check=True, capture_output=True)
return []
except subprocess.CalledProcessError as e:
msg = e.stderr.decode(errors="replace")
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=f"Could not run '{' '.join(cmd)}': {msg}",
)
]
def in_github_actions() -> bool:
return bool(os.getenv("GITHUB_ACTIONS"))
def check_files(
code: str,
config: str,
) -> list[LintMessage]:
try:
pyrefly_commands = [
"pyrefly",
"check",
"--config",
config,
"--output-format=json",
]
proc = run_command(
[*pyrefly_commands],
extra_env={},
retries=0,
)
except OSError as err:
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
)
]
stdout = str(proc.stdout, "utf-8").strip()
stderr = str(proc.stderr, "utf-8").strip()
if proc.returncode not in (0, 1):
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=stderr,
)
]
# Parse JSON output from pyrefly
try:
if stdout:
result = json.loads(stdout)
errors = result.get("errors", [])
else:
errors = []
# For now filter out deprecated warnings and only report type errors as warnings
# until we remove mypy
errors = [error for error in errors if error["name"] != "deprecated"]
rc = [
LintMessage(
path=error["path"],
name=error["name"],
description=error.get(
"description", error.get("concise_description", "")
),
line=error["line"],
char=error["column"],
code=code,
severity=LintSeverity.ADVICE,
# uncomment and replace when we switch to pyrefly
# severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR,
original=None,
replacement=None,
)
for error in errors
]
except (json.JSONDecodeError, KeyError, TypeError) as e:
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="json-parse-error",
original=None,
replacement=None,
description=f"Failed to parse pyrefly JSON output: {e}",
)
]
# Still check stderr for internal errors
rc += [
LintMessage(
path=match["file"],
name="INTERNAL ERROR",
description=match["message"],
line=int(match["line"]),
char=None,
code=code,
severity=severities.get(match["severity"], LintSeverity.ERROR),
original=None,
replacement=None,
)
for match in INTERNAL_ERROR_RE.finditer(stderr)
]
return rc
def main() -> None:
parser = argparse.ArgumentParser(
description="pyrefly wrapper linter.",
fromfile_prefix_chars="@",
)
parser.add_argument(
"--code",
default="PYREFLY",
help="the code this lint should report as",
)
parser.add_argument(
"--verbose",
action="store_true",
help="verbose logging",
)
parser.add_argument(
"--config",
required=True,
help="path to an mypy .ini config file",
)
args = parser.parse_args()
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.INFO,
stream=sys.stderr,
)
lint_messages = check_pyrefly_installed(args.code) + check_files(
args.code, args.config
)
for lint_message in lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)
if __name__ == "__main__":
main()

View File

@ -299,7 +299,11 @@ class DeviceMeshVariable(DistributedVariable):
if name == "get_rank":
return ConstantVariable.create(self.value.get_rank())
if name == "get_local_rank":
return ConstantVariable.create(self.value.get_local_rank())
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
return ConstantVariable.create(
self.value.get_local_rank(*const_args, **const_kwargs)
)
if name == "get_group":
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}

View File

@ -76,6 +76,7 @@ class StreamVariable(VariableTracker):
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
# pyrefly: ignore # read-only
self.device = device
def python_type(self) -> type:

View File

@ -1492,6 +1492,7 @@ def _aot_stage2a_partition(
# apply joint_gm callback here
if callable(torch._functorch.config.joint_custom_pass):
# pyrefly: ignore # bad-assignment
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
static_lifetime_input_indices = fw_metadata.static_input_indices
@ -1761,6 +1762,7 @@ def _aot_stage2b_bw_compile(
# tensor which is wrong.
ph_size = ph_arg.size()
# pyrefly: ignore # bad-argument-type
if len(ph_size) == 0 and len(real_stride) > 0:
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
# (e.g., via squeeze), its stride should be () not (1,).

View File

@ -628,6 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
]
# add the scale nodes to the output find the first sym_node in the output
# pyrefly: ignore # bad-argument-type
idx = find_first_sym_node(output_updated_args)
scale_nodes = tensor_scale_nodes + sym_scale_nodes
if scale_nodes:

View File

@ -1,7 +1,8 @@
import collections
import logging
import operator
from collections import defaultdict
from typing import Any, Callable
from typing import Any, Callable, Literal, TypeAlias
import torch
import torch.distributed as dist
@ -17,16 +18,24 @@ from torch.utils._ordered_set import OrderedSet
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
# Helper functions moved to top for better organization
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined]
_, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
return (group_name, dtype)
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
_, group_size, group_name = node.args
assert isinstance(group_name, str)
return (group_name,)
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined]
_, reduce_op, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
@ -53,6 +62,11 @@ def bucket_key(node: torch.fx.Node) -> object | None:
return None
def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
assert len(dtypes) > 0
return min(dtypes, key=operator.attrgetter("itemsize"))
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.
@ -69,15 +83,15 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
def bucket_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default,
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets, mode)
@ -86,15 +100,17 @@ def bucket_all_gather(
def bucket_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default,
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
rs_buckets = bucket_reduce_scatter_by_mb(
gm, bucket_cap_mb_by_bucket_idx, None, mode
)
if len(rs_buckets) == 0:
return
merge_reduce_scatter(gm, rs_buckets, mode)
@ -252,6 +268,7 @@ def bucket_all_gather_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
mode: BucketMode = "default",
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets,
@ -271,11 +288,15 @@ def bucket_all_gather_by_mb(
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
"""
group_key_fn = (
_ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
)
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
is_all_gather_into_tensor,
_ag_group_key,
group_key_fn,
filter_wait_node,
)
@ -284,6 +305,7 @@ def bucket_reduce_scatter_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
mode: BucketMode = "default",
) -> list[list[torch.fx.Node]]:
"""
Identifies all reduce_scatter nodes and groups them into buckets,
@ -301,6 +323,10 @@ def bucket_reduce_scatter_by_mb(
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
"""
assert "multidtype" not in mode, (
"reduce scatter bucketing does not support multidtype"
)
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
@ -439,13 +465,17 @@ def _pre_bucket_all_gather(
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> torch.Tensor:
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
device = ag_ins[0].device
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
return new_ag_out
@ -457,7 +487,11 @@ def _pre_bucket_all_gather_fake(
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> torch.Tensor:
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
device = ag_ins[0].device
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
@ -468,14 +502,28 @@ _pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
def all_gather_merge_fn_to_trace_custom_ops(
ag_ins: list[torch.Tensor],
_ag_ins: list[torch.Tensor],
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
) -> list[torch.Tensor]:
ag_ins = [
torch._prims.convert_element_type(_ag_in, out_dtype)
if _ag_in.dtype != out_dtype
else _ag_in
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
]
ins_sizes = [ag_in.shape for ag_in in ag_ins]
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ins_split_sizes_bytes = [
ag_in.numel() * out_dtype.itemsize
for ag_in, out_dtype in zip(ag_ins, out_dtypes)
]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
ag_ins, group_size, group_name, dtype, rank
@ -487,14 +535,14 @@ def all_gather_merge_fn_to_trace_custom_ops(
)
)
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
outs = torch.split_with_sizes(
outs_bucket_dtype = torch.split_with_sizes(
new_ag_out_reshaped,
ins_split_sizes,
dim=1,
)
outs_reshaped = [
o.reshape((shape[0] * group_size,) + shape[1:])
for o, shape in zip(outs, ins_sizes)
o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
]
return outs_reshaped
@ -504,6 +552,7 @@ def all_gather_merge_fn_to_trace(
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
) -> list[torch.Tensor]:
ins_sizes = [ag_in.shape for ag_in in ag_ins]
@ -538,6 +587,7 @@ def all_gather_merge_fn_to_trace_functional(
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
use_fsdp_ag_copy_in: bool = False,
) -> list[torch.Tensor]:
@ -733,7 +783,7 @@ def process_collective_bucket(
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: str | None = None,
mode: BucketMode = "default",
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
@ -826,29 +876,27 @@ def merge_all_reduce_bucket(
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: str | None = None,
mode: BucketMode = "default",
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
from torch.distributed.distributed_c10d import _resolve_process_group
ag0 = ag_nodes[0]
ag0_val = ag0.meta["val"]
_, group_size, group_name = ag0.args
dtype = ag0_val.dtype
assert isinstance(group_name, str)
_ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined]
for n in ag_nodes:
assert (
n.args[1] == group_size
and n.args[2] == group_name
and n.meta["val"].dtype == dtype
)
assert n.args[1] == group_size and n.args[2] == group_name
_ag_dtypes.append(n.meta["val"].dtype)
bucket_dtype = pick_bucket_dtype(_ag_dtypes)
# Choose merge function based on mode
ag_merge_fn = all_gather_merge_fn_to_trace
if mode and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
if mode is not None and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment]
# Process bucket with lazy input collection
rank: int = dist.get_rank(_resolve_process_group(group_name))
@ -858,7 +906,8 @@ def merge_all_gather_bucket(
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
group_size,
group_name,
dtype,
bucket_dtype,
_ag_dtypes,
rank,
)
@ -874,7 +923,7 @@ def merge_all_gather_bucket(
def merge_reduce_scatter(
gm: torch.fx.GraphModule,
rs_buckets: list[list[torch.fx.Node]],
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Merges specified buckets of reduce_scatter to joint reduce_scatter.
@ -898,7 +947,7 @@ def merge_reduce_scatter(
def merge_all_gather(
gm: torch.fx.GraphModule,
ag_buckets: list[list[torch.fx.Node]],
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Merges specified buckets of all_gather to joint all_gather.

View File

@ -209,6 +209,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
int8_woq_fusion_replacement,
[val(), val(), val(), val(), scale(), scale(), scale()],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],
@ -230,6 +231,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
matmul_replacement,
[val(), val(), val(), val()],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],
@ -251,6 +253,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
matmul_replacement_two,
[val(), val(), val()],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],
@ -276,6 +279,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
addmm_fuse_replacement_second,
[val() for _ in range(7)],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],

View File

@ -5,6 +5,7 @@ import torch
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather_by_mb,
bucket_reduce_scatter_by_mb,
BucketMode,
merge_all_gather,
merge_reduce_scatter,
)
@ -56,7 +57,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
def bucket_fsdp_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Bucketing pass for SimpleFSDP all_gather ops.
@ -86,7 +87,7 @@ def bucket_fsdp_all_gather(
def bucket_fsdp_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Bucketing pass for SimpleFSDP reduce_scatter ops.

View File

@ -27,6 +27,10 @@ aten = torch.ops.aten
patterns = PatternMatcherPass()
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
return dim == t.ndim - 1 or dim == -1
def _is_backward(graph: torch.fx.Graph) -> bool:
placeholders = []
for node in graph.nodes:
@ -645,9 +649,17 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if not is_symm_mem_enabled_for_group(group_name):
return
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
# Decomposing the matmul on the K dimension is not supported
return
filter_matmul = None
if _is_last_dim(_get_tensor(shard_node), gather_dim):
# Decomposed mms should not be too small
if _get_tensor(shard_node).shape[-1] < 1024:
return
# scaled_mm is not supported yet for last dim
def _filter_out_scaled_matmul(matmul: _Matmul):
return not isinstance(matmul, _ScaledMatmul)
filter_matmul = _filter_out_scaled_matmul
# Find consumer matmuls
matmuls = _find_consumer_matmuls(ag_res_node)
@ -663,18 +675,29 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
return
if _is_last_dim(_get_tensor(shard_node), gather_dim) and len(
all_gather.res_node.users
) > len(matmuls):
# The result of ag-split-cat is used not only in matmuls.
# Then it has to be materialized, which can have overhead.
return
if filter_matmul and not filter_matmul(matmuls[0]):
return
# Fuse the all_gather_tensor with the eligible matmuls
graph = ag_node.graph
with graph.inserting_before(ag_node):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_matmul(
graph, matmuls, shard_node, gather_dim, group_name
@ -881,7 +904,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
return
filter_matmul = None
if orig_scatter_dim == _get_tensor(input_node).ndim - 1:
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
# scaled_mm is not supported yet for last dim mm+rs
def _filter_out_scaled_matmul(matmul: _Matmul):
return not isinstance(matmul, _ScaledMatmul)

View File

@ -49,6 +49,7 @@ def _misc_patterns_init():
# pyrefly: ignore # bad-argument-type
randperm_index_add_replacement,
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
[post_grad_patterns, joint_graph_patterns],
@ -68,6 +69,7 @@ def _misc_patterns_init():
# pyrefly: ignore # bad-argument-type
randperm_index_replacement,
[torch.empty(4, 8, device=device)],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
[post_grad_patterns, joint_graph_patterns],

View File

@ -919,6 +919,7 @@ def _pad_mm_init() -> None:
pattern,
replacement,
args,
# pyrefly: ignore # bad-argument-type
joint_fwd_bwd,
# pyrefly: ignore # bad-argument-type
patterns,
@ -931,6 +932,7 @@ def _pad_mm_init() -> None:
pattern,
replacement,
args,
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
patterns,

View File

@ -216,7 +216,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
lambda graph: p(
graph.owning_module,
config.bucket_reduce_scatters_fx_bucket_size_determinator,
config.bucket_reduce_scatters_fx,
config.bucket_reduce_scatters_fx, # type: ignore[arg-type]
)
)
collectives_bucketing = True
@ -236,7 +236,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
lambda graph: p(
graph.owning_module,
config.bucket_all_gathers_fx_bucket_size_determinator,
config.bucket_all_gathers_fx,
config.bucket_all_gathers_fx, # type: ignore[arg-type]
)
)
collectives_bucketing = True
@ -666,6 +666,7 @@ def lazy_init():
prepare_softmax_replacement,
[torch.empty(4, 8)],
scalar_workaround=dict(dim=-1),
# pyrefly: ignore # bad-argument-type
trace_fn=fwd_only,
# pyrefly: ignore # bad-argument-type
pass_dicts=pass_patterns[1],
@ -730,6 +731,7 @@ def register_lowering_pattern(
return pattern_matcher.register_lowering_pattern(
pattern,
extra_check,
# pyrefly: ignore # bad-argument-type
pass_dict=pass_patterns[pass_number],
)
@ -1573,6 +1575,7 @@ def register_partial_reduction_pattern():
@register_graph_pattern(
MultiOutputPattern([partial_reduc, full_reduc]),
# pyrefly: ignore # bad-argument-type
pass_dict=pass_patterns[2],
)
def reuse_partial(match, input, reduced_dims, keepdim):

View File

@ -27,7 +27,7 @@ from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import associative_scan_op
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._library.utils import get_layout_constraint_tag
from torch._prims_common import (
from torch._prims_common import ( # pyrefly: ignore # deprecated
canonicalize_dim,
canonicalize_dims,
check,

View File

@ -3994,6 +3994,12 @@ class Scheduler:
):
return -1
# in some rare case, a template can be passed in.
# Check test_interaction_with_multi_template in test_loop_ordering.py
# and https://github.com/pytorch/pytorch/issues/165579
if node1.is_template() or node2.is_template():
return -1
node1_buffer_names = node1.read_writes.buffer_names()
node2_buffer_names = node2.read_writes.buffer_names()
# Fast path: no common buffers.

View File

@ -173,6 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
)
_OPAQUE_TYPES[cls] = name
# pyrefly: ignore # missing-attribute
torch._C._register_opaque_type(name)
@ -182,4 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
"""
if cls not in _OPAQUE_TYPES:
return False
# pyrefly: ignore # missing-attribute
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])

View File

@ -135,7 +135,7 @@ if is_available():
# this.
# pyrefly: ignore # deprecated
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
from .distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base,
_coalescing_manager,
_CoalescingManager,

View File

@ -1009,8 +1009,8 @@ lib_impl.impl("broadcast", _broadcast_meta, "Meta")
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
# mark these ops has side effect so that they won't be removed by DCE
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type]
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type]
# Register legacy ops for backward compatibility
# TODO(yifu): remove these in functional collective beta release
@ -1176,7 +1176,7 @@ def all_gather_inplace(
return tensor_list
from torch.distributed.distributed_c10d import (
from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base as legacy_all_gather_base,
_reduce_scatter_base as legacy_reduce_scatter_base,
all_gather as legacy_all_gather,
@ -1190,11 +1190,11 @@ from torch.distributed.distributed_c10d import (
# This dict should contain sets of functions that dynamo is allowed to remap.
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
traceable_collective_remaps = {
legacy_allgather: all_gather_tensor_inplace,
legacy_reducescatter: reduce_scatter_tensor_inplace,
legacy_allreduce: all_reduce_inplace,
legacy_all_to_all_single: all_to_all_inplace,
legacy_all_gather: all_gather_inplace,
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
legacy_all_gather_base: all_gather_tensor_inplace,
legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type]
legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type]
legacy_allreduce: all_reduce_inplace, # type: ignore[has-type]
legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type]
legacy_all_gather: all_gather_inplace, # type: ignore[has-type]
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type]
legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type]
}

View File

@ -393,6 +393,7 @@ class LocalTensor(torch.Tensor):
def __repr__(self) -> str: # type: ignore[override]
parts = []
for k, v in self._local_tensors.items():
# pyrefly: ignore # bad-argument-type
parts.append(f" {k}: {v}")
tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)"
@ -680,6 +681,7 @@ class LocalTensorMode(TorchDispatchMode):
def _unpatch_device_mesh(self) -> None:
assert self._old_get_coordinate is not None
DeviceMesh.get_coordinate = self._old_get_coordinate
# pyrefly: ignore # bad-assignment
self._old_get_coordinate = None

View File

@ -316,6 +316,7 @@ def _local_all_gather_(
assert len(input_tensors) == 1
input_tensor = input_tensors[0]
# pyrefly: ignore # bad-assignment
output_tensors = output_tensors[0]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
@ -336,10 +337,12 @@ def _local_all_gather_(
source_tensor = input_tensor
if isinstance(input_tensor, LocalTensor):
source_tensor = input_tensor._local_tensors[rank_i]
# pyrefly: ignore # missing-attribute
output_tensors[i].copy_(source_tensor)
work = FakeWork()
work_so = Work.boxed(work)
# pyrefly: ignore # bad-return
return ([output_tensors], work_so)
@ -426,6 +429,7 @@ def _local_scatter_(
assert len(output_tensors) == 1
assert len(input_tensors) == 1
output_tensor = output_tensors[0]
# pyrefly: ignore # bad-assignment
input_tensors = input_tensors[0]
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)

View File

@ -9,6 +9,7 @@ from itertools import product
import torch
from torch.distributed._pycute import (
as_tuple,
coalesce,
complement,
composition,
@ -17,7 +18,6 @@ from torch.distributed._pycute import (
is_int,
is_tuple,
Layout,
suffix_product,
)
@ -79,6 +79,11 @@ class _MeshLayout(Layout):
# # operator [] (get-i like tuples)
def __getitem__(self, i: int) -> "_MeshLayout":
if i < -len(self) or i >= len(self):
raise IndexError(
f"Dim {i} is out of range for layout with {len(self)} dimensions. "
f"Expected dim to be in range [{-len(self)}, {len(self) - 1}]."
)
layout = super().__getitem__(i)
return _MeshLayout(layout.shape, layout.stride)
@ -156,50 +161,11 @@ class _MeshLayout(Layout):
layout = complement(self, world_size)
return _MeshLayout(layout.shape, layout.stride)
def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
"""
Unflatten a single dimension in the layout by splitting it into multiple dimensions.
It takes a dimension at position `dim` and splits it into multiple new dimensions
with the specified sizes.
Args:
dim (int): The index of the dimension to unflatten. Must be a valid dimension index.
unflatten_sizes (tuple[int, ...]): The new sizes for the dimensions that will replace
the original dimension at `dim`. The product of these sizes must equal the size
of the original dimension at `dim`.
Returns:
_MeshLayout: A new layout with the specified dimension unflattened.
Example:
Original: sizes=(8,), strides=(1,) # 8 ranks in 1D
Call: unflatten(0, (2, 2, 2)) # Create 3D topology
Result: sizes=(2, 2, 2), strides=(4, 2, 1) # 2*2*2 unflattened topology
"""
# Check that dim is within valid range
if dim < 0 or dim >= len(self):
raise ValueError(
f"dim {dim} is out of range for layout with {len(self)} dimensions. "
f"Expected dim to be in range [0, {len(self) - 1}]."
)
# Check that the product of unflatten_sizes equals the original dimension size
original_size = self[dim].numel()
unflatten_product = math.prod(unflatten_sizes)
if unflatten_product != original_size:
raise ValueError(
f"The product of unflatten_sizes {unflatten_sizes} is {unflatten_product}, "
f"but the original dimension at dim={dim} has size {original_size}. "
f"These must be equal for unflatten to work correctly."
)
sizes = list(self.sizes) # type: ignore[arg-type]
strides = list(self.strides) # type: ignore[arg-type]
unflatten_layout = self[dim].composition(
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
)
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout":
sizes = list(as_tuple(self.sizes))
strides = list(as_tuple(self.strides))
sizes[start:end] = list(as_tuple(layout.sizes))
strides[start:end] = list(as_tuple(layout.strides))
return _MeshLayout(tuple(sizes), tuple(strides))
def all_ranks_from_zero(self) -> list[int]:
@ -301,10 +267,7 @@ class _MeshLayout(Layout):
ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks))
def remap_to_tensor(
self,
mesh_tensor: torch.Tensor,
) -> torch.Tensor:
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
"""
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks.
@ -316,10 +279,7 @@ class _MeshLayout(Layout):
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation.
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
The shape of the `rank_map` must be 1D and contiguous.
Examples:
@ -336,18 +296,18 @@ class _MeshLayout(Layout):
Return: [[[10,30],[20,40]]]
Args:
mesh_tensor: The concrete mesh tensor with actual device ranks
rank_map: The concrete mesh tensor with actual device ranks
Returns:
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
torch.Tensor: A tensor representing the actual device allocation from rank_map
"""
complement_layout = self.complement(mesh_tensor.numel())
assert rank_map.ndim == 1
assert rank_map.is_contiguous()
assert rank_map.numel() >= self.cosize()
return (
mesh_tensor.flatten()
.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
)
.reshape(-1, *(self[i].numel() for i in range(len(self))))
)
complement_layout = self.complement(rank_map.numel())
return rank_map.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
).reshape(-1, *self.top_level_sizes)

View File

@ -31,6 +31,7 @@
#################################################################################################
from .int_tuple import (
as_tuple,
crd2crd,
crd2idx,
elem_scale,

View File

@ -54,6 +54,12 @@ def is_tuple(x: object) -> TypeIs[tuple]:
return isinstance(x, tuple)
def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]:
if is_int(x):
return (x,)
return x
def flatten(t: IntTuple) -> tuple[int, ...]:
if is_tuple(t):
if len(t) == 0:

View File

@ -524,6 +524,19 @@ def _fused_all_gather_matmul_impl(
group = c10d._resolve_process_group(group_name)
if gather_dim == A_shard.ndim - 1 or gather_dim == -1:
return _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op,
A_shard,
Bs,
A_scale,
kwargs_list,
out_dtypes,
gather_dim,
group_name,
return_A,
)
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
@ -624,6 +637,140 @@ def _fused_all_gather_matmul_impl(
return A, [unflatten(output) for output in outputs]
def _pipelined_all_gather_and_consume_last_dim(
shard: torch.Tensor,
shard_consumer: Callable[[torch.Tensor, int], None],
ag_out: torch.Tensor,
group_name: str,
ag_out_needed: bool = True,
) -> None:
p2p_workspace_size_req = 0
p2p_workspace_size_req = shard.numel() * shard.element_size()
symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
group_size = symm_mem.world_size
rank = symm_mem.rank
symm_mem.barrier(channel=0)
backend_stream = _get_backend_stream()
backend_stream.wait_stream(torch.cuda.current_stream())
def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None:
dst.copy_(src)
def get_p2p_buf(remote_rank: int) -> torch.Tensor:
buf = symm_mem.get_buffer(
remote_rank,
shard.shape,
shard.dtype,
)
return buf
local_p2p_buf = get_p2p_buf(rank)
shards = ag_out.chunk(group_size)
copy_shard(dst=local_p2p_buf, src=shard)
symm_mem.barrier(channel=1)
backend_stream.wait_stream(torch.cuda.current_stream())
# At this point, all ranks have copied their local shard to
# their local p2p buffer. Each rank can now copy and consume
# remote shards.
shard_consumer(shard, rank)
for step in range(1, group_size):
if step % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
remote_rank = (step + rank) % group_size
remote_p2p_buf = get_p2p_buf(remote_rank)
with stream:
copy_shard(dst=shards[remote_rank], src=remote_p2p_buf)
shard_consumer(shards[remote_rank], remote_rank)
if ag_out_needed:
# Copy from input to the all-gather output. Opportunistically overlap
# it with the last shard_consumer.
if group_size % 2 == 0:
stream = torch.cuda.current_stream()
else:
stream = backend_stream
with stream:
copy_shard(dst=shards[rank], src=shard)
torch.cuda.current_stream().wait_stream(backend_stream)
symm_mem.barrier(channel=0)
def _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op: torch._ops.OpOverload,
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
A_scale: torch.Tensor | None,
kwargs_list: list[dict[str, Any]],
out_dtypes: list[torch.dtype | None],
gather_dim: int,
group_name: str,
return_A: bool,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
group = c10d._resolve_process_group(group_name)
group_size = group.size()
B_shards = [B.chunk(group.size()) for B in Bs]
leading_dims = list(A_shard.shape[:-1])
A_shard_flat = A_shard.flatten(0, -2)
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1)
A_flat_out = A_shard_flat.new_empty(
A_shard_flat.shape[0] * group.size(),
A_shard_flat.shape[1],
)
outputs = [
torch.empty(
(A_shard_flat.shape[0], B.shape[1]),
dtype=out_dtype or B.dtype,
device=A_shard.device,
)
for B, out_dtype in zip(Bs, out_dtypes)
]
first = True
events = [torch.cuda.Event() for _ in outputs]
def default_consumer(shard: torch.Tensor, rank: int) -> None:
nonlocal first
for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list):
event.wait()
if first:
torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out)
else:
out.addmm_(shard, B_shard[rank])
event.record()
first = False
_pipelined_all_gather_and_consume_last_dim(
A_shard_flat,
default_consumer,
A_flat_out,
group_name,
return_A,
)
ret_A = None
if return_A:
# This path is inefficient and will be filtered out at passes stage
# Added only for completeness.
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
ret_A = unflatten(A_split_cat_out_flat)
return ret_A, [unflatten(output) for output in outputs]
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
@ -638,6 +785,15 @@ def _fused_all_gather_matmul_fallback(
A_shard.contiguous(), group_size, group_name
)
A = torch.ops._c10d_functional.wait_tensor(A)
if gather_dim == A.ndim - 1 or gather_dim == -1:
A_splits = A.chunk(group_size)
A_mm = torch.cat(A_splits, dim=-1)
res = [torch.matmul(A_mm, B) for B in Bs]
if return_A:
return A_mm, res
else:
return None, res
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs]
if return_A:

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import math
import os
import threading
import warnings
@ -12,7 +11,7 @@ from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.distributed._mesh_layout import _MeshLayout
from torch.distributed._pycute import is_int
from torch.distributed._pycute import is_int, suffix_product
from torch.utils._typing_utils import not_none
@ -173,7 +172,7 @@ else:
"""
_device_type: str
_mesh: torch.Tensor
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None
@ -183,60 +182,75 @@ else:
def __init__(
self,
device_type: str,
mesh: Union[torch.Tensor, "ArrayLike"],
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
self._device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
self._mesh = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * self.mesh.ndim
elif len(backend_override) != self.mesh.ndim:
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {self.mesh.ndim}."
if mesh is not None:
if _layout is not None or _rank_map is not None:
raise TypeError(
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
)
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
# Internal bookkeeping for the device mesh.
self._layout = (
_layout
if _layout
else _MeshLayout(self.mesh.size(), self.mesh.stride())
)
assert self._layout.check_non_overlap(), (
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
_rank_map = mesh_tensor.flatten()
else:
if _layout is None or _rank_map is None:
raise TypeError(
"The mesh argument is required except for PRIVATE USAGE ONLY!"
)
assert _layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.top_level_sizes == self.mesh.size(), (
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
assert _rank_map.numel() >= _layout.cosize(), (
f"The rank map contains {_rank_map.numel()} element, "
f"which isn't large enough for layout {_layout}"
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
self._device_type = device_type
self._layout = _layout
self._rank_map = _rank_map
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
self._root_mesh = _root_mesh
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
elif len(backend_override) != len(self._layout):
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {len(self._layout)}."
)
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not).
if _init_backend:
self._setup_world_group_and_device()
self._init_process_groups(backend_override)
self._dim_group_names = self._init_process_groups(
self._layout,
self._rank_map,
self._mesh_dim_names,
backend_override,
)
if is_initialized() and get_backend() == "threaded":
# pyrefly: ignore # bad-assignment
@ -252,6 +266,11 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
@ -260,7 +279,17 @@ else:
@property
def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices."""
return self._mesh
full_mesh = self._layout.remap_to_tensor(self._rank_map)
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
@property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
@ -275,9 +304,9 @@ else:
init_process_group()
world_size = get_world_size()
if self.mesh.numel() > world_size:
if self._layout.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
)
# ONLY set the device if the current device is not initialized, if user already
@ -317,10 +346,13 @@ else:
return _get_default_group()
@staticmethod
def _init_process_groups(
self,
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
):
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
#
@ -328,8 +360,8 @@ else:
default_group = _get_default_group()
if (
self.mesh.ndim == 1
and self.mesh.numel() == get_world_size()
len(layout) == 1
and layout.numel() == get_world_size()
and backend_override[0] == (None, None)
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
@ -348,12 +380,10 @@ else:
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(self.mesh.ndim):
for dim in range(len(layout)):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
-1, self.mesh.size(dim)
)
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
@ -365,8 +395,8 @@ else:
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = (
f"mesh_{self._mesh_dim_names[dim]}"
if self._mesh_dim_names
f"mesh_{mesh_dim_names[dim]}"
if mesh_dim_names
else f"mesh_dim_{dim}"
)
@ -424,14 +454,14 @@ else:
)
# only add to dim_groups if the current rank in the subgroup
if self.get_rank() in subgroup_ranks:
if get_rank() in subgroup_ranks:
if len(dim_group_names) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
self._dim_group_names = dim_group_names
return dim_group_names
def _get_root_mesh(self) -> "DeviceMesh":
return self._root_mesh if self._root_mesh else self
@ -448,14 +478,14 @@ else:
def __repr__(self) -> str:
device_mesh_repr = (
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
if self._mesh_dim_names
else f"{tuple(self._mesh.shape)}"
else f"{self._layout.top_level_sizes}"
)
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
# We only print the mesh tensor if the debug mode is turned on.
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
return f"{device_mesh_repr})"
def __hash__(self):
@ -465,7 +495,7 @@ else:
self._hash = hash(
(
self._flatten_mesh_list,
self._mesh.shape,
self._layout,
self._device_type,
self._mesh_dim_names,
self._thread_id,
@ -481,7 +511,7 @@ else:
return False
return (
self._flatten_mesh_list == other._flatten_mesh_list
and self._mesh.shape == other._mesh.shape
and self._layout == other._layout
and self._device_type == other._device_type
and self._mesh_dim_names == other._mesh_dim_names
and self._thread_id == other._thread_id
@ -573,16 +603,16 @@ else:
if not hasattr(self, "_dim_group_names"):
raise RuntimeError("DeviceMesh process groups not initialized!")
if self.mesh.ndim > 1 and mesh_dim is None:
if len(self._layout) > 1 and mesh_dim is None:
raise RuntimeError(
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
f"Found the DeviceMesh have {len(self._layout)} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
"please use `get_all_groups()` instead.",
)
# Quick return if the current device_mesh is a 1D mesh.
if self.mesh.ndim == 1 and mesh_dim is None:
if len(self._layout) == 1 and mesh_dim is None:
return not_none(_resolve_process_group(self._dim_group_names[0]))
root_mesh = self._get_root_mesh()
@ -608,7 +638,7 @@ else:
Returns:
A list of :class:`ProcessGroup` object.
"""
return [self.get_group(i) for i in range(self.mesh.ndim)]
return [self.get_group(i) for i in range(len(self._layout))]
def _create_sub_mesh(
self,
@ -634,18 +664,13 @@ else:
not_none(flatten_mesh._mesh_dim_names).index(name)
]
)
cur_rank = self.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor(
root_mesh.mesh,
)
res_submesh = DeviceMesh._create_mesh_from_ranks(
res_submesh = DeviceMesh(
self._device_type,
pg_ranks_by_dim,
cur_rank,
submesh_dim_names,
_init_backend=False,
_layout=layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=submesh_dim_names,
_root_mesh=root_mesh,
_init_backend=False,
)
res_submesh._dim_group_names = slice_dim_group_name
return res_submesh
@ -689,22 +714,13 @@ else:
f"Please specify another valid mesh_dim_name."
)
cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
root_mesh.mesh,
)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
res_flattened_mesh = DeviceMesh(
root_mesh._device_type,
pg_ranks_by_dim.flatten(
start_dim=1
), # this is needed for flatten non-contiguous mesh dims.
cur_rank,
(mesh_dim_name,),
(backend_override,),
_layout=flattened_mesh_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=(mesh_dim_name,),
_root_mesh=root_mesh,
backend_override=(backend_override,),
)
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
@ -833,9 +849,7 @@ else:
"""
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(
self.mesh,
)
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
cur_rank = self.get_rank()
res_submeshes = []
for mesh_1d in pg_ranks_by_dim:
@ -854,60 +868,6 @@ else:
return res_submeshes
@staticmethod
def _create_mesh_from_ranks(
device_type: str,
pg_ranks_by_dim: torch.Tensor,
cur_rank: int,
mesh_dim_names: tuple[str, ...],
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_layout: Optional[_MeshLayout] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> "DeviceMesh":
"""
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
the constraint of ProcessGroup API that all ranks have to call the PG creation API
even if the rank is not in that PG.
We will create a potentially very large number of DeviceMesh objects
(e.g., on 1024 GPUs with TP=2, this could be up to 512 DeviceMeshes), only to throw
them all away except when the mesh contains the current rank.
#TODO: Further refactor this method once we relax the ProcessGroup API constraint.
Args:
device_type: The device type of the mesh.
pg_ranks_by_dim: all ranks within the worlds organized by dimensions.
cur_rank: The current global rank in the mesh.
mesh_dim_names: Mesh dimension names.
backend_override: Optional backend override for the mesh.
_init_backend: Whether to initialize the backend of the mesh.
_layout: Optional layout for the mesh.
Returns:
The DeviceMesh containing the current rank.
"""
res_mesh = None
for mesh_nd in pg_ranks_by_dim:
mesh = DeviceMesh(
device_type,
mesh_nd,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override,
_init_backend=_init_backend,
_layout=_layout,
)
if cur_rank in mesh_nd:
res_mesh = mesh
if res_mesh is None:
raise RuntimeError(
f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
)
if _root_mesh is not None:
res_mesh._root_mesh = _root_mesh
return res_mesh
@staticmethod
def from_group(
group: Union[ProcessGroup, list[ProcessGroup]],
@ -1004,15 +964,17 @@ else:
return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int:
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
if mesh_dim is not None:
return self._layout[mesh_dim].numel()
return self._layout.numel()
@property
def ndim(self) -> int:
return self.mesh.ndim
return len(self._layout)
@property
def shape(self) -> tuple[int, ...]:
return tuple(self.mesh.shape)
return self._layout.top_level_sizes
def get_rank(self) -> int:
"""
@ -1051,7 +1013,7 @@ else:
"""
if self.ndim > 1 and mesh_dim is None:
raise RuntimeError(
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
f"Found the DeviceMesh have {len(self._layout)} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
)
elif mesh_dim is None:
@ -1112,22 +1074,28 @@ else:
tuple[Optional[str], Optional[C10dBackend.Options]], ...
] = ((None, None),),
) -> "DeviceMesh":
root_mesh = self._get_root_mesh()
cur_rank = self.get_rank()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
root_mesh.mesh,
)
inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
if inner_layout.numel() != self._layout[dim].numel():
raise ValueError(
f"The product of {mesh_sizes=} is {inner_layout.numel()}, "
f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. "
f"These must be equal for unflatten to work correctly."
)
partial_layout = self._layout[dim].composition(inner_layout)
unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
res_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh = self._get_root_mesh()
res_mesh = DeviceMesh(
self.device_type,
pg_ranks_by_dim,
cur_rank,
tuple(unflattened_mesh_dim_names),
_init_backend=False,
_layout=unflattened_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=tuple(unflattened_mesh_dim_names),
_root_mesh=root_mesh,
_init_backend=False,
)
# If original mesh has initiated its backend, we need to initialize the backend
@ -1135,33 +1103,13 @@ else:
# TODO: To make backend init more efficient with cute layout representation and support
# per dim backend init.
if hasattr(self, "_dim_group_names"):
unflatten_length = len(mesh_sizes)
unflatten_layout = _MeshLayout(
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
)
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
root_mesh.mesh,
)
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
self.device_type,
unflatten_pg_ranks_by_dim,
cur_rank,
dim_group_names = self._dim_group_names.copy()
dim_group_names[dim : dim + 1] = self._init_process_groups(
partial_layout,
root_mesh._rank_map,
mesh_dim_names,
backend_override=backend_override,
backend_override,
)
dim_group_names = []
for idx in range(0, res_mesh.ndim):
if idx < dim:
dim_group_names.append(self._dim_group_names[idx])
elif idx >= dim + unflatten_length:
dim_group_names.append(
self._dim_group_names[idx - unflatten_length + 1]
)
else:
dim_group_names.append(
unflatten_submesh._dim_group_names[idx - dim]
)
res_mesh._dim_group_names = dim_group_names
return res_mesh
@ -1349,13 +1297,15 @@ else:
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
)
# Always initialize the mesh's tensor on CPU, regardless of what the
layout = _MeshLayout(tuple(mesh_shape), suffix_product(mesh_shape))
# Always initialize the (identity) rank map on CPU, regardless of what the
# external device type has been set to be (e.g. meta)
with torch.device("cpu"):
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
rank_map = torch.arange(layout.numel(), dtype=torch.int)
device_mesh = DeviceMesh(
device_type=device_type,
mesh=mesh,
_layout=layout,
_rank_map=rank_map,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override_tuple,
)

View File

@ -90,6 +90,7 @@ class DTensorSpec:
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
if self.shard_order is None:
# pyrefly: ignore # bad-assignment
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
self._hash: int | None = None

View File

@ -701,6 +701,7 @@ def _restore_state_dict(
for name, _ in list(
chain(
original_module.named_parameters(remove_duplicate=False),
# pyrefly: ignore # bad-argument-type
original_module.named_buffers(remove_duplicate=False),
)
):

View File

@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
__all__ = [
"annotate",
"annotate_fn",
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
@ -266,9 +267,10 @@ def annotate(annotation_dict: dict):
into the FX trace metadata.
Example:
After exiting the context, custom annotations are removed.
>>> with annotate({"source": "custom_pass", "tag": 42}):
... # compute here
# After exiting the context, custom annotations are removed.
... pass # Your computation here
"""
global current_meta
@ -291,6 +293,43 @@ def annotate(annotation_dict: dict):
del current_meta["custom"]
@compatibility(is_backward_compatible=False)
def annotate_fn(annotation_dict: dict):
"""
A decorator that wraps a function with the annotate context manager.
Use this when you want to annotate an entire function instead of a specific code block.
Note:
This API is **not backward compatible** and may evolve in future releases.
Note:
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
Args:
annotation_dict (dict): A dictionary of custom key-value pairs to inject
into the FX trace metadata for all operations in the function.
Example:
All operations in my_function will have {"pp_stage": 1} in their metadata.
>>> @annotate_fn({"pp_stage": 1})
... def my_function(x):
... return x + 1
"""
from functools import wraps
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
with annotate(annotation_dict):
return func(*args, **kwargs)
return wrapper
return decorator
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta

View File

@ -218,6 +218,7 @@ class FlexKernelOptions(TypedDict, total=False):
waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit."""
# pyrefly: ignore # invalid-annotation
force_flash: NotRequired[bool]
""" If True, forces use of the cute-dsl flash attention kernel.

View File

@ -1,5 +1,5 @@
from . import parametrizations, parametrize, rnn, stateless
from .clip_grad import (
from .clip_grad import ( # pyrefly: ignore # deprecated
_clip_grads_with_norm_ as clip_grads_with_norm_,
_get_total_norm as get_total_norm,
clip_grad_norm,

View File

@ -283,6 +283,7 @@ def clip_grad_value_(
clip_value = float(clip_value)
grads = [p.grad for p in parameters if p.grad is not None]
# pyrefly: ignore # bad-argument-type
grouped_grads = _group_tensors_by_device_and_dtype([grads])
for (device, _), ([grads], _) in grouped_grads.items():

View File

@ -111,8 +111,10 @@ class _Orthogonal(Module):
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"):
# pyrefly: ignore # unbound-name
Q = self.base @ Q
if transposed:
# pyrefly: ignore # unbound-name
Q = Q.mT
return Q # type: ignore[possibly-undefined]

View File

@ -170,6 +170,7 @@ class TorchTensor(ir.Tensor):
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
raise TypeError(
# pyrefly: ignore # missing-attribute
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
"with a tensor backed by real data using ONNXProgram.apply_weights() "
"or save the model without initializers by setting include_initializers=False."

View File

@ -297,6 +297,7 @@ class AveragedModel(Module):
avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
# pyrefly: ignore # missing-attribute
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else:
for p_averaged, p_model in zip( # type: ignore[assignment]

View File

@ -71,6 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
nrows // 16, 16
)
).view(-1)
# pyrefly: ignore # unbound-name
outp = outp.index_copy(1, cols_permuted, outp)
# interleave_column_major_tensor

View File

@ -67,6 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
# pyrefly: ignore # no-matching-overload
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,

View File

@ -184,6 +184,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
outer_stride,
) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
# pyrefly: ignore # no-matching-overload
return cls(
shape=shape,
packed=inner_tensors.get("packed", None),
@ -413,6 +414,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
sparse_tensor_cutlass,
meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
# pyrefly: ignore # no-matching-overload
return cls(
original_tensor.shape,
packed=sparse_tensor_cutlass,
@ -499,6 +501,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
original_tensor, algorithm=algorithm, use_cutlass=True
)
# pyrefly: ignore # no-matching-overload
return cls(
original_tensor.shape,
packed=packed,
@ -560,6 +563,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
# pyrefly: ignore # no-matching-overload
return cls(
shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor),
@ -626,6 +630,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
packed = packed.view(original_tensor.shape[0], -1)
packed_t = packed_t.view(original_tensor.shape[1], -1)
# pyrefly: ignore # no-matching-overload
return cls(
original_tensor.shape,
packed=packed,

View File

@ -1336,6 +1336,7 @@ class Identity(sympy.Function):
def _sympystr(self, printer):
"""Controls how sympy's StrPrinter prints this"""
# pyrefly: ignore # missing-attribute
return f"({printer.doprint(self.args[0])})"
def _eval_is_real(self):