mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			151 Commits
		
	
	
		
			gh/mlazos/
			...
			main-enabl
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e752a29afd | |||
| 36b622bb72 | |||
| 83a04f38a4 | |||
| 6579829bee | |||
| 2b856676f3 | |||
| 5746261c97 | |||
| b3c94fd0fc | |||
| 6fd366b2c7 | |||
| fe25f6ab59 | |||
| ca89e5732f | |||
| f12cb265d4 | |||
| 7dc6bf5377 | |||
| e5ba464808 | |||
| 7d95185044 | |||
| 77fb3c1cac | |||
| 11a3d1d87b | |||
| 8c6d9feb26 | |||
| 003dd13073 | |||
| c2bd41ac9f | |||
| ca8bd5dbed | |||
| 26f3803433 | |||
| 48064acf37 | |||
| e5a9c247bc | |||
| 36371b8ec7 | |||
| 7e6721fb0a | |||
| 901bbcba12 | |||
| febb603230 | |||
| 568d2f3ae7 | |||
| b54e466fd0 | |||
| 53f9ae0e50 | |||
| b42fe389b9 | |||
| 66ea76ec44 | |||
| e787d532b6 | |||
| b3f6d49b69 | |||
| bc1f2108d7 | |||
| f071f17911 | |||
| fa1539594b | |||
| dfc8a1c5dd | |||
| 7f9b745494 | |||
| 83f9baf413 | |||
| ffc7552e01 | |||
| 78f5a1ec60 | |||
| 2b71b62045 | |||
| 8c4b528403 | |||
| 066f818eea | |||
| 14af1dc3da | |||
| 2395d7d7da | |||
| 0aa7ebaf03 | |||
| 7a97832585 | |||
| 84d141e910 | |||
| 7c6c5d04fe | |||
| b509fb9b5d | |||
| 331b7cc054 | |||
| 815d641599 | |||
| ffe3cb226a | |||
| 7ae123d72c | |||
| 7719cb75bf | |||
| 712f54d453 | |||
| f58f301313 | |||
| 5c583e2573 | |||
| 0c14f55de6 | |||
| 8e510e1095 | |||
| 59d30d1b75 | |||
| 3915898c22 | |||
| 3044e1a460 | |||
| b11593c31b | |||
| 36871622f1 | |||
| b4fd47179e | |||
| 4f400ab520 | |||
| 839f6facdb | |||
| ca65023b90 | |||
| 132ae8e6dd | |||
| a20afb6100 | |||
| 47524dcc48 | |||
| 9ffba8a2f9 | |||
| 3681312ce0 | |||
| 7778a58e7c | |||
| e7091a47da | |||
| bcfea48ab7 | |||
| d2e1dbc8f2 | |||
| 89298ada83 | |||
| c467e59cb0 | |||
| bbb902c8dd | |||
| e6f766c7d7 | |||
| 13b621d87c | |||
| 01738a3fea | |||
| a2f34bdd7c | |||
| a63ab0b8cd | |||
| 102b7885ff | |||
| 382d04a51e | |||
| 1ec0755a7e | |||
| 058782c6ab | |||
| 2b4ef6b4d6 | |||
| 3f83e8915e | |||
| d7e3f493d9 | |||
| 08f09d9543 | |||
| 74acf92648 | |||
| cbf212e9c7 | |||
| d18e068fd6 | |||
| 3401665110 | |||
| 8c60f4ae08 | |||
| c4565c3b94 | |||
| 6918f17114 | |||
| 9b6be53326 | |||
| 7fee6bbf34 | |||
| 6adaa328f4 | |||
| 4a7eed527f | |||
| d2494cbb2b | |||
| 5eddbb5e47 | |||
| c9b2a09530 | |||
| bf5aeb3148 | |||
| 45b8c0f75c | |||
| c733072874 | |||
| fbe0d20a17 | |||
| 1fa11f42b1 | |||
| 6f713e25bb | |||
| 09a4187b8e | |||
| 306c55ba27 | |||
| 56d6229ff9 | |||
| 74db92b218 | |||
| c48843e4c6 | |||
| 9e89b1c4c7 | |||
| c5972ebdfb | |||
| 18b3658df9 | |||
| 5fbf93b774 | |||
| a856a17799 | |||
| bc6e08954d | |||
| 45a96b2081 | |||
| 04e36611bb | |||
| f15c25d5c3 | |||
| e93981c243 | |||
| 496adf9f9c | |||
| 33bfec27ff | |||
| f44935cc14 | |||
| 39116409a1 | |||
| 515d1326c1 | |||
| ac529df244 | |||
| fa3916f466 | |||
| 267348fe7f | |||
| 1803d40c99 | |||
| 29c5368e0f | |||
| e71c75680f | |||
| ca96c67500 | |||
| 770e6b910c | |||
| 37d57ac9cb | |||
| 9166f6120f | |||
| fb0291d14b | |||
| f3683453ae | |||
| 1191e51c44 | |||
| 3edd94485f | |||
| a701c937bf | 
@ -187,19 +187,22 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
 | 
			
		||||
            export USE_CUFILE=0
 | 
			
		||||
        else
 | 
			
		||||
            DEPS_LIST+=(
 | 
			
		||||
                "/usr/local/cuda/lib64/libnvToolsExt.so.1"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcublas.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcublasLt.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcudart.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libnvrtc.so.12"
 | 
			
		||||
                "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
 | 
			
		||||
            DEPS_SONAME+=(
 | 
			
		||||
                "libnvToolsExt.so.1"
 | 
			
		||||
                "libcublas.so.12"
 | 
			
		||||
                "libcublasLt.so.12"
 | 
			
		||||
                "libcudart.so.12"
 | 
			
		||||
                "libnvrtc.so.12"
 | 
			
		||||
                "libcupti.so.12")
 | 
			
		||||
 | 
			
		||||
            if [[ $CUDA_VERSION != 12.9* ]]; then
 | 
			
		||||
                DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
 | 
			
		||||
                DEPS_SONAME+=("libnvToolsExt.so.1")
 | 
			
		||||
            fi
 | 
			
		||||
        fi
 | 
			
		||||
    else
 | 
			
		||||
        echo "Using nvidia libs from pypi."
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							@ -8,6 +8,7 @@ assignees: ''
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
> NOTE: Remember to label this issue with "`ci: sev`"
 | 
			
		||||
>       If you want autorevert to be disabled, keep the ci: disable-autorevert label
 | 
			
		||||
 | 
			
		||||
 <!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open -->
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
---
 | 
			
		||||
name: DISABLE AUTOREVERT
 | 
			
		||||
name: "D❌\U0001F519 ISABLE AUTOREVERT"
 | 
			
		||||
about: Disables autorevert when open
 | 
			
		||||
title: "❌\U0001F519 [DISABLE AUTOREVERT]"
 | 
			
		||||
title: "[DISABLE AUTOREVERT]"
 | 
			
		||||
labels: 'ci: disable-autorevert'
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ runs:
 | 
			
		||||
          cd .ci/lumen_cli
 | 
			
		||||
          python3 -m pip install -e .
 | 
			
		||||
        )
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=6)"
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=10)"
 | 
			
		||||
        export MAX_JOBS
 | 
			
		||||
 | 
			
		||||
        # Split the comma-separated list and build each target
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
8ad2aa5d354d1bf432339113860185d5a5d1abbd
 | 
			
		||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
 | 
			
		||||
faffd5cf673615583da6517275e361cb3dbc77e6
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
 | 
			
		||||
ciflow_push_tags:
 | 
			
		||||
- ciflow/b200
 | 
			
		||||
- ciflow/b200-symm-mem
 | 
			
		||||
- ciflow/b200-distributed
 | 
			
		||||
- ciflow/binaries
 | 
			
		||||
- ciflow/binaries_libtorch
 | 
			
		||||
- ciflow/binaries_wheel
 | 
			
		||||
@ -15,7 +16,8 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/inductor-micro-benchmark
 | 
			
		||||
- ciflow/inductor-micro-benchmark-cpu-x86
 | 
			
		||||
- ciflow/inductor-perf-compare
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-x86-zen
 | 
			
		||||
- ciflow/inductor-periodic
 | 
			
		||||
- ciflow/inductor-rocm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
name: CI for distributed tests on B200
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  pull_request:
 | 
			
		||||
    paths:
 | 
			
		||||
      - .github/workflows/b200-distributed.yml
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/b200-distributed/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 46 8 * * *  # about 1:46am PDT
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions:
 | 
			
		||||
  id-token: write
 | 
			
		||||
  contents: read
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runner: linux.12xlarge.memory
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
 | 
			
		||||
      cuda-arch-list: '10.0'
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
 | 
			
		||||
    with:
 | 
			
		||||
      timeout-minutes: 1200
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
 | 
			
		||||
      aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -27,9 +27,8 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: [ '3.12' ]
 | 
			
		||||
        # TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
 | 
			
		||||
        platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129', 'cu130' ]
 | 
			
		||||
        include:
 | 
			
		||||
          - platform: manylinux_2_28_x86_64
 | 
			
		||||
            device: cu128
 | 
			
		||||
@ -39,6 +38,10 @@ jobs:
 | 
			
		||||
            device: cu129
 | 
			
		||||
            manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
 | 
			
		||||
            runner: linux.12xlarge.memory
 | 
			
		||||
          - platform: manylinux_2_28_x86_64
 | 
			
		||||
            device: cu130
 | 
			
		||||
            manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
 | 
			
		||||
            runner: linux.12xlarge.memory
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu128
 | 
			
		||||
            manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
 | 
			
		||||
@ -47,6 +50,11 @@ jobs:
 | 
			
		||||
            device: cu129
 | 
			
		||||
            manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
 | 
			
		||||
            runner: linux.arm64.r7g.12xlarge.memory
 | 
			
		||||
        exclude:
 | 
			
		||||
          # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
 | 
			
		||||
          # xformers is update to support 13.0
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu130
 | 
			
		||||
    name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
 | 
			
		||||
    runs-on: ${{ matrix.runner }}
 | 
			
		||||
    timeout-minutes: 480
 | 
			
		||||
@ -169,7 +177,12 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129', 'cu130' ]
 | 
			
		||||
        exclude:
 | 
			
		||||
          # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
 | 
			
		||||
          # xformers is update to support 13.0
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu130
 | 
			
		||||
    env:
 | 
			
		||||
      PLATFORM: ${{ matrix.platform }}
 | 
			
		||||
      BUILD_DEVICE: ${{ matrix.device }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										132
									
								
								.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,132 @@
 | 
			
		||||
name: inductor-perf-nightly-rocm-mi300
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm-mi300/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 15 0 * * *
 | 
			
		||||
  # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
 | 
			
		||||
  # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      training:
 | 
			
		||||
        description: Run training (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      inference:
 | 
			
		||||
        description: Run inference (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      default:
 | 
			
		||||
        description: Run inductor_default?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      dynamic:
 | 
			
		||||
        description: Run inductor_dynamic_shapes?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cppwrapper:
 | 
			
		||||
        description: Run inductor_cpp_wrapper?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      freezing_cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs with freezing for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      aotinductor:
 | 
			
		||||
        description: Run aot_inductor for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      maxautotune:
 | 
			
		||||
        description: Run inductor_max_autotune?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      benchmark_configs:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
      opt_out_experiments: lf
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-inductor-benchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: rocm-py3_10-inductor-benchmark-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-inductor-benchmark-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: rocm-py3_10-inductor-benchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs: linux-jammy-rocm-py3_10-inductor-benchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }}
 | 
			
		||||
      timeout-minutes: 720
 | 
			
		||||
      # Disable monitor in perf tests for more investigation
 | 
			
		||||
      disable-monitor: true
 | 
			
		||||
      monitor-log-interval: 10
 | 
			
		||||
      monitor-data-collect-interval: 2
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
@ -1,11 +1,11 @@
 | 
			
		||||
name: inductor-perf-nightly-rocm
 | 
			
		||||
name: inductor-perf-nightly-rocm-mi355
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm/*
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm-mi355/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 0 7 * * 0,3
 | 
			
		||||
    - cron: 15 0 * * *
 | 
			
		||||
  # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
 | 
			
		||||
  # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
@ -59,7 +59,7 @@ on:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm
 | 
			
		||||
        default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
@ -88,23 +88,27 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										23
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										23
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							@ -7,9 +7,11 @@ on:
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      test_mode:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: 'short'
 | 
			
		||||
        type: choice
 | 
			
		||||
        options:
 | 
			
		||||
          - 'short'
 | 
			
		||||
          - 'long'
 | 
			
		||||
          - 'all'
 | 
			
		||||
        description: tag filter for operator benchmarks, options from long, short, all
 | 
			
		||||
  schedule:
 | 
			
		||||
    # Run at 07:00 UTC every Sunday
 | 
			
		||||
@ -37,20 +39,7 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  opbenchmark-on-demand-build:
 | 
			
		||||
    if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: opbenchmark-on-demand-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
          { config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -180,13 +180,13 @@ jobs:
 | 
			
		||||
      disable-monitor: false
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  win-vs2022-cuda12_6-py3-build:
 | 
			
		||||
    name: win-vs2022-cuda12.6-py3
 | 
			
		||||
  win-vs2022-cuda12_8-py3-build:
 | 
			
		||||
    name: win-vs2022-cuda12.8-py3
 | 
			
		||||
    uses: ./.github/workflows/_win-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: win-vs2022-cuda12.6-py3
 | 
			
		||||
      cuda-version: "12.6"
 | 
			
		||||
      build-environment: win-vs2022-cuda12.8-py3
 | 
			
		||||
      cuda-version: "12.8"
 | 
			
		||||
      runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -395,3 +395,4 @@ android/pytorch_android_torchvision/.cxx
 | 
			
		||||
CLAUDE.local.md
 | 
			
		||||
/test_*.py
 | 
			
		||||
/debug_*.py
 | 
			
		||||
CLAUDE_CONTEXT/
 | 
			
		||||
 | 
			
		||||
@ -256,6 +256,7 @@ endif()
 | 
			
		||||
IF(USE_FBGEMM_GENAI)
 | 
			
		||||
  set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
 | 
			
		||||
  set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
 | 
			
		||||
 | 
			
		||||
  if(USE_CUDA)
 | 
			
		||||
    # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
 | 
			
		||||
    # If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
 | 
			
		||||
@ -292,58 +293,64 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PUBLIC
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
  else()
 | 
			
		||||
    if(USE_ROCM)
 | 
			
		||||
      # Only include the kernels we want to build to avoid increasing binary size.
 | 
			
		||||
      file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
 | 
			
		||||
        "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
 | 
			
		||||
        "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
 | 
			
		||||
      set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
 | 
			
		||||
 | 
			
		||||
      # Add additional HIPCC compiler flags for performance
 | 
			
		||||
      set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -amdgpu-coerce-illegal-types=1
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -enable-post-misched=0
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -greedy-reverse-local-assignment=1
 | 
			
		||||
        -fhip-new-launch-api)
 | 
			
		||||
    # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  elseif(USE_ROCM)
 | 
			
		||||
    # Only include the kernels we want to build to avoid increasing binary size.
 | 
			
		||||
    file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
 | 
			
		||||
    set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
 | 
			
		||||
 | 
			
		||||
      # Only compile for gfx942 for now.
 | 
			
		||||
      # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
      set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
 | 
			
		||||
      string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
 | 
			
		||||
      if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
 | 
			
		||||
        list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
 | 
			
		||||
      endif()
 | 
			
		||||
      set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
 | 
			
		||||
    # Add additional HIPCC compiler flags for performance
 | 
			
		||||
    set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -amdgpu-coerce-illegal-types=1
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -enable-post-misched=0
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -greedy-reverse-local-assignment=1
 | 
			
		||||
      -fhip-new-launch-api)
 | 
			
		||||
 | 
			
		||||
      hip_add_library(
 | 
			
		||||
        fbgemm_genai STATIC
 | 
			
		||||
        ${fbgemm_genai_native_rocm_hip}
 | 
			
		||||
        HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
 | 
			
		||||
      set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
 | 
			
		||||
      set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
      target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
 | 
			
		||||
 | 
			
		||||
      target_include_directories(fbgemm_genai PUBLIC
 | 
			
		||||
        # FBGEMM version of Composable Kernel is used due to some customizations
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/composable_kernel/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
        ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
        ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
      )
 | 
			
		||||
    # Only compile for gfx942 for now.
 | 
			
		||||
    # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
    set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
 | 
			
		||||
    string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
 | 
			
		||||
    if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
 | 
			
		||||
      list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
 | 
			
		||||
    endif()
 | 
			
		||||
    set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
 | 
			
		||||
 | 
			
		||||
    hip_add_library(
 | 
			
		||||
      fbgemm_genai STATIC
 | 
			
		||||
      ${fbgemm_genai_native_rocm_hip}
 | 
			
		||||
      HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
 | 
			
		||||
    set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
    target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      # FBGEMM version of Composable Kernel is used due to some customizations
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/composable_kernel/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
@ -692,12 +699,6 @@ if(USE_CUDA AND NOT USE_ROCM)
 | 
			
		||||
  list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
 | 
			
		||||
  list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
 | 
			
		||||
 | 
			
		||||
  # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
  if(USE_FBGEMM_GENAI)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  if($ENV{ATEN_STATIC_CUDA})
 | 
			
		||||
    if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
 | 
			
		||||
      list(APPEND ATen_CUDA_DEPENDENCY_LIBS
 | 
			
		||||
 | 
			
		||||
@ -389,37 +389,16 @@ void fillVersion<DLManagedTensorVersioned>(
 | 
			
		||||
// constructed out of ATen tensor
 | 
			
		||||
template <class T>
 | 
			
		||||
T* toDLPackImpl(const Tensor& src) {
 | 
			
		||||
  auto view = src;
 | 
			
		||||
 | 
			
		||||
  // Detect whether there is need to normalize the strides
 | 
			
		||||
  // Background: gh-83069
 | 
			
		||||
  //
 | 
			
		||||
  // However, normalizing strides can come at a high-cost
 | 
			
		||||
  // to slow down toDLPack conversion 3x, so we
 | 
			
		||||
  // only normalize if needed.
 | 
			
		||||
  //
 | 
			
		||||
  // The following code detects whether the src follows
 | 
			
		||||
  // a continuous pattern. If the src follows such pattern (common-case)
 | 
			
		||||
  // then we do not need to normalize the strides.
 | 
			
		||||
  bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
 | 
			
		||||
  // less common case, try normalizing the strides
 | 
			
		||||
  if (need_normalize_strides) {
 | 
			
		||||
    // create a new tensor with possibly normalized strides
 | 
			
		||||
    // gh-83069
 | 
			
		||||
    auto shape = src.sizes();
 | 
			
		||||
    view = src.as_strided(shape, {1}, src.storage_offset());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
 | 
			
		||||
  atDLMTensor->handle = view;
 | 
			
		||||
  atDLMTensor->handle = src;
 | 
			
		||||
  atDLMTensor->tensor.manager_ctx = atDLMTensor;
 | 
			
		||||
  atDLMTensor->tensor.deleter = &deleter<T>;
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.byte_offset = 0;
 | 
			
		||||
  fillVersion(&atDLMTensor->tensor);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -624,7 +624,14 @@ struct TORCH_API IValue final {
 | 
			
		||||
  IValue(const c10::SymBool& i) {
 | 
			
		||||
    if (auto mi = i.maybe_as_bool()) {
 | 
			
		||||
      tag = Tag::Bool;
 | 
			
		||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
 | 
			
		||||
      payload.u.as_int = *mi;
 | 
			
		||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
 | 
			
		||||
      /* due to byteorder if value assigned as_int, as_bool actually is not set correctly */
 | 
			
		||||
      payload.u.as_bool = *mi;
 | 
			
		||||
#else
 | 
			
		||||
#error Unexpected or undefined __BYTE_ORDER__
 | 
			
		||||
#endif
 | 
			
		||||
    } else {
 | 
			
		||||
      tag = Tag::SymBool;
 | 
			
		||||
      payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,7 @@
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/tunable/TunableOp.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
 | 
			
		||||
      BLASType = "unknown";
 | 
			
		||||
  }
 | 
			
		||||
  return BLASType;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Similar to Compute Type in GemmRocblas.h
 | 
			
		||||
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
 | 
			
		||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
 | 
			
		||||
 | 
			
		||||
  if (!config.enabled) {
 | 
			
		||||
    return true; // skip when disabled
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
 | 
			
		||||
  // comparison done as 1D tensor
 | 
			
		||||
  at::Tensor ref = at::from_blob(c,       {size}, options);
 | 
			
		||||
  at::Tensor oth = at::from_blob(other_c, {size}, options);
 | 
			
		||||
  at::Tensor ref_float = ref.to(at::kFloat);
 | 
			
		||||
  at::Tensor oth_float = oth.to(at::kFloat);
 | 
			
		||||
  std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
 | 
			
		||||
  std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
 | 
			
		||||
  double last_succeed_atol = 1;
 | 
			
		||||
  double last_succeed_rtol = 1;
 | 
			
		||||
  for (auto& atol : atols) {
 | 
			
		||||
    for (auto& rtol : rtols) {
 | 
			
		||||
      if (at::allclose(ref_float, oth_float, rtol, atol)) {
 | 
			
		||||
        last_succeed_atol = atol;
 | 
			
		||||
        last_succeed_rtol = rtol;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (last_succeed_atol == 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
  const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
 | 
			
		||||
  if (ok) {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
 | 
			
		||||
  } else {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
 | 
			
		||||
  }
 | 
			
		||||
  return ok;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<T>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<T>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
 | 
			
		||||
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
 | 
			
		||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
 | 
			
		||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
 | 
			
		||||
@ -173,10 +173,9 @@ All python APIs exist in the `torch.cuda.tunable` module.
 | 
			
		||||
| get_max_tuning_iterations() -> int | |
 | 
			
		||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
 | 
			
		||||
| get_filename() -> str | |
 | 
			
		||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
 | 
			
		||||
| get_results() -> Tuple[str, str, str, float] | |
 | 
			
		||||
| get_validators() -> Tuple[str, str] | |
 | 
			
		||||
| write_file_on_exit(val: bool) -> None | Default is True. |
 | 
			
		||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
 | 
			
		||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
 | 
			
		||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
 | 
			
		||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
 | 
			
		||||
 | 
			
		||||
@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
 | 
			
		||||
  std::scoped_lock l{lock_};
 | 
			
		||||
  bool is_new = false;
 | 
			
		||||
  ResultEntry inserted = ResultEntry::Null();
 | 
			
		||||
 | 
			
		||||
  auto it = results_.find(op_signature);
 | 
			
		||||
  if (it == results_.end()) {
 | 
			
		||||
    it = results_.insert({op_signature, {}}).first;
 | 
			
		||||
  // ---- mutate maps under results lock ----
 | 
			
		||||
  {
 | 
			
		||||
    std::scoped_lock l{lock_};
 | 
			
		||||
    auto& km = results_[op_signature];  // creates if missing
 | 
			
		||||
    is_new = (km.find(params_signature) == km.end());
 | 
			
		||||
    AddImpl(op_signature, params_signature, std::move(best), km);
 | 
			
		||||
    if (is_new) {
 | 
			
		||||
      inserted = km.at(params_signature);  // snapshot for I/O after unlocking
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
   if (!is_new) return;  // only write once per unique (op, params)
 | 
			
		||||
 | 
			
		||||
   TuningContext* ctx = getTuningContext();
 | 
			
		||||
  if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
 | 
			
		||||
    InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
 | 
			
		||||
 | 
			
		||||
    if (is_new && realtime_out_ && realtime_out_->good()) {
 | 
			
		||||
      AppendResultLine(op_signature, params_signature, inserted);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  AddImpl(op_signature, params_signature, std::move(best), it->second);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
 | 
			
		||||
@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
  if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (realtime_out_ && realtime_filename_ != filename) {
 | 
			
		||||
    realtime_out_->flush();
 | 
			
		||||
    realtime_out_->close();
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    validators_written_ = false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool file_exists = false;
 | 
			
		||||
  bool file_empty = true;
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    std::ifstream check_file(filename);
 | 
			
		||||
    if (check_file.good()) {
 | 
			
		||||
      file_exists = true;
 | 
			
		||||
      file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
 | 
			
		||||
 | 
			
		||||
  if (!realtime_out_->good()) {
 | 
			
		||||
    TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if(!file_exists || file_empty) {
 | 
			
		||||
    for(const auto& [key, val] : validators) {
 | 
			
		||||
      (*realtime_out_) << "Validator," << key << "," << val << std::endl;
 | 
			
		||||
      realtime_out_->flush();
 | 
			
		||||
    }
 | 
			
		||||
    validators_written_ = true;
 | 
			
		||||
 | 
			
		||||
    TUNABLE_LOG2("Wrote validators to realtime output file");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  realtime_filename_ = filename;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
  if(!realtime_out_ || !realtime_out_->good()) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  (*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
 | 
			
		||||
  realtime_out_->flush(); //ensure immediate write to disk
 | 
			
		||||
 | 
			
		||||
  TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::CloseRealtimeAppend() {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  if(realtime_out_) {
 | 
			
		||||
    realtime_out_->flush();
 | 
			
		||||
    realtime_out_->close();
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    TUNABLE_LOG2("Closed realtime output file");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
 | 
			
		||||
  std::scoped_lock l{lock_};
 | 
			
		||||
 | 
			
		||||
@ -396,7 +483,6 @@ TuningContext::TuningContext() :
 | 
			
		||||
    tuning_enable_{true},
 | 
			
		||||
    record_untuned_enable_{false},
 | 
			
		||||
    manager_initialized_{false},
 | 
			
		||||
    write_file_on_exit_{true},
 | 
			
		||||
    numerics_check_enable_{false},
 | 
			
		||||
    max_tuning_duration_ms_{30},
 | 
			
		||||
    max_tuning_iterations_{100},
 | 
			
		||||
@ -417,20 +503,8 @@ TuningContext::~TuningContext() {
 | 
			
		||||
    // but doesn't do any computation itself.
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  auto filename = GetFilename();
 | 
			
		||||
  if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
 | 
			
		||||
    if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
 | 
			
		||||
      if (results_count_from_input_file_ > 0) {
 | 
			
		||||
        TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        TUNABLE_LOG1("writing file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
      if (!WriteFile(filename)) {
 | 
			
		||||
        TUNABLE_LOG1("failed to write file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TUNABLE_LOG1("Closing File");
 | 
			
		||||
  GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
 | 
			
		||||
 | 
			
		||||
  if (untuned_file_.good()) {
 | 
			
		||||
    untuned_file_.close();
 | 
			
		||||
@ -511,20 +585,54 @@ std::ofstream& TuningContext::GetUntunedFile(){
 | 
			
		||||
  return untuned_file_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::WriteFileOnExit(bool value) {
 | 
			
		||||
  write_file_on_exit_ = value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::EnableNumericsCheck(bool value) {
 | 
			
		||||
  numerics_check_enable_ = value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::IsNumericsCheckEnabled() const {
 | 
			
		||||
  const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
 | 
			
		||||
  if (env == "1") {
 | 
			
		||||
    return true;
 | 
			
		||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
 | 
			
		||||
  const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
 | 
			
		||||
 | 
			
		||||
  if (!env_opt.has_value()) {
 | 
			
		||||
    return numerics_cfg_;
 | 
			
		||||
  }
 | 
			
		||||
  return numerics_check_enable_;
 | 
			
		||||
 | 
			
		||||
  const std::string& env = env_opt.value();
 | 
			
		||||
 | 
			
		||||
  if (env == "0") {
 | 
			
		||||
    return NumericalCheckConfig(false, 1e-5, 1e-5);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t underscore = env.find('_');
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      underscore != std::string::npos,
 | 
			
		||||
      "Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
 | 
			
		||||
      "Expected 'atol_rtol', got: ",
 | 
			
		||||
      env);
 | 
			
		||||
 | 
			
		||||
  double atol = 0.0;
 | 
			
		||||
  double rtol = 0.0;
 | 
			
		||||
 | 
			
		||||
  try {
 | 
			
		||||
    atol = std::stod(env.substr(0, underscore));
 | 
			
		||||
    rtol = std::stod(env.substr(underscore + 1));
 | 
			
		||||
  } catch (const std::exception& e) {
 | 
			
		||||
    TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
 | 
			
		||||
  return NumericalCheckConfig(true, atol, rtol);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
 | 
			
		||||
  TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
 | 
			
		||||
  numerics_cfg_ = {enabled, atol, rtol};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::IsNumericsCheckEnabled() const {
 | 
			
		||||
  const auto cfg = GetNumericalCheckConfig();
 | 
			
		||||
  return cfg.enabled || numerics_check_enable_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
 | 
			
		||||
@ -634,11 +742,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
 | 
			
		||||
    auto filename = GetFilename();
 | 
			
		||||
    if (!filename.empty() && !IsRecordUntunedEnabled()) {
 | 
			
		||||
      ReadFile(filename);
 | 
			
		||||
      // attempt immediately to open file for writing to catch errors early
 | 
			
		||||
      std::ofstream file(filename, std::ios::out | std::ios::app);
 | 
			
		||||
      if (!file.good()) {
 | 
			
		||||
        TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
  return manager_;
 | 
			
		||||
@ -744,27 +847,6 @@ bool TuningContext::ReadFile(const std::string& filename_) {
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::WriteFile(const std::string& filename_) {
 | 
			
		||||
  std::string filename = filename_.empty() ? GetFilename() : filename_;
 | 
			
		||||
  std::ofstream file(filename, std::ios::out | std::ios::trunc);
 | 
			
		||||
  if (!file.good()) {
 | 
			
		||||
    TUNABLE_LOG1("error opening tuning results file for writing ", filename);
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  auto validators = GetTuningResultsValidator().GetAllValidators();
 | 
			
		||||
  for (const auto& [key, val] : validators) {
 | 
			
		||||
    file << "Validator," << key << "," << val << std::endl;
 | 
			
		||||
  }
 | 
			
		||||
  auto results = GetTuningResultsManager().Dump();
 | 
			
		||||
  for (const auto& [op_sig, kernelmap] : results) {
 | 
			
		||||
    for (const auto& [param_sig, result] : kernelmap) {
 | 
			
		||||
      file << op_sig << "," << param_sig << "," << result << std::endl;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  file.close();
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct MaybeDelete {
 | 
			
		||||
 | 
			
		||||
@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
 | 
			
		||||
 | 
			
		||||
    void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
 | 
			
		||||
      const std::string& params_signature, const std::string& blas_signature);
 | 
			
		||||
 | 
			
		||||
    void InitRealtimeAppend(
 | 
			
		||||
        const std::string& filename,
 | 
			
		||||
        const std::unordered_map<std::string, std::string>& validators);
 | 
			
		||||
 | 
			
		||||
    void AppendResultLine(const std::string& op_sig,
 | 
			
		||||
                         const std::string& param_sig,
 | 
			
		||||
                         const ResultEntry& result);
 | 
			
		||||
 | 
			
		||||
    void CloseRealtimeAppend();  // For clean shutdown
 | 
			
		||||
  private:
 | 
			
		||||
    std::mutex lock_;
 | 
			
		||||
    std::mutex realtime_file_mutex_;
 | 
			
		||||
    std::unique_ptr<std::ofstream> realtime_out_;
 | 
			
		||||
    std::string realtime_filename_;
 | 
			
		||||
    ResultsMap results_;
 | 
			
		||||
    UntunedMap untuned_results_;
 | 
			
		||||
    bool validators_written_ = false;
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -134,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
 | 
			
		||||
    GetValidateFuncs validators_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct NumericalCheckConfig {
 | 
			
		||||
  bool   enabled{false};
 | 
			
		||||
  double atol{1e-5};
 | 
			
		||||
  double rtol{1e-5};
 | 
			
		||||
 | 
			
		||||
  NumericalCheckConfig() = default;
 | 
			
		||||
  NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
  public:
 | 
			
		||||
    TuningContext();
 | 
			
		||||
@ -155,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
 | 
			
		||||
    void EnableNumericsCheck(bool value);
 | 
			
		||||
    bool IsNumericsCheckEnabled() const;
 | 
			
		||||
    void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
 | 
			
		||||
    NumericalCheckConfig GetNumericalCheckConfig() const;
 | 
			
		||||
 | 
			
		||||
    void SetMaxTuningDurationMs(int max_duration_ms);
 | 
			
		||||
    int GetMaxTuningDurationMs() const;
 | 
			
		||||
@ -185,10 +211,7 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
 | 
			
		||||
    std::string GetFilename() const;
 | 
			
		||||
 | 
			
		||||
    void WriteFileOnExit(bool value);
 | 
			
		||||
 | 
			
		||||
    bool ReadFile(const std::string& filename={});
 | 
			
		||||
    bool WriteFile(const std::string& filename={});
 | 
			
		||||
 | 
			
		||||
    template<class... Types>
 | 
			
		||||
    void Log(int level, Types... args) {
 | 
			
		||||
@ -207,7 +230,6 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    bool tuning_enable_;
 | 
			
		||||
    bool record_untuned_enable_;
 | 
			
		||||
    bool manager_initialized_;
 | 
			
		||||
    bool write_file_on_exit_;
 | 
			
		||||
    bool numerics_check_enable_;
 | 
			
		||||
    int max_tuning_duration_ms_;
 | 
			
		||||
    int max_tuning_iterations_;
 | 
			
		||||
@ -222,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    std::ofstream untuned_file_;
 | 
			
		||||
    size_t results_count_from_input_file_;
 | 
			
		||||
    bool is_shutting_down_;
 | 
			
		||||
 | 
			
		||||
    NumericalCheckConfig numerics_cfg_{};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
 | 
			
		||||
 | 
			
		||||
@ -267,27 +267,10 @@ class TunableOp {
 | 
			
		||||
      for (size_t i = 0; i < op_names_.size(); i++) {
 | 
			
		||||
        auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
 | 
			
		||||
 | 
			
		||||
        if (do_numerics_check) {
 | 
			
		||||
          ParamsT* numerical_params = params->DeepCopy(false);
 | 
			
		||||
          auto status = candidate->Call(numerical_params);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            numerical_params->Delete();
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          status = reference_params->NumericalCheck(numerical_params);
 | 
			
		||||
          numerical_params->Delete();
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
          auto status = candidate->Call(reusable_params[0]);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        auto status = candidate->Call(reusable_params[0]);
 | 
			
		||||
        if (status != OK) {
 | 
			
		||||
          TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // collect a small profile
 | 
			
		||||
@ -310,6 +293,22 @@ class TunableOp {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (do_numerics_check) {
 | 
			
		||||
          ParamsT* numerical_params = params->DeepCopy(false);
 | 
			
		||||
          auto status = candidate->Call(numerical_params);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            numerical_params->Delete();
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          status = reference_params->NumericalCheck(numerical_params);
 | 
			
		||||
          numerical_params->Delete();
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // for warmup does user set max duration, max iters, or both?
 | 
			
		||||
        // warmup is skipped by default, i.e. warmup_iter = 0
 | 
			
		||||
        // warmup will be set to the non-zero value of max_warmup_duration
 | 
			
		||||
 | 
			
		||||
@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
 | 
			
		||||
  return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: replace with targetable functionalization
 | 
			
		||||
// uses functional formulation for one_hot under vmap to be compatible with
 | 
			
		||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
 | 
			
		||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
 | 
			
		||||
// but requires explicit positive num_classes under vmap to avoid
 | 
			
		||||
// data-dependent output shapes.
 | 
			
		||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
    TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
 | 
			
		||||
    auto shape = self.sym_sizes().vec();
 | 
			
		||||
 | 
			
		||||
    // empty tensor could be converted to one hot representation,
 | 
			
		||||
    // but shape inference is not possible.
 | 
			
		||||
    if (self.sym_numel() == 0) {
 | 
			
		||||
        if (num_classes <= 0) {
 | 
			
		||||
            TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
 | 
			
		||||
        } else {
 | 
			
		||||
            shape.emplace_back(num_classes);
 | 
			
		||||
            return at::empty_symint(shape, self.options());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // disallow implicit inference under vmap; this would be data-dependent
 | 
			
		||||
    // and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
 | 
			
		||||
    TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
 | 
			
		||||
        "provide an explicit positive num_classes argument.");
 | 
			
		||||
 | 
			
		||||
    // Disabling all of the following checks. This is OK because scatter has checks too.
 | 
			
		||||
    // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
 | 
			
		||||
    // // non-empty tensor
 | 
			
		||||
    // if (self.device().type() != at::kCUDA) {
 | 
			
		||||
    //   //for cuda, rely on device assert thrown by scatter
 | 
			
		||||
    //   TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
 | 
			
		||||
    // }
 | 
			
		||||
    // if (self.device().type() != at::kCUDA) {
 | 
			
		||||
    //   //rely on device asserts from scatter to avoid sync here
 | 
			
		||||
    //   TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
 | 
			
		||||
    // }
 | 
			
		||||
 | 
			
		||||
    shape.emplace_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros_symint(shape, self.options());
 | 
			
		||||
    return ret.scatter(-1, self.unsqueeze(-1), 1);
 | 
			
		||||
    const auto options = self.options();
 | 
			
		||||
    at::Tensor index = at::arange(num_classes, options);
 | 
			
		||||
    return at::eq(self.unsqueeze(-1), index).to(at::kLong);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename A, A a, typename C>
 | 
			
		||||
 | 
			
		||||
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto shape = self.sizes().vec();
 | 
			
		||||
    auto shape = self.sym_sizes().vec();
 | 
			
		||||
 | 
			
		||||
    // empty tensor could be converted to one hot representation,
 | 
			
		||||
    // but shape inference is not possible.
 | 
			
		||||
    if (self.numel() == 0) {
 | 
			
		||||
    if (self.sym_numel() == 0) {
 | 
			
		||||
        if (num_classes <= 0) {
 | 
			
		||||
            TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
 | 
			
		||||
        } else {
 | 
			
		||||
            shape.push_back(num_classes);
 | 
			
		||||
            return at::empty(shape, self.options());
 | 
			
		||||
            shape.emplace_back(num_classes);
 | 
			
		||||
            return at::empty_symint(shape, self.options());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    shape.push_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros(shape, self.options());
 | 
			
		||||
    shape.emplace_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros_symint(shape, self.options());
 | 
			
		||||
    ret.scatter_(-1, self.unsqueeze(-1), 1);
 | 
			
		||||
    return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
 | 
			
		||||
  } else if (dtype == ScalarType::Half) {
 | 
			
		||||
    [&]() {
 | 
			
		||||
      using scalar_t =
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
 | 
			
		||||
          c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
 | 
			
		||||
      const auto exp = exp_scalar.to<scalar_t>();
 | 
			
		||||
      using Vec = Vectorized<scalar_t>;
 | 
			
		||||
      cpu_kernel_vec(iter,
 | 
			
		||||
 | 
			
		||||
@ -1230,8 +1230,205 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor&
 | 
			
		||||
_tunable_scaled_gemm_rocm(
 | 
			
		||||
          cublasCommonArgs& args,
 | 
			
		||||
          const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
          const Tensor& scale_a, const Tensor& scale_b,
 | 
			
		||||
          const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const bool use_fast_accum,
 | 
			
		||||
          const at::ScalarType out_dtype,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \
 | 
			
		||||
      if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \
 | 
			
		||||
        if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
        else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
      }                                                               \
 | 
			
		||||
      else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \
 | 
			
		||||
        if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
        else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
      }                                                               \
 | 
			
		||||
      else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) {     \
 | 
			
		||||
        if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t,         \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
        else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e4m3fn, at::Float8_e5m2, scalar_t,           \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
      }                                                               \
 | 
			
		||||
      else if (mat1.scalar_type() == ScalarType::Float8_e5m2) {       \
 | 
			
		||||
        if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e5m2, at::Float8_e4m3fn, scalar_t,           \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
        else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \
 | 
			
		||||
          static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
              at::Float8_e5m2, at::Float8_e5m2, scalar_t,             \
 | 
			
		||||
              BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
          scaledgemm(¶ms);                                        \
 | 
			
		||||
        }                                                             \
 | 
			
		||||
      }
 | 
			
		||||
  AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] {
 | 
			
		||||
    bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
 | 
			
		||||
    bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
 | 
			
		||||
    at::cuda::tunable::ScaledGemmParams<scalar_t> params;
 | 
			
		||||
    params.transa = args.transa;
 | 
			
		||||
    params.transb = args.transb;
 | 
			
		||||
    params.m = args.m;
 | 
			
		||||
    params.n = args.n;
 | 
			
		||||
    params.k = args.k;
 | 
			
		||||
    params.a = args.mata->data_ptr();
 | 
			
		||||
    params.a_scale_ptr = args.scale_mata_ptr;
 | 
			
		||||
    params.a_scale_dtype = args.scale_mata_dtype.value();
 | 
			
		||||
    params.lda = args.lda;
 | 
			
		||||
    params.a_dtype = args.mata->scalar_type();
 | 
			
		||||
    params.a_scale_dtype = args.scale_mata_dtype.value();
 | 
			
		||||
    params.a_scaling_type = args.scaling_mata_type.value();
 | 
			
		||||
    params.b = args.matb->data_ptr();
 | 
			
		||||
    params.b_scale_ptr = args.scale_matb_ptr;
 | 
			
		||||
    params.b_scale_dtype = args.scale_matb_dtype.value();
 | 
			
		||||
    params.ldb = args.ldb;
 | 
			
		||||
    params.b_dtype = args.matb->scalar_type();
 | 
			
		||||
    params.b_scale_dtype = args.scale_matb_dtype.value();
 | 
			
		||||
    params.b_scaling_type = args.scaling_matb_type.value();
 | 
			
		||||
    params.bias_ptr = bias ? bias->data_ptr(): nullptr;
 | 
			
		||||
    params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype;
 | 
			
		||||
    params.c = args.result->data_ptr();
 | 
			
		||||
    params.c_scale_ptr = args.scale_result_ptr;
 | 
			
		||||
    params.ldc = args.result_ld;
 | 
			
		||||
    params.c_dtype = out_dtype;
 | 
			
		||||
    params.use_fast_accum = use_fast_accum;
 | 
			
		||||
    if (transa_ && transb_) {
 | 
			
		||||
      TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
 | 
			
		||||
    }
 | 
			
		||||
    else if (transa_ && !transb_) {
 | 
			
		||||
      TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
 | 
			
		||||
    }
 | 
			
		||||
    else if (!transa_ && transb_) {
 | 
			
		||||
      TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
 | 
			
		||||
    }
 | 
			
		||||
    else if (!transa_ && !transb_) {
 | 
			
		||||
      TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
      TORCH_CHECK(false, "unreachable");
 | 
			
		||||
    }
 | 
			
		||||
  }),
 | 
			
		||||
  kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
 | 
			
		||||
#undef TUNABLE_DISPATCH
 | 
			
		||||
  return out;
 | 
			
		||||
#else
 | 
			
		||||
  TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices");
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor&
 | 
			
		||||
_scaled_gemm(
 | 
			
		||||
          const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
          const Tensor& scale_a, const Tensor& scale_b,
 | 
			
		||||
          const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const bool use_fast_accum,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
  cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
 | 
			
		||||
  const auto out_dtype_ = args.result->scalar_type();
 | 
			
		||||
  TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
 | 
			
		||||
 | 
			
		||||
// ROCM enables the TunableOp path only
 | 
			
		||||
// but can fallback to at::cuda::blas::scaled_gemm
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
  bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled();
 | 
			
		||||
#else
 | 
			
		||||
  bool tunable_op_enabled = false;
 | 
			
		||||
#endif
 | 
			
		||||
  if (tunable_op_enabled) {
 | 
			
		||||
      // Only available on ROCM
 | 
			
		||||
      return _tunable_scaled_gemm_rocm(
 | 
			
		||||
          args,
 | 
			
		||||
          mat1, mat2,
 | 
			
		||||
          scale_a, scale_b,
 | 
			
		||||
          scaling_choice_a, scaling_choice_b,
 | 
			
		||||
          bias,
 | 
			
		||||
          use_fast_accum,
 | 
			
		||||
          out_dtype_,
 | 
			
		||||
          out);
 | 
			
		||||
  }
 | 
			
		||||
  else
 | 
			
		||||
  {
 | 
			
		||||
      at::cuda::blas::scaled_gemm(
 | 
			
		||||
          args.transa,
 | 
			
		||||
          args.transb,
 | 
			
		||||
          args.m,
 | 
			
		||||
          args.n,
 | 
			
		||||
          args.k,
 | 
			
		||||
          args.mata->data_ptr(),
 | 
			
		||||
          args.scale_mata_ptr,
 | 
			
		||||
          args.lda,
 | 
			
		||||
          args.mata->scalar_type(),
 | 
			
		||||
          args.scale_mata_dtype.value(),
 | 
			
		||||
          args.scaling_mata_type.value(),
 | 
			
		||||
          args.matb->data_ptr(),
 | 
			
		||||
          args.scale_matb_ptr,
 | 
			
		||||
          args.ldb,
 | 
			
		||||
          args.matb->scalar_type(),
 | 
			
		||||
          args.scale_matb_dtype.value(),
 | 
			
		||||
          args.scaling_matb_type.value(),
 | 
			
		||||
          bias ? bias->data_ptr(): nullptr,
 | 
			
		||||
          bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
 | 
			
		||||
          args.result->data_ptr(),
 | 
			
		||||
          args.scale_result_ptr,
 | 
			
		||||
          args.result_ld,
 | 
			
		||||
          out_dtype_,
 | 
			
		||||
          use_fast_accum);
 | 
			
		||||
      return out;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here
 | 
			
		||||
//                  to help cleanup v1 call structure.
 | 
			
		||||
Tensor&
 | 
			
		||||
_scaled_rowwise_rowwise(
 | 
			
		||||
          const Tensor&, const Tensor&,
 | 
			
		||||
          const Tensor&, const Tensor&,
 | 
			
		||||
          const std::optional<Tensor>&,
 | 
			
		||||
          const c10::ScalarType,
 | 
			
		||||
          bool,
 | 
			
		||||
          Tensor&);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Computes matrix multiply + bias while applying scaling to input and output matrices
 | 
			
		||||
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
 | 
			
		||||
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
 | 
			
		||||
@ -1273,6 +1470,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
  // by decreasing priority. We prefer "simpler" schemes as they are supported
 | 
			
		||||
  // more broadly (more GPU archs, more CUDA versions) and because they are more
 | 
			
		||||
  // efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
 | 
			
		||||
 | 
			
		||||
  // List of supported BlockWise pairs for FP8:
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
 | 
			
		||||
 | 
			
		||||
  auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
 | 
			
		||||
    {
 | 
			
		||||
      std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
 | 
			
		||||
@ -1305,7 +1506,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
  TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  // Type restrictions imposed by CuBLASLt as of CUDA-12.1
 | 
			
		||||
  TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
 | 
			
		||||
  TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
 | 
			
		||||
        "Multiplication of two Float8_e5m2 matrices is not supported");
 | 
			
		||||
#endif
 | 
			
		||||
  if (use_fast_accum) {
 | 
			
		||||
@ -1371,41 +1572,44 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
 | 
			
		||||
  // NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
 | 
			
		||||
  // and only for compute capability 9.0+. In other cases we use CUTLASS.
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  // We are doing row-wise scaling
 | 
			
		||||
  auto dprops = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
  if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
 | 
			
		||||
      && ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
 | 
			
		||||
      // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
 | 
			
		||||
      ||  (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
 | 
			
		||||
    TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
 | 
			
		||||
    at::cuda::detail::f8f8bf16_rowwise(
 | 
			
		||||
        mat1,
 | 
			
		||||
        mat2,
 | 
			
		||||
        scale_a,
 | 
			
		||||
        scale_b,
 | 
			
		||||
        bias,
 | 
			
		||||
        use_fast_accum,
 | 
			
		||||
        out);
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    auto dprops = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
    if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
 | 
			
		||||
        // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
 | 
			
		||||
        ||  (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
 | 
			
		||||
      TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
 | 
			
		||||
      return _scaled_rowwise_rowwise(
 | 
			
		||||
          mat1,
 | 
			
		||||
          mat2,
 | 
			
		||||
          scale_a,
 | 
			
		||||
          scale_b,
 | 
			
		||||
          bias,
 | 
			
		||||
          out.scalar_type(),
 | 
			
		||||
          use_fast_accum,
 | 
			
		||||
          out);
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
 | 
			
		||||
    Tensor b = mat2;
 | 
			
		||||
    if (_scaled_mm_is_fnuz()) {
 | 
			
		||||
      TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
 | 
			
		||||
      TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz,
 | 
			
		||||
          "Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype());
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
      TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
 | 
			
		||||
      TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn,
 | 
			
		||||
          "Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype());
 | 
			
		||||
    }
 | 
			
		||||
    // Until more than bf16 is supported.
 | 
			
		||||
    TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
 | 
			
		||||
    TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
 | 
			
		||||
         "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    #if ROCM_VERSION >= 70000
 | 
			
		||||
    TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
 | 
			
		||||
                "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
 | 
			
		||||
 | 
			
		||||
    int packed_factor = 1;
 | 
			
		||||
@ -1414,163 +1618,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
      // effectively packing two elements into one byte.
 | 
			
		||||
      packed_factor = 2;
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
 | 
			
		||||
    TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
 | 
			
		||||
                mat2.size(1) % 16 == 0,
 | 
			
		||||
                "M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
 | 
			
		||||
    TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
 | 
			
		||||
                out.scalar_type() == ScalarType::Half,
 | 
			
		||||
                "Block-wise scaling only supports BFloat16 or Half output types");
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
 | 
			
		||||
#endif
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
 | 
			
		||||
  const auto out_dtype_ = args.result->scalar_type();
 | 
			
		||||
  TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
  if (tuning_ctx->IsTunableOpEnabled()) {
 | 
			
		||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \
 | 
			
		||||
        if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }                                                               \
 | 
			
		||||
        else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }                                                               \
 | 
			
		||||
        else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) {     \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t,         \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fn, at::Float8_e5m2, scalar_t,           \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }                                                               \
 | 
			
		||||
        else if (mat1.scalar_type() == ScalarType::Float8_e5m2) {       \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2, at::Float8_e4m3fn, scalar_t,           \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2, at::Float8_e5m2, scalar_t,             \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }
 | 
			
		||||
    AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
 | 
			
		||||
      bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
 | 
			
		||||
      bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
 | 
			
		||||
      at::cuda::tunable::ScaledGemmParams<scalar_t> params;
 | 
			
		||||
      params.transa = args.transa;
 | 
			
		||||
      params.transb = args.transb;
 | 
			
		||||
      params.m = args.m;
 | 
			
		||||
      params.n = args.n;
 | 
			
		||||
      params.k = args.k;
 | 
			
		||||
      params.a = args.mata->data_ptr();
 | 
			
		||||
      params.a_scale_ptr = args.scale_mata_ptr;
 | 
			
		||||
      params.a_scale_dtype = args.scale_mata_dtype.value();
 | 
			
		||||
      params.lda = args.lda;
 | 
			
		||||
      params.a_dtype = args.mata->scalar_type();
 | 
			
		||||
      params.a_scale_dtype = args.scale_mata_dtype.value();
 | 
			
		||||
      params.a_scaling_type = args.scaling_mata_type.value();
 | 
			
		||||
      params.b = args.matb->data_ptr();
 | 
			
		||||
      params.b_scale_ptr = args.scale_matb_ptr;
 | 
			
		||||
      params.b_scale_dtype = args.scale_matb_dtype.value();
 | 
			
		||||
      params.ldb = args.ldb;
 | 
			
		||||
      params.b_dtype = args.matb->scalar_type();
 | 
			
		||||
      params.b_scale_dtype = args.scale_matb_dtype.value();
 | 
			
		||||
      params.b_scaling_type = args.scaling_matb_type.value();
 | 
			
		||||
      params.bias_ptr = bias ? bias->data_ptr(): nullptr;
 | 
			
		||||
      params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
 | 
			
		||||
      params.c = args.result->data_ptr();
 | 
			
		||||
      params.c_scale_ptr = args.scale_result_ptr;
 | 
			
		||||
      params.ldc = args.result_ld;
 | 
			
		||||
      params.c_dtype = out_dtype_;
 | 
			
		||||
      params.use_fast_accum = use_fast_accum;
 | 
			
		||||
      if (transa_ && transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
 | 
			
		||||
      }
 | 
			
		||||
      else if (transa_ && !transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
 | 
			
		||||
      }
 | 
			
		||||
      else if (!transa_ && transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
 | 
			
		||||
      }
 | 
			
		||||
      else if (!transa_ && !transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        TORCH_CHECK(false, "unreachable");
 | 
			
		||||
      }
 | 
			
		||||
    }),
 | 
			
		||||
    kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
 | 
			
		||||
#undef TUNABLE_DISPATCH
 | 
			
		||||
  }
 | 
			
		||||
  else
 | 
			
		||||
#endif
 | 
			
		||||
 {
 | 
			
		||||
    at::cuda::blas::scaled_gemm(
 | 
			
		||||
        args.transa,
 | 
			
		||||
        args.transb,
 | 
			
		||||
        args.m,
 | 
			
		||||
        args.n,
 | 
			
		||||
        args.k,
 | 
			
		||||
        args.mata->data_ptr(),
 | 
			
		||||
        args.scale_mata_ptr,
 | 
			
		||||
        args.lda,
 | 
			
		||||
        args.mata->scalar_type(),
 | 
			
		||||
        args.scale_mata_dtype.value(),
 | 
			
		||||
        args.scaling_mata_type.value(),
 | 
			
		||||
        args.matb->data_ptr(),
 | 
			
		||||
        args.scale_matb_ptr,
 | 
			
		||||
        args.ldb,
 | 
			
		||||
        args.matb->scalar_type(),
 | 
			
		||||
        args.scale_matb_dtype.value(),
 | 
			
		||||
        args.scaling_matb_type.value(),
 | 
			
		||||
        bias ? bias->data_ptr(): nullptr,
 | 
			
		||||
        bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
 | 
			
		||||
        args.result->data_ptr(),
 | 
			
		||||
        args.scale_result_ptr,
 | 
			
		||||
        args.result_ld,
 | 
			
		||||
        out_dtype_,
 | 
			
		||||
        use_fast_accum);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return out;
 | 
			
		||||
  return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
@ -1910,159 +1971,6 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
 | 
			
		||||
  { "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
 | 
			
		||||
  { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
 | 
			
		||||
 | 
			
		||||
Tensor&
 | 
			
		||||
_cutlass_scaled_gemm(
 | 
			
		||||
          const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
          const Tensor& scale_a, const Tensor& scale_b,
 | 
			
		||||
          const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const bool use_fast_accum,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
  cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
 | 
			
		||||
  const auto out_dtype_ = args.result->scalar_type();
 | 
			
		||||
  TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
  if (tuning_ctx->IsTunableOpEnabled()) {
 | 
			
		||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \
 | 
			
		||||
        if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }                                                               \
 | 
			
		||||
        else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }                                                               \
 | 
			
		||||
        else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) {     \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t,         \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e4m3fn, at::Float8_e5m2, scalar_t,           \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }                                                               \
 | 
			
		||||
        else if (mat1.scalar_type() == ScalarType::Float8_e5m2) {       \
 | 
			
		||||
          if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2, at::Float8_e4m3fn, scalar_t,           \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
          else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \
 | 
			
		||||
            static at::cuda::tunable::ScaledGemmTunableOp<              \
 | 
			
		||||
                at::Float8_e5m2, at::Float8_e5m2, scalar_t,             \
 | 
			
		||||
                BLASOP_A, BLASOP_B> scaledgemm{};                       \
 | 
			
		||||
            scaledgemm(¶ms);                                        \
 | 
			
		||||
          }                                                             \
 | 
			
		||||
        }
 | 
			
		||||
    AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
 | 
			
		||||
      bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
 | 
			
		||||
      bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
 | 
			
		||||
      at::cuda::tunable::ScaledGemmParams<scalar_t> params;
 | 
			
		||||
      params.transa = args.transa;
 | 
			
		||||
      params.transb = args.transb;
 | 
			
		||||
      params.m = args.m;
 | 
			
		||||
      params.n = args.n;
 | 
			
		||||
      params.k = args.k;
 | 
			
		||||
      params.a = args.mata->data_ptr();
 | 
			
		||||
      params.a_scale_ptr = args.scale_mata_ptr;
 | 
			
		||||
      params.a_scale_dtype = args.scale_mata_dtype.value();
 | 
			
		||||
      params.lda = args.lda;
 | 
			
		||||
      params.a_dtype = args.mata->scalar_type();
 | 
			
		||||
      params.a_scale_dtype = args.scale_mata_dtype.value();
 | 
			
		||||
      params.a_scaling_type = args.scaling_mata_type.value();
 | 
			
		||||
      params.b = args.matb->data_ptr();
 | 
			
		||||
      params.b_scale_ptr = args.scale_matb_ptr;
 | 
			
		||||
      params.b_scale_dtype = args.scale_matb_dtype.value();
 | 
			
		||||
      params.ldb = args.ldb;
 | 
			
		||||
      params.b_dtype = args.matb->scalar_type();
 | 
			
		||||
      params.b_scale_dtype = args.scale_matb_dtype.value();
 | 
			
		||||
      params.b_scaling_type = args.scaling_matb_type.value();
 | 
			
		||||
      params.bias_ptr = bias ? bias->data_ptr(): nullptr;
 | 
			
		||||
      params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
 | 
			
		||||
      params.c = args.result->data_ptr();
 | 
			
		||||
      params.c_scale_ptr = args.scale_result_ptr;
 | 
			
		||||
      params.ldc = args.result_ld;
 | 
			
		||||
      params.c_dtype = out_dtype_;
 | 
			
		||||
      params.use_fast_accum = use_fast_accum;
 | 
			
		||||
      if (transa_ && transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
 | 
			
		||||
      }
 | 
			
		||||
      else if (transa_ && !transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
 | 
			
		||||
      }
 | 
			
		||||
      else if (!transa_ && transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
 | 
			
		||||
      }
 | 
			
		||||
      else if (!transa_ && !transb_) {
 | 
			
		||||
        TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        TORCH_CHECK(false, "unreachable");
 | 
			
		||||
      }
 | 
			
		||||
    }),
 | 
			
		||||
    kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
 | 
			
		||||
#undef TUNABLE_DISPATCH
 | 
			
		||||
  }
 | 
			
		||||
  else
 | 
			
		||||
#endif
 | 
			
		||||
 {
 | 
			
		||||
    at::cuda::blas::scaled_gemm(
 | 
			
		||||
        args.transa,
 | 
			
		||||
        args.transb,
 | 
			
		||||
        args.m,
 | 
			
		||||
        args.n,
 | 
			
		||||
        args.k,
 | 
			
		||||
        args.mata->data_ptr(),
 | 
			
		||||
        args.scale_mata_ptr,
 | 
			
		||||
        args.lda,
 | 
			
		||||
        args.mata->scalar_type(),
 | 
			
		||||
        args.scale_mata_dtype.value(),
 | 
			
		||||
        args.scaling_mata_type.value(),
 | 
			
		||||
        args.matb->data_ptr(),
 | 
			
		||||
        args.scale_matb_ptr,
 | 
			
		||||
        args.ldb,
 | 
			
		||||
        args.matb->scalar_type(),
 | 
			
		||||
        args.scale_matb_dtype.value(),
 | 
			
		||||
        args.scaling_matb_type.value(),
 | 
			
		||||
        bias ? bias->data_ptr(): nullptr,
 | 
			
		||||
        bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
 | 
			
		||||
        args.result->data_ptr(),
 | 
			
		||||
        args.scale_result_ptr,
 | 
			
		||||
        args.result_ld,
 | 
			
		||||
        out_dtype_,
 | 
			
		||||
        use_fast_accum);
 | 
			
		||||
  }
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor&
 | 
			
		||||
_scaled_tensorwise_tensorwise(
 | 
			
		||||
          const Tensor& mat_a, const Tensor& mat_b,
 | 
			
		||||
@ -2082,7 +1990,7 @@ _scaled_tensorwise_tensorwise(
 | 
			
		||||
  auto scaling_choice_a = ScalingType::TensorWise;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::TensorWise;
 | 
			
		||||
 | 
			
		||||
  _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
  _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
@ -2118,7 +2026,7 @@ _scaled_rowwise_rowwise(
 | 
			
		||||
  if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
 | 
			
		||||
      // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
 | 
			
		||||
      ||  (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
 | 
			
		||||
    TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
 | 
			
		||||
    TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
 | 
			
		||||
    at::cuda::detail::f8f8bf16_rowwise(
 | 
			
		||||
        mat_a,
 | 
			
		||||
        mat_b,
 | 
			
		||||
@ -2144,11 +2052,38 @@ _scaled_rowwise_rowwise(
 | 
			
		||||
       "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
  _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
 | 
			
		||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
 | 
			
		||||
// and strides become somewhat meaningless
 | 
			
		||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
 | 
			
		||||
  if (scale_type == ScalingType::BlockWise1x128) {
 | 
			
		||||
    TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
 | 
			
		||||
        "at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
 | 
			
		||||
        "shape=", scale.sizes(), ", stride=", scale.strides());
 | 
			
		||||
    auto expected_size = ceil_div<int64_t>(t.size(1), 128);
 | 
			
		||||
    TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
 | 
			
		||||
        "at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
 | 
			
		||||
        "shape=", scale.sizes(), ", stride=", scale.strides());
 | 
			
		||||
  } else if (scale_type == ScalingType::BlockWise128x128) {
 | 
			
		||||
      TORCH_CHECK_VALUE(check_size_stride(
 | 
			
		||||
          scale,
 | 
			
		||||
          0,
 | 
			
		||||
          ceil_div<int64_t>(t.size(0), 128),
 | 
			
		||||
          ceil_div<int64_t>(t.size(1), 128)),
 | 
			
		||||
        "at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
 | 
			
		||||
        "shape=", scale.sizes(), ", stride=", scale.strides());
 | 
			
		||||
      TORCH_CHECK(check_size_stride(
 | 
			
		||||
          scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
 | 
			
		||||
        "at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
 | 
			
		||||
        "shape=", scale.sizes(), ", stride=", scale.strides());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor&
 | 
			
		||||
_scaled_block1x128_block1x128(
 | 
			
		||||
          const Tensor& mat_a, const Tensor& mat_b,
 | 
			
		||||
@ -2166,15 +2101,14 @@ _scaled_block1x128_block1x128(
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
 | 
			
		||||
      "scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
 | 
			
		||||
  TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
 | 
			
		||||
  TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
 | 
			
		||||
      "expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
 | 
			
		||||
 | 
			
		||||
  auto scaling_choice_a = ScalingType::BlockWise1x128;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::BlockWise1x128;
 | 
			
		||||
 | 
			
		||||
  _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
  // Check scale strides (including stride=1 small cases)
 | 
			
		||||
  _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
 | 
			
		||||
  _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
 | 
			
		||||
 | 
			
		||||
  _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
@ -2189,6 +2123,8 @@ _scaled_block128x128_block1x128(
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
  // Restrictions:
 | 
			
		||||
  // A, B are FP8, scales are fp32, shape K//128
 | 
			
		||||
  std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
 | 
			
		||||
  std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
 | 
			
		||||
  TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
 | 
			
		||||
      mat_a.scalar_type(), mat_b.scalar_type());
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
 | 
			
		||||
@ -2196,15 +2132,14 @@ _scaled_block128x128_block1x128(
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
 | 
			
		||||
      "scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
 | 
			
		||||
      "expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
 | 
			
		||||
 | 
			
		||||
  auto scaling_choice_a = ScalingType::BlockWise128x128;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::BlockWise1x128;
 | 
			
		||||
 | 
			
		||||
  _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
  // Check scale strides (including stride=1 small cases)
 | 
			
		||||
  _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
 | 
			
		||||
  _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
 | 
			
		||||
 | 
			
		||||
  _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
@ -2226,15 +2161,14 @@ _scaled_block1x128_block128x128(
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
 | 
			
		||||
      "scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
 | 
			
		||||
      "expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
 | 
			
		||||
 | 
			
		||||
  auto scaling_choice_a = ScalingType::BlockWise1x128;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::BlockWise128x128;
 | 
			
		||||
 | 
			
		||||
  _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
  // Check scale strides (including stride=1 small cases)
 | 
			
		||||
  _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
 | 
			
		||||
  _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
 | 
			
		||||
 | 
			
		||||
  _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
 | 
			
		||||
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
@ -2288,7 +2222,7 @@ _scaled_mxfp8_mxfp8(
 | 
			
		||||
#endif
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
 | 
			
		||||
  return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor&
 | 
			
		||||
@ -2325,7 +2259,7 @@ _scaled_nvfp4_nvfp4(
 | 
			
		||||
 | 
			
		||||
  auto scaling_choice_a = ScalingType::BlockWise1x16;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::BlockWise1x16;
 | 
			
		||||
  return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
 | 
			
		||||
  return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2574,7 +2508,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
        const Tensor& mat_a,
 | 
			
		||||
        const Tensor& mat_b,
 | 
			
		||||
        const Tensor& scale_a,
 | 
			
		||||
        const SwizzleType& swizzle_a,
 | 
			
		||||
        const Tensor& scale_b,
 | 
			
		||||
        const SwizzleType& swizzle_b,
 | 
			
		||||
        const std::optional<at::Tensor>& offs,
 | 
			
		||||
        Tensor& out) {
 | 
			
		||||
    const bool a_is_2d = mat_a.dim() == 2;
 | 
			
		||||
@ -2585,6 +2521,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
    TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
 | 
			
		||||
    TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
 | 
			
		||||
    TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
 | 
			
		||||
    // MXFP8 expects float8_e8m0fnu scales.
 | 
			
		||||
    TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
 | 
			
		||||
        "For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
 | 
			
		||||
        "For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
 | 
			
		||||
        "For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
 | 
			
		||||
    fbgemm_gpu::mx8mx8bf16_grouped_mm(
 | 
			
		||||
@ -2669,6 +2615,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
 | 
			
		||||
      const std::optional<Tensor>& bias,
 | 
			
		||||
      bool use_fast_accum,
 | 
			
		||||
      Tensor& out) {
 | 
			
		||||
  // FP8 per-tensor and per-row scaling expect fp32 scales.
 | 
			
		||||
  TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
 | 
			
		||||
      "For grouped FP8 rowwise, both scales must be float32 tensors");
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  return _f8_f8_bf16_rowwise_grouped_mm_cuda(
 | 
			
		||||
      mat_a,
 | 
			
		||||
@ -2768,11 +2717,15 @@ _scaled_grouped_mm_cuda(
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  if (is_mx8mx8bf16) {
 | 
			
		||||
    // Note: Passing implied SwizzleType here, correctness of scale previously checked
 | 
			
		||||
    //       in `check_scale` call
 | 
			
		||||
    return _mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
        mat_a,
 | 
			
		||||
        mat_b,
 | 
			
		||||
        scale_a,
 | 
			
		||||
        SwizzleType::SWIZZLE_32_4_4,
 | 
			
		||||
        scale_b,
 | 
			
		||||
        SwizzleType::SWIZZLE_32_4_4,
 | 
			
		||||
        offs.value(),
 | 
			
		||||
        out);
 | 
			
		||||
  }
 | 
			
		||||
@ -2789,6 +2742,140 @@ _scaled_grouped_mm_cuda(
 | 
			
		||||
      out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
 | 
			
		||||
  { "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
 | 
			
		||||
  { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
Tensor
 | 
			
		||||
_scaled_grouped_mm_cuda_v2(
 | 
			
		||||
          const Tensor& mat_a, const Tensor& mat_b,
 | 
			
		||||
          ArrayRef<Tensor> scale_a,
 | 
			
		||||
          IntArrayRef scale_recipe_a,
 | 
			
		||||
          IntArrayRef swizzle_a,
 | 
			
		||||
          ArrayRef<Tensor> scale_b,
 | 
			
		||||
          IntArrayRef scale_recipe_b,
 | 
			
		||||
          IntArrayRef swizzle_b,
 | 
			
		||||
          const std::optional<Tensor>& offs,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const std::optional<c10::ScalarType> out_dtype,
 | 
			
		||||
          IntArrayRef contraction_dim,
 | 
			
		||||
          bool use_fast_accum) {
 | 
			
		||||
  bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
 | 
			
		||||
  TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
 | 
			
		||||
  TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
 | 
			
		||||
  const bool a_is_2d = mat_a.dim() == 2;
 | 
			
		||||
  const bool b_is_2d = mat_b.dim() == 2;
 | 
			
		||||
 | 
			
		||||
  // NOTE(slayton): For sub-1B formats want contraction_dim argument?
 | 
			
		||||
  if (!a_is_2d || !b_is_2d) {
 | 
			
		||||
    if (contraction_dim.size() > 0) {
 | 
			
		||||
      const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
 | 
			
		||||
      TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
 | 
			
		||||
          "Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
 | 
			
		||||
          mat_b.size(dim_b));
 | 
			
		||||
      // Note: only (-1, -2) is currently supported
 | 
			
		||||
      TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
    mat_a.size(-1) % 16 == 0,
 | 
			
		||||
    "Expected trailing dimension of mat_a to be divisible by 16 ",
 | 
			
		||||
    "but got mat1 shape: (",
 | 
			
		||||
    mat_a.sizes(),
 | 
			
		||||
    ").");
 | 
			
		||||
  TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
 | 
			
		||||
    "Expected mat_b shape to be divisible by 16 ",
 | 
			
		||||
    "but got mat_b shape: (",
 | 
			
		||||
    mat_b.sizes(),
 | 
			
		||||
    ").");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
 | 
			
		||||
  TORCH_CHECK_VALUE(offs.has_value() ==  (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
 | 
			
		||||
 | 
			
		||||
  // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
 | 
			
		||||
  //       for rowwise, no offsets implies 3d-3d and is handled by lower-level
 | 
			
		||||
  //       routines
 | 
			
		||||
  if (offs.has_value()) {
 | 
			
		||||
    TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
 | 
			
		||||
    TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const auto out_dtype_ = out_dtype.value_or(kBFloat16);
 | 
			
		||||
  TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
 | 
			
		||||
 | 
			
		||||
  Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
 | 
			
		||||
 | 
			
		||||
  // Conversion of implicitly-defined enums to explicit
 | 
			
		||||
  auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
 | 
			
		||||
  auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
 | 
			
		||||
  auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
 | 
			
		||||
  auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
 | 
			
		||||
 | 
			
		||||
  // at this point we can start working out what we want to be doing
 | 
			
		||||
  // Try to do as few steps as possible.
 | 
			
		||||
  // NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
 | 
			
		||||
  // Do this via a list of defined (name, acceptance, concrete_impl) tuples.
 | 
			
		||||
  ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
 | 
			
		||||
  for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
 | 
			
		||||
    const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
 | 
			
		||||
    bool ok = accept_fn(mat_a.scalar_type(),
 | 
			
		||||
                        scale_recipe_a_enum,
 | 
			
		||||
                        scale_a,
 | 
			
		||||
                        mat_b.scalar_type(),
 | 
			
		||||
                        scale_recipe_b_enum,
 | 
			
		||||
                        scale_b);
 | 
			
		||||
    if (ok) {
 | 
			
		||||
      gemm_impl = scaled_gemm_impl;
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
 | 
			
		||||
      "No gemm implementation was found");
 | 
			
		||||
 | 
			
		||||
  switch (gemm_impl) {
 | 
			
		||||
    case ScaledGemmImplementation::ROWWISE_ROWWISE: {
 | 
			
		||||
      const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
 | 
			
		||||
      _check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
 | 
			
		||||
      _check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
 | 
			
		||||
      return _f8_f8_bf16_rowwise_grouped_mm(
 | 
			
		||||
          mat_a,
 | 
			
		||||
          mat_b,
 | 
			
		||||
          scale_a[0],
 | 
			
		||||
          scale_b[0],
 | 
			
		||||
          offs,
 | 
			
		||||
          bias,
 | 
			
		||||
          use_fast_accum,
 | 
			
		||||
          out);
 | 
			
		||||
    }
 | 
			
		||||
    case ScaledGemmImplementation::MXFP8_MXFP8: {
 | 
			
		||||
      _check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
 | 
			
		||||
      _check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
 | 
			
		||||
      return _mx8_mx8_bf16_grouped_mm_fbgemm(
 | 
			
		||||
          mat_a,
 | 
			
		||||
          mat_b,
 | 
			
		||||
          scale_a[0],
 | 
			
		||||
          swizzle_a_enum[0],
 | 
			
		||||
          scale_b[0],
 | 
			
		||||
          swizzle_b_enum[0],
 | 
			
		||||
          offs.value(),
 | 
			
		||||
          out);
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_CHECK_NOT_IMPLEMENTED(false,
 | 
			
		||||
          "_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
 | 
			
		||||
const std::optional<at::Tensor>& offs,
 | 
			
		||||
const std::optional<at::Tensor>& bias,
 | 
			
		||||
 | 
			
		||||
@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
      out_calc_t output_offset_calculator,
 | 
			
		||||
      loader_t loader,
 | 
			
		||||
      storer_t storer) {
 | 
			
		||||
    if (ret_t == rt_binary_specializations[arg_index][0] &&
 | 
			
		||||
        arg0_t == rt_binary_specializations[arg_index][1] &&
 | 
			
		||||
        arg1_t == rt_binary_specializations[arg_index][2])
 | 
			
		||||
    constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
 | 
			
		||||
    constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
 | 
			
		||||
    constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
 | 
			
		||||
    if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
 | 
			
		||||
      using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
 | 
			
		||||
      using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
 | 
			
		||||
      using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
 | 
			
		||||
      launch_vectorized_templated_kernel<
 | 
			
		||||
          func_t,
 | 
			
		||||
          array_t,
 | 
			
		||||
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          out_calc_t,
 | 
			
		||||
          loader_t,
 | 
			
		||||
          storer_t,
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][0]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][1]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][2]>::t)>(
 | 
			
		||||
          cret_t,
 | 
			
		||||
          carg0_t,
 | 
			
		||||
          carg1_t>(
 | 
			
		||||
          numel,
 | 
			
		||||
          f,
 | 
			
		||||
          data,
 | 
			
		||||
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          output_offset_calculator,
 | 
			
		||||
          loader,
 | 
			
		||||
          storer);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -655,8 +655,14 @@ struct ReduceOp {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
 | 
			
		||||
    // matching Triton, etc.
 | 
			
		||||
    // todo for AMD
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
    #endif
 | 
			
		||||
      #pragma unroll
 | 
			
		||||
      for (int i = 0; i < output_vec_size; i++) {
 | 
			
		||||
        arg_t other = ops.warp_shfl_down(value[i], offset);
 | 
			
		||||
 | 
			
		||||
@ -77,8 +77,8 @@ struct nansum_functor_complex {
 | 
			
		||||
#if AT_USE_JITERATOR()
 | 
			
		||||
  void operator()(TensorIterator& iter) {
 | 
			
		||||
    std::string func = jiterator_stringify(
 | 
			
		||||
        arg_t combine(arg_t a, scalar_t b) {
 | 
			
		||||
          return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
 | 
			
		||||
        arg_t combine(arg_t a, arg_t b) {
 | 
			
		||||
          return a + (std::isnan(b) ? arg_t{0.} : b);
 | 
			
		||||
        }
 | 
			
		||||
    );
 | 
			
		||||
    jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
 | 
			
		||||
 | 
			
		||||
@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    int32_t trailingSize;
 | 
			
		||||
    int nDimsLocal = nDims;
 | 
			
		||||
    TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
 | 
			
		||||
    if (isInOutAligned) {
 | 
			
		||||
      // in this case we can and should flatten the tensors after the cat dim
 | 
			
		||||
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
      // and divide all strides except last by elems_per_vec (last stride is 1 always)
 | 
			
		||||
      // for input, we will fix up the sizes and strides in the kernel directly
 | 
			
		||||
      kernelOutputParam = outputParam;
 | 
			
		||||
      nDims = dimension + 1;
 | 
			
		||||
      nDimsLocal = dimension + 1;
 | 
			
		||||
      constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
 | 
			
		||||
      auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
 | 
			
		||||
      kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
 | 
			
		||||
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
      case 0:
 | 
			
		||||
        break;
 | 
			
		||||
      case 1:
 | 
			
		||||
        cat_dim = nDims - cat_dim;
 | 
			
		||||
        cat_dim = nDimsLocal - cat_dim;
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        cat_dim--;
 | 
			
		||||
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
              data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
 | 
			
		||||
    }\
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
    switch (nDims) {
 | 
			
		||||
    switch (nDimsLocal) {
 | 
			
		||||
      case 1:
 | 
			
		||||
        HANDLE_CASE(1);
 | 
			
		||||
        break;
 | 
			
		||||
 | 
			
		||||
@ -21,9 +21,15 @@ namespace {
 | 
			
		||||
struct offset_t {
 | 
			
		||||
  int stride;
 | 
			
		||||
  int begin;
 | 
			
		||||
  __device__ int operator[](int i) {
 | 
			
		||||
  __device__ int operator[](int i) const {
 | 
			
		||||
    return stride * (begin + i);
 | 
			
		||||
  }
 | 
			
		||||
#if CCCL_VERSION >= 3001000
 | 
			
		||||
  __device__ offset_t& operator+=(int i) {
 | 
			
		||||
    begin += i;
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
// Segmented sort by full sort algorithm:.
 | 
			
		||||
// Say we are sorting a (2, 3) tensor. We have in flattened form:
 | 
			
		||||
 | 
			
		||||
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
// Helper function to compute output pixel range that can contribute to input pixel
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
__device__ __forceinline__ void compute_output_range(
 | 
			
		||||
    int input_pos,
 | 
			
		||||
    accscalar_t scale,
 | 
			
		||||
    int output_size,
 | 
			
		||||
    bool align_corners,
 | 
			
		||||
    int& min_output,
 | 
			
		||||
    int& max_output) {
 | 
			
		||||
  accscalar_t lo, hi;
 | 
			
		||||
  if (align_corners) {
 | 
			
		||||
      lo = static_cast<accscalar_t>(input_pos - 1) / scale;
 | 
			
		||||
      hi = static_cast<accscalar_t>(input_pos + 1) / scale;
 | 
			
		||||
  } else {
 | 
			
		||||
      lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
      hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
  }
 | 
			
		||||
  min_output = max(0, static_cast<int>(ceil(lo)));
 | 
			
		||||
  max_output = min(output_size - 1, static_cast<int>(floor(hi)));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(1024)
 | 
			
		||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
    const bool align_corners,
 | 
			
		||||
    scalar_t* __restrict__ idata,
 | 
			
		||||
    const scalar_t* __restrict__ odata) {
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    // Decode input pixel coordinates
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
    const int w1 = index_temp % width1;
 | 
			
		||||
    index_temp /= width1;
 | 
			
		||||
    const int h1 = index_temp % height1;
 | 
			
		||||
    const size_t nc_idx = index_temp / height1;
 | 
			
		||||
 | 
			
		||||
    accscalar_t grad_sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Find range of output pixels that could interpolate from this input pixel
 | 
			
		||||
    int h2_min, h2_max, w2_min, w2_max;
 | 
			
		||||
    compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
 | 
			
		||||
    compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
 | 
			
		||||
 | 
			
		||||
    // Iterate over potential output pixels
 | 
			
		||||
    for (int h2 = h2_min; h2 <= h2_max; h2++) {
 | 
			
		||||
      for (int w2 = w2_min; w2 <= w2_max; w2++) {
 | 
			
		||||
        // Compute source coordinates for this output pixel
 | 
			
		||||
        const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rheight, h2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int h1_base = (int)h1r;
 | 
			
		||||
        const int h1p = (h1_base < height1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t h1lambda = h1r - h1_base;
 | 
			
		||||
        const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
 | 
			
		||||
 | 
			
		||||
        const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rwidth, w2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int w1_base = (int)w1r;
 | 
			
		||||
        const int w1p = (w1_base < width1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t w1lambda = w1r - w1_base;
 | 
			
		||||
        const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
 | 
			
		||||
 | 
			
		||||
        // Check if our input pixel participates in this interpolation and accumulate all weights
 | 
			
		||||
        // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
 | 
			
		||||
        // to the same pixel, so we need to accumulate weights from all matching positions
 | 
			
		||||
        accscalar_t weight = 0;
 | 
			
		||||
 | 
			
		||||
        // Check all four interpolation positions and accumulate weights
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base) {
 | 
			
		||||
          weight += h0lambda * w0lambda;  // top-left
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base) {
 | 
			
		||||
          weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (weight > 0) {
 | 
			
		||||
          const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
 | 
			
		||||
          grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write accumulated gradient (no atomics needed)
 | 
			
		||||
    idata[index] = static_cast<scalar_t>(grad_sum);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
        static_cast<scalar_t>(h1lambda * w1lambda * d2val),
 | 
			
		||||
        true);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
  // threads are not covering the whole input tensor.
 | 
			
		||||
  grad_input.zero_();
 | 
			
		||||
 | 
			
		||||
  const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
  const int num_threads = std::min(
 | 
			
		||||
      at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  constexpr bool use_input = true;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr bool use_input = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half, at::ScalarType::BFloat16,
 | 
			
		||||
      grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
 | 
			
		||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
 | 
			
		||||
              input_height,
 | 
			
		||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
 | 
			
		||||
             num_threads,
 | 
			
		||||
 | 
			
		||||
@ -466,7 +466,11 @@ struct ReduceJitOp {
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
    #endif
 | 
			
		||||
      #pragma unroll
 | 
			
		||||
      for (int i = 0; i < output_vec_size; i++) {
 | 
			
		||||
        arg_t other = reducer::warp_shfl_down(value[i], offset);
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,9 @@
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/native/DispatchStub.h>
 | 
			
		||||
#include <c10/util/accumulate.h>
 | 
			
		||||
#include <c10/core/SymBool.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
@ -19,28 +22,30 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
 | 
			
		||||
      "Expected normalized_shape to be at least 1-dimensional, i.e., ",
 | 
			
		||||
      "containing at least one element, but got normalized_shape = ",
 | 
			
		||||
      normalized_shape);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !weight.defined() || weight.sym_sizes().equals(normalized_shape),
 | 
			
		||||
      "Expected weight to be of same shape as normalized_shape, but got ",
 | 
			
		||||
      "weight of shape ",
 | 
			
		||||
      weight.sym_sizes(),
 | 
			
		||||
      " and normalized_shape = ",
 | 
			
		||||
      normalized_shape);
 | 
			
		||||
  if (weight.defined()) {
 | 
			
		||||
    TORCH_SYM_CHECK(
 | 
			
		||||
        sym_equals(weight.sym_sizes(), normalized_shape),
 | 
			
		||||
        "Expected weight to be of same shape as normalized_shape, but got ",
 | 
			
		||||
        "weight of shape ",
 | 
			
		||||
        weight.sym_sizes(),
 | 
			
		||||
        " and normalized_shape = ",
 | 
			
		||||
        normalized_shape);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const auto input_ndim = input.dim();
 | 
			
		||||
  const auto input_shape = input.sym_sizes();
 | 
			
		||||
  if (input_ndim < normalized_ndim ||
 | 
			
		||||
      !input_shape.slice(input_ndim - normalized_ndim)
 | 
			
		||||
           .equals(normalized_shape)) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "Given normalized_shape=" << normalized_shape
 | 
			
		||||
       << ", expected input with shape [*";
 | 
			
		||||
    for (auto size : normalized_shape) {
 | 
			
		||||
      ss << ", " << size;
 | 
			
		||||
    }
 | 
			
		||||
    ss << "], but got input of size" << input_shape;
 | 
			
		||||
    TORCH_CHECK(false, ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      input_ndim >= normalized_ndim,
 | 
			
		||||
      "Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
 | 
			
		||||
 | 
			
		||||
  auto expect_input_shape_msg = c10::str(
 | 
			
		||||
      "Given normalized_shape=", normalized_shape,
 | 
			
		||||
      ", expected input with shape [*", c10::Join(", ", normalized_shape),
 | 
			
		||||
      "], but got input of size", input_shape);
 | 
			
		||||
 | 
			
		||||
  TORCH_SYM_CHECK(
 | 
			
		||||
      sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
 | 
			
		||||
      expect_input_shape_msg);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
 | 
			
		||||
 | 
			
		||||
@ -512,7 +512,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
  // (1.0f + erf(x*SQRT1_2)) * 0.5f * x;
 | 
			
		||||
  // (1.0f + erf(x*SQRT1_2)) * 0.5f;
 | 
			
		||||
  auto dataType = [inputTensor dataType];
 | 
			
		||||
  const float SQRT1_2 = 0.707106781186547524400844362104849039f;
 | 
			
		||||
  MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType];
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,10 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
 | 
			
		||||
  using namespace mps;
 | 
			
		||||
  using CachedGraph = MPSBinaryCachedGraph;
 | 
			
		||||
 | 
			
		||||
  if (self.numel() == 0 & other.numel() == 0) {
 | 
			
		||||
    return zeros({}, self.options());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  dot_check(self, other);
 | 
			
		||||
 | 
			
		||||
  auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
 | 
			
		||||
 | 
			
		||||
@ -4545,6 +4545,7 @@
 | 
			
		||||
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: _cdist_forward
 | 
			
		||||
    MTIA: _cdist_forward_mtia
 | 
			
		||||
    MPS: _cdist_forward_mps
 | 
			
		||||
  autogen: _cdist_forward.out
 | 
			
		||||
  tags: core
 | 
			
		||||
@ -7182,6 +7183,12 @@
 | 
			
		||||
    CUDA: _scaled_grouped_mm_cuda
 | 
			
		||||
  tags: needs_exact_strides
 | 
			
		||||
 | 
			
		||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CUDA: _scaled_grouped_mm_cuda_v2
 | 
			
		||||
  tags: needs_exact_strides
 | 
			
		||||
 | 
			
		||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
 | 
			
		||||
@ -178,24 +178,30 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
 | 
			
		||||
          0 & \text{ else }
 | 
			
		||||
        \end{cases}
 | 
			
		||||
  */
 | 
			
		||||
  auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
 | 
			
		||||
  bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
 | 
			
		||||
  at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
 | 
			
		||||
  at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
 | 
			
		||||
  at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
 | 
			
		||||
  at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
 | 
			
		||||
  auto zero_point_rounded = _get_rounded_zero_point(zero_point_, quant_min, quant_max);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
 | 
			
		||||
  TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(X_.sizes() == dY_.sizes(), "`X` and `dY` are not the same size");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      quant_min <= 0 && quant_max >= 0,
 | 
			
		||||
      "Expecting `quant_min` <= 0 and `quant_max` >= 0");
 | 
			
		||||
  TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(scale_.dim() == 1, "scale should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(zero_point_.dim() == 1, "zero point should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      scale.numel() == zero_point.numel(),
 | 
			
		||||
      scale_.numel() == zero_point_.numel(),
 | 
			
		||||
      "scale and zero-point need to have the same dimensions");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      scale.numel() == X.size(axis),
 | 
			
		||||
      scale_.numel() == X_.size(axis),
 | 
			
		||||
      "dimensions of scale and zero-point are not consistent with input tensor")
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
@ -204,42 +210,42 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
 | 
			
		||||
      "`zero_point` must be between `quant_min` and `quant_max`.");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      axis >= 0 && axis < X.dim(),
 | 
			
		||||
      axis >= 0 && axis < X_.dim(),
 | 
			
		||||
      "`axis` must be between 0 and number of dimensions of input");
 | 
			
		||||
 | 
			
		||||
  if (X.numel() <= 0) {
 | 
			
		||||
  if (X_.numel() <= 0) {
 | 
			
		||||
    return std::make_tuple(X, scale, zero_point);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto numDimensions = X.ndimension();
 | 
			
		||||
  auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto numDimensions = X_.ndimension();
 | 
			
		||||
 | 
			
		||||
  // Create an axis mask for vectorizing and reshaping the scale and zero point tensors
 | 
			
		||||
  // into the same shapes as X along the channel axis.
 | 
			
		||||
  c10::DimVector axis_mask(numDimensions);
 | 
			
		||||
  for (const auto i : c10::irange(numDimensions)) {
 | 
			
		||||
    axis_mask[i] = (i == axis) ? X.size(axis) : 1;
 | 
			
		||||
    axis_mask[i] = (i == axis) ? X_.size(axis) : 1;
 | 
			
		||||
  }
 | 
			
		||||
  auto X_shape = X.sizes();
 | 
			
		||||
  auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
 | 
			
		||||
  auto X_shape = X_.sizes();
 | 
			
		||||
  auto scale_vectorized = scale_.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
 | 
			
		||||
  auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
 | 
			
		||||
 | 
			
		||||
  auto iter = TensorIteratorConfig()
 | 
			
		||||
    .add_output(dX)
 | 
			
		||||
    .add_output(dScale_vec)
 | 
			
		||||
    .add_output(dZeroPoint_vec)
 | 
			
		||||
    .add_input(X)
 | 
			
		||||
    .add_input(dY)
 | 
			
		||||
    .add_input(X_)
 | 
			
		||||
    .add_input(dY_)
 | 
			
		||||
    .add_input(scale_vectorized)
 | 
			
		||||
    .add_input(zero_point_vectorized)
 | 
			
		||||
    .build();
 | 
			
		||||
 | 
			
		||||
  fake_quant_grad_learnable_channel_stub(
 | 
			
		||||
    X.device().type(), iter, quant_min, quant_max, grad_factor);
 | 
			
		||||
    X_.device().type(), iter, quant_min, quant_max, grad_factor);
 | 
			
		||||
 | 
			
		||||
  auto numElements = X.ndimension() - 1;
 | 
			
		||||
  auto numElements = X_.ndimension() - 1;
 | 
			
		||||
 | 
			
		||||
  // Create a collection of axes that include all but the channel axis for
 | 
			
		||||
  // reduction when summing over the dScale and dZeroPoint tensors.
 | 
			
		||||
 | 
			
		||||
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,fail_accuracy,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -1060,6 +1060,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
 | 
			
		||||
            frozen_model_iter_fn = export_nativert(model, example_inputs)
 | 
			
		||||
        elif args.torchscript_jit_trace:
 | 
			
		||||
            frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
 | 
			
		||||
        elif args.aot_precompile:
 | 
			
		||||
            frozen_model_iter_fn = aot_precompile(model, example_inputs)
 | 
			
		||||
        else:
 | 
			
		||||
            if kwargs["hf_llm"]:
 | 
			
		||||
                # If it's an llm, we want to optimize model.forward, and use
 | 
			
		||||
@ -1495,6 +1497,37 @@ def export(model, example_inputs):
 | 
			
		||||
    return opt_export
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def aot_precompile(model, example_inputs):
 | 
			
		||||
    example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
 | 
			
		||||
 | 
			
		||||
    with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
 | 
			
		||||
        save_path = f.name
 | 
			
		||||
 | 
			
		||||
    with fresh_cache(), torch._dynamo.config.patch("enable_aot_compile", True):
 | 
			
		||||
        compiled_fn = torch.compile(
 | 
			
		||||
            model,
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            options={"guard_filter_fn": lambda guards: [False for _ in guards]},
 | 
			
		||||
        ).forward.aot_compile((example_args, example_kwargs))
 | 
			
		||||
 | 
			
		||||
        compiled_fn.save_compiled_function(save_path)
 | 
			
		||||
 | 
			
		||||
        torch._dynamo.reset()
 | 
			
		||||
        with open(save_path, "rb") as f:
 | 
			
		||||
            load_start_time = time.perf_counter()
 | 
			
		||||
            loaded_fn = torch.compiler.load_compiled_function(f)
 | 
			
		||||
            load_end_time = time.perf_counter()
 | 
			
		||||
            print(
 | 
			
		||||
                f"AOT Precompile loading time: {load_end_time - load_start_time} seconds"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            def opt_aot_precompile(_, example_inputs, collect_outputs=False):
 | 
			
		||||
                example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
 | 
			
		||||
                return loaded_fn(model, *example_args, **example_kwargs)
 | 
			
		||||
 | 
			
		||||
            return opt_aot_precompile
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_nativert(model, example_inputs):
 | 
			
		||||
    optimized = NativeRTCache.load(model, example_inputs)
 | 
			
		||||
 | 
			
		||||
@ -2274,6 +2307,7 @@ class BenchmarkRunner:
 | 
			
		||||
                    or self.args.export_aot_inductor
 | 
			
		||||
                    or self.args.export_nativert
 | 
			
		||||
                    or self.args.torchscript_jit_trace
 | 
			
		||||
                    or self.args.aot_precompile
 | 
			
		||||
                ):
 | 
			
		||||
                    # apply export on module directly
 | 
			
		||||
                    # no need for n iterations
 | 
			
		||||
@ -2729,6 +2763,7 @@ class BenchmarkRunner:
 | 
			
		||||
                self.args.export_aot_inductor
 | 
			
		||||
                or self.args.export_nativert
 | 
			
		||||
                or self.args.torchscript_jit_trace
 | 
			
		||||
                or self.args.aot_precompile
 | 
			
		||||
            ):
 | 
			
		||||
                optimized_model_iter_fn = optimize_ctx
 | 
			
		||||
            else:
 | 
			
		||||
@ -3505,6 +3540,11 @@ def parse_args(args=None):
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Measure pass rate with Export+AOTInductor",
 | 
			
		||||
    )
 | 
			
		||||
    group.add_argument(
 | 
			
		||||
        "--aot-precompile",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Measure pass rate with AOT Precompile",
 | 
			
		||||
    )
 | 
			
		||||
    group.add_argument(
 | 
			
		||||
        "--export-nativert",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
@ -3935,6 +3975,10 @@ def run(runner, args, original_dir=None):
 | 
			
		||||
        optimize_ctx = export
 | 
			
		||||
        experiment = speedup_experiment
 | 
			
		||||
        output_filename = "export.csv"
 | 
			
		||||
    elif args.aot_precompile:
 | 
			
		||||
        optimize_ctx = aot_precompile
 | 
			
		||||
        experiment = speedup_experiment
 | 
			
		||||
        output_filename = "aot_precompile.csv"
 | 
			
		||||
    elif args.export_nativert:
 | 
			
		||||
        optimize_ctx = export_nativert
 | 
			
		||||
        experiment = speedup_experiment
 | 
			
		||||
 | 
			
		||||
@ -271,8 +271,6 @@ class TimmRunner(BenchmarkRunner):
 | 
			
		||||
            memory_format=torch.channels_last if channels_last else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.num_classes = model.num_classes
 | 
			
		||||
 | 
			
		||||
        data_config = resolve_data_config(
 | 
			
		||||
            vars(self._args) if timmversion >= "0.8.0" else self._args,
 | 
			
		||||
            model=model,
 | 
			
		||||
@ -302,7 +300,6 @@ class TimmRunner(BenchmarkRunner):
 | 
			
		||||
        example_inputs = [
 | 
			
		||||
            example_inputs,
 | 
			
		||||
        ]
 | 
			
		||||
        self.target = self._gen_target(batch_size, device)
 | 
			
		||||
 | 
			
		||||
        self.loss = torch.nn.CrossEntropyLoss().to(device)
 | 
			
		||||
 | 
			
		||||
@ -370,11 +367,6 @@ class TimmRunner(BenchmarkRunner):
 | 
			
		||||
                tolerance = 1e-2
 | 
			
		||||
        return tolerance, cosine
 | 
			
		||||
 | 
			
		||||
    def _gen_target(self, batch_size, device):
 | 
			
		||||
        return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
 | 
			
		||||
            self.num_classes
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def compute_loss(self, pred):
 | 
			
		||||
        # High loss values make gradient checking harder, as small changes in
 | 
			
		||||
        # accumulation order upsets accuracy checks.
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
adv_inception_v3 128
 | 
			
		||||
beit_base_patch16_224 128
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k 128
 | 
			
		||||
deit_base_distilled_patch16_224 128
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k 128
 | 
			
		||||
dm_nfnet_f0 128
 | 
			
		||||
ghostnet_100 512
 | 
			
		||||
inception_v3 128
 | 
			
		||||
@ -12,3 +14,5 @@ repvgg_a2 128
 | 
			
		||||
swin_base_patch4_window7_224 128
 | 
			
		||||
tf_efficientnet_b0 128
 | 
			
		||||
visformer_small 128
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m 128
 | 
			
		||||
vit_base_patch16_siglip_256 128
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
adv_inception_v3,128
 | 
			
		||||
beit_base_patch16_224,64
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,128
 | 
			
		||||
deit_base_distilled_patch16_224,64
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,128
 | 
			
		||||
dm_nfnet_f0,128
 | 
			
		||||
ghostnet_100,128
 | 
			
		||||
inception_v3,128
 | 
			
		||||
@ -12,3 +14,5 @@ repvgg_a2,128
 | 
			
		||||
swin_base_patch4_window7_224,64
 | 
			
		||||
tf_efficientnet_b0,128
 | 
			
		||||
visformer_small,128
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,128
 | 
			
		||||
ViT-B-16-SigLIP-i18n-256,128
 | 
			
		||||
@ -28,101 +28,8 @@
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
// [dtype Macros note] For the macros below:
 | 
			
		||||
//
 | 
			
		||||
// For users: If you want to macro some code for all non-QInt scalar types
 | 
			
		||||
// (i.e. types with complete information, you probably want one of the
 | 
			
		||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
 | 
			
		||||
// designed to behave similarly to the Dispatch macros with the same name.
 | 
			
		||||
//
 | 
			
		||||
// For adding a new dtype: In the beginning, we had an idea that there was a
 | 
			
		||||
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
 | 
			
		||||
// iterate over them.  But over the years we added weird types which couldn't
 | 
			
		||||
// be handled uniformly everywhere and so in the end we ended up with some
 | 
			
		||||
// mish-mosh of some helper macros, but mostly use sites making a call about
 | 
			
		||||
// what dtypes they can or can't support.  So if you want to add a new dtype,
 | 
			
		||||
// the preferred resolution is to find a dtype similar to what you want,
 | 
			
		||||
// grep for it and edit all the sites you find this way.  If you need to add
 | 
			
		||||
// a completely new kind of dtype, you're going to have to laboriously audit
 | 
			
		||||
// all of the sites everywhere to figure out how it should work.  Consulting
 | 
			
		||||
// some old PRs where we added new dtypes (check history of this file) can
 | 
			
		||||
// help give you an idea where to start.
 | 
			
		||||
 | 
			
		||||
// If you want to support ComplexHalf for real, add ComplexHalf
 | 
			
		||||
// into this macro (and change the name).  But beware: convert()
 | 
			
		||||
// doesn't work for all the conversions you need...
 | 
			
		||||
//
 | 
			
		||||
// TODO: To add unsigned int types here, we must define accumulate type.
 | 
			
		||||
// But uint8 currently accumulates into int64, so we would have to make
 | 
			
		||||
// an inconsistent choice for the larger types.  Difficult.
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
 | 
			
		||||
  _(uint8_t, Byte)                                                      \
 | 
			
		||||
  _(int8_t, Char)                                                       \
 | 
			
		||||
  _(int16_t, Short)                                                     \
 | 
			
		||||
  _(int, Int)                                                           \
 | 
			
		||||
  _(int64_t, Long)                                                      \
 | 
			
		||||
  _(at::Half, Half)                                                     \
 | 
			
		||||
  _(float, Float)                                                       \
 | 
			
		||||
  _(double, Double)                                                     \
 | 
			
		||||
  _(c10::complex<float>, ComplexFloat)                                  \
 | 
			
		||||
  _(c10::complex<double>, ComplexDouble)                                \
 | 
			
		||||
  _(bool, Bool)                                                         \
 | 
			
		||||
  _(at::BFloat16, BFloat16)                                             \
 | 
			
		||||
  _(at::Float8_e5m2, Float8_e5m2)                                       \
 | 
			
		||||
  _(at::Float8_e4m3fn, Float8_e4m3fn)
 | 
			
		||||
 | 
			
		||||
// This macro controls many of our C++ APIs, including constructors
 | 
			
		||||
// for Scalar as well as the data() and item() accessors on Tensor
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
 | 
			
		||||
  _(uint8_t, Byte)                             \
 | 
			
		||||
  _(int8_t, Char)                              \
 | 
			
		||||
  _(int16_t, Short)                            \
 | 
			
		||||
  _(int, Int)                                  \
 | 
			
		||||
  _(int64_t, Long)                             \
 | 
			
		||||
  _(at::Half, Half)                            \
 | 
			
		||||
  _(float, Float)                              \
 | 
			
		||||
  _(double, Double)                            \
 | 
			
		||||
  _(c10::complex<c10::Half>, ComplexHalf)      \
 | 
			
		||||
  _(c10::complex<float>, ComplexFloat)         \
 | 
			
		||||
  _(c10::complex<double>, ComplexDouble)       \
 | 
			
		||||
  _(bool, Bool)                                \
 | 
			
		||||
  _(at::BFloat16, BFloat16)                    \
 | 
			
		||||
  _(at::Float8_e5m2, Float8_e5m2)              \
 | 
			
		||||
  _(at::Float8_e4m3fn, Float8_e4m3fn)          \
 | 
			
		||||
  _(at::Float8_e5m2fnuz, Float8_e5m2fnuz)      \
 | 
			
		||||
  _(at::Float8_e4m3fnuz, Float8_e4m3fnuz)      \
 | 
			
		||||
  _(at::Float8_e8m0fnu, Float8_e8m0fnu)
 | 
			
		||||
 | 
			
		||||
namespace impl {
 | 
			
		||||
 | 
			
		||||
// These are used to map ScalarTypes to C++ types.
 | 
			
		||||
 | 
			
		||||
template <c10::ScalarType N>
 | 
			
		||||
struct ScalarTypeToCPPType;
 | 
			
		||||
 | 
			
		||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type)                \
 | 
			
		||||
  template <>                                                                \
 | 
			
		||||
  struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> {                 \
 | 
			
		||||
    using type = cpp_type;                                                   \
 | 
			
		||||
                                                                             \
 | 
			
		||||
    /* This is a workaround for the CUDA bug which prevents */               \
 | 
			
		||||
    /* ::detail::ScalarTypeToCType<T>::type being used directly due to */    \
 | 
			
		||||
    /* ambiguous reference which can't to be resolved. For some reason it */ \
 | 
			
		||||
    /* can't pick between at::detail and at::cuda::detail. */                \
 | 
			
		||||
    /* For repro example, please see: */                                     \
 | 
			
		||||
    /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */    \
 | 
			
		||||
    /* TODO: remove once the bug is fixed. */                                \
 | 
			
		||||
    static type t;                                                           \
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
 | 
			
		||||
 | 
			
		||||
#undef SPECIALIZE_ScalarTypeToCPPType
 | 
			
		||||
 | 
			
		||||
template <c10::ScalarType N>
 | 
			
		||||
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
 | 
			
		||||
 | 
			
		||||
} // namespace impl
 | 
			
		||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
 | 
			
		||||
// regarding macros.
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct CppTypeToScalarType;
 | 
			
		||||
@ -138,130 +45,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
 | 
			
		||||
 | 
			
		||||
#undef SPECIALIZE_CppTypeToScalarType
 | 
			
		||||
 | 
			
		||||
// NB: despite its generic sounding name, the macros that don't take _AND
 | 
			
		||||
// are mostly only used by tensorexpr
 | 
			
		||||
#define AT_FORALL_INT_TYPES(_) \
 | 
			
		||||
  _(uint8_t, Byte)             \
 | 
			
		||||
  _(int8_t, Char)              \
 | 
			
		||||
  _(int16_t, Short)            \
 | 
			
		||||
  _(int, Int)                  \
 | 
			
		||||
  _(int64_t, Long)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES(_) \
 | 
			
		||||
  _(uint8_t, Byte)                \
 | 
			
		||||
  _(int8_t, Char)                 \
 | 
			
		||||
  _(int16_t, Short)               \
 | 
			
		||||
  _(int, Int)                     \
 | 
			
		||||
  _(int64_t, Long)                \
 | 
			
		||||
  _(float, Float)                 \
 | 
			
		||||
  _(double, Double)
 | 
			
		||||
 | 
			
		||||
// These macros are often controlling how many template instantiations we
 | 
			
		||||
// create for kernels.  It is typically inappropriate to add new dtypes here,
 | 
			
		||||
// instead, new types should be added to use sites on a case-by-case basis.
 | 
			
		||||
// We generally are not accepting new dtypes due to binary size concerns.
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
 | 
			
		||||
  _(uint8_t, Byte)                                \
 | 
			
		||||
  _(int8_t, Char)                                 \
 | 
			
		||||
  _(int16_t, Short)                               \
 | 
			
		||||
  _(int, Int)                                     \
 | 
			
		||||
  _(int64_t, Long)                                \
 | 
			
		||||
  _(float, Float)                                 \
 | 
			
		||||
  _(double, Double)                               \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE>::t),  \
 | 
			
		||||
    SCALARTYPE)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
 | 
			
		||||
  _(uint8_t, Byte)                                               \
 | 
			
		||||
  _(int8_t, Char)                                                \
 | 
			
		||||
  _(int16_t, Short)                                              \
 | 
			
		||||
  _(int, Int)                                                    \
 | 
			
		||||
  _(int64_t, Long)                                               \
 | 
			
		||||
  _(float, Float)                                                \
 | 
			
		||||
  _(double, Double)                                              \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<                   \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE1>::t),                \
 | 
			
		||||
    SCALARTYPE1)                                                 \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<                   \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE2>::t),                \
 | 
			
		||||
    SCALARTYPE2)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
 | 
			
		||||
  _(uint8_t, Byte)                                                            \
 | 
			
		||||
  _(int8_t, Char)                                                             \
 | 
			
		||||
  _(int16_t, Short)                                                           \
 | 
			
		||||
  _(int, Int)                                                                 \
 | 
			
		||||
  _(int64_t, Long)                                                            \
 | 
			
		||||
  _(float, Float)                                                             \
 | 
			
		||||
  _(double, Double)                                                           \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<                                \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE1>::t),                             \
 | 
			
		||||
    SCALARTYPE1)                                                              \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<                                \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE2>::t),                             \
 | 
			
		||||
    SCALARTYPE2)                                                              \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<                                \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE3>::t),                             \
 | 
			
		||||
    SCALARTYPE3)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_SCALAR_TYPES_AND7(              \
 | 
			
		||||
    SCALARTYPE1,                                  \
 | 
			
		||||
    SCALARTYPE2,                                  \
 | 
			
		||||
    SCALARTYPE3,                                  \
 | 
			
		||||
    SCALARTYPE4,                                  \
 | 
			
		||||
    SCALARTYPE5,                                  \
 | 
			
		||||
    SCALARTYPE6,                                  \
 | 
			
		||||
    SCALARTYPE7,                                  \
 | 
			
		||||
    _)                                            \
 | 
			
		||||
  _(uint8_t, Byte)                                \
 | 
			
		||||
  _(int8_t, Char)                                 \
 | 
			
		||||
  _(int16_t, Short)                               \
 | 
			
		||||
  _(int, Int)                                     \
 | 
			
		||||
  _(int64_t, Long)                                \
 | 
			
		||||
  _(float, Float)                                 \
 | 
			
		||||
  _(double, Double)                               \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE1>::t), \
 | 
			
		||||
    SCALARTYPE1)                                  \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE2>::t), \
 | 
			
		||||
    SCALARTYPE2)                                  \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE3>::t), \
 | 
			
		||||
    SCALARTYPE3)                                  \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE4>::t), \
 | 
			
		||||
    SCALARTYPE4)                                  \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE5>::t), \
 | 
			
		||||
    SCALARTYPE5)                                  \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE6>::t), \
 | 
			
		||||
    SCALARTYPE6)                                  \
 | 
			
		||||
  _(decltype(::c10::impl::ScalarTypeToCPPType<    \
 | 
			
		||||
             ::c10::ScalarType::SCALARTYPE7>::t), \
 | 
			
		||||
    SCALARTYPE7)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_QINT_TYPES(_) \
 | 
			
		||||
  _(c10::qint8, QInt8)          \
 | 
			
		||||
  _(c10::quint8, QUInt8)        \
 | 
			
		||||
  _(c10::qint32, QInt32)        \
 | 
			
		||||
  _(c10::quint4x2, QUInt4x2)    \
 | 
			
		||||
  _(c10::quint2x4, QUInt2x4)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_FLOAT8_TYPES(_)         \
 | 
			
		||||
  _(at::Float8_e5m2, Float8_e5m2)         \
 | 
			
		||||
  _(at::Float8_e4m3fn, Float8_e4m3fn)     \
 | 
			
		||||
  _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
 | 
			
		||||
  _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
 | 
			
		||||
  _(at::Float8_e8m0fnu, Float8_e8m0fnu)
 | 
			
		||||
 | 
			
		||||
#define AT_FORALL_COMPLEX_TYPES(_)     \
 | 
			
		||||
  _(c10::complex<float>, ComplexFloat) \
 | 
			
		||||
  _(c10::complex<double>, ComplexDouble)
 | 
			
		||||
 | 
			
		||||
#define DEFINE_CONSTANT(_, name) \
 | 
			
		||||
  constexpr ScalarType k##name = ScalarType::name;
 | 
			
		||||
 | 
			
		||||
@ -269,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
 | 
			
		||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
 | 
			
		||||
#undef DEFINE_CONSTANT
 | 
			
		||||
 | 
			
		||||
inline const char* toString(ScalarType t) {
 | 
			
		||||
#define DEFINE_CASE(_, name) \
 | 
			
		||||
  case ScalarType::name:     \
 | 
			
		||||
    return #name;
 | 
			
		||||
 | 
			
		||||
  switch (t) {
 | 
			
		||||
    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
 | 
			
		||||
    default:
 | 
			
		||||
      return "UNKNOWN_SCALAR";
 | 
			
		||||
  }
 | 
			
		||||
#undef DEFINE_CASE
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline size_t elementSize(ScalarType t) {
 | 
			
		||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
 | 
			
		||||
  case ScalarType::name:                   \
 | 
			
		||||
@ -525,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
 | 
			
		||||
 | 
			
		||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
 | 
			
		||||
 | 
			
		||||
inline std::ostream& operator<<(
 | 
			
		||||
    std::ostream& stream,
 | 
			
		||||
    at::ScalarType scalar_type) {
 | 
			
		||||
  return stream << toString(scalar_type);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns a pair of strings representing the names for each dtype.
 | 
			
		||||
// The returned pair is (name, legacy_name_if_applicable)
 | 
			
		||||
C10_API std::pair<std::string, std::string> getDtypeNames(
 | 
			
		||||
 | 
			
		||||
@ -86,4 +86,23 @@ inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) {
 | 
			
		||||
      reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline c10::SymBool sym_equals(SymIntArrayRef LHS, SymIntArrayRef RHS) {
 | 
			
		||||
  if (LHS.size() != RHS.size()) {
 | 
			
		||||
    return c10::SymBool(false);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::SymBool result = sym_eq(LHS.size(), RHS.size());
 | 
			
		||||
  for (size_t i = 0; i < RHS.size(); ++i) {
 | 
			
		||||
    c10::SymBool equals = sym_eq(LHS[i], RHS[i]);
 | 
			
		||||
    std::optional<bool> equals_bool = equals.maybe_as_bool();
 | 
			
		||||
 | 
			
		||||
    if (equals_bool.has_value() && !*equals_bool) {
 | 
			
		||||
      // Early return if element comparison is known to be false
 | 
			
		||||
      return equals;
 | 
			
		||||
    }
 | 
			
		||||
    result = result.sym_and(equals);
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
@ -1080,19 +1080,12 @@ class RingBuffer {
 | 
			
		||||
 | 
			
		||||
  void getEntries(std::vector<T>& result) const {
 | 
			
		||||
    std::lock_guard<std::mutex> lk(alloc_trace_lock);
 | 
			
		||||
    result.reserve(alloc_trace->size());
 | 
			
		||||
    result.insert(
 | 
			
		||||
        result.end(),
 | 
			
		||||
        alloc_trace->begin() +
 | 
			
		||||
            static_cast<typename std::vector<T>::difference_type>(
 | 
			
		||||
                alloc_trace_next),
 | 
			
		||||
        alloc_trace->end());
 | 
			
		||||
    result.insert(
 | 
			
		||||
        result.end(),
 | 
			
		||||
    result.reserve(result.size() + alloc_trace->size());
 | 
			
		||||
    std::rotate_copy(
 | 
			
		||||
        alloc_trace->begin(),
 | 
			
		||||
        alloc_trace->begin() +
 | 
			
		||||
            static_cast<typename std::vector<T>::difference_type>(
 | 
			
		||||
                alloc_trace_next));
 | 
			
		||||
        std::next(alloc_trace->begin(), alloc_trace_next),
 | 
			
		||||
        alloc_trace->end(),
 | 
			
		||||
        std::back_inserter(result));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void clear() {
 | 
			
		||||
@ -4466,10 +4459,7 @@ struct BackendStaticInitializer {
 | 
			
		||||
          if (kv[0] == "backend") {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
            // convenience for ROCm users to allow either CUDA or HIP env var
 | 
			
		||||
            if (kv[1] ==
 | 
			
		||||
                    "cud"
 | 
			
		||||
                    "aMallocAsync" ||
 | 
			
		||||
                kv[1] == "hipMallocAsync")
 | 
			
		||||
            if (kv[1] == "cudaMallocAsync" || kv[1] == "hipMallocAsync")
 | 
			
		||||
#else
 | 
			
		||||
            if (kv[1] == "cudaMallocAsync")
 | 
			
		||||
#endif
 | 
			
		||||
@ -4491,9 +4481,7 @@ struct BackendStaticInitializer {
 | 
			
		||||
// HIPAllocatorMasqueradingAsCUDA because it needs to happen during static
 | 
			
		||||
// initialization, and doing so there may introduce static initialization
 | 
			
		||||
// order (SIOF) issues.
 | 
			
		||||
#define HIP_MASQUERADING_AS_CUDA \
 | 
			
		||||
  "cud"                          \
 | 
			
		||||
  "a"
 | 
			
		||||
#define HIP_MASQUERADING_AS_CUDA "cuda"
 | 
			
		||||
    at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0);
 | 
			
		||||
    allocator.store(r);
 | 
			
		||||
#undef HIP_MASQUERADING_AS_CUDA
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ struct default_constructible
 | 
			
		||||
 | 
			
		||||
namespace impl {
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>*)
 | 
			
		||||
  constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>* /*unused*/)
 | 
			
		||||
  {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
@ -76,7 +76,7 @@ class type : public modifier<M, type<T, Tag, M...>>...
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
  template <typename TT = T, typename = std::enable_if_t<std::is_trivially_constructible<TT>{}>>
 | 
			
		||||
  explicit type(uninitialized_t)
 | 
			
		||||
  explicit type(uninitialized_t /*unused*/)
 | 
			
		||||
    noexcept
 | 
			
		||||
  {
 | 
			
		||||
  }
 | 
			
		||||
@ -138,7 +138,7 @@ private:
 | 
			
		||||
 | 
			
		||||
namespace impl {
 | 
			
		||||
  template <typename T, typename Tag, typename ... Ms>
 | 
			
		||||
  constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>*) { return true;}
 | 
			
		||||
  constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>* /*unused*/) { return true;}
 | 
			
		||||
  constexpr bool is_strong_type_func(...) { return false;}
 | 
			
		||||
  template <typename T, typename Tag, typename ... Ms>
 | 
			
		||||
  constexpr T underlying_type(strong::type<T, Tag, Ms...>*);
 | 
			
		||||
 | 
			
		||||
@ -1019,6 +1019,8 @@ coverage_ignore_functions = [
 | 
			
		||||
    "loop_pass",
 | 
			
		||||
    "these_before_those_pass_constraint",
 | 
			
		||||
    "this_before_that_pass_constraint",
 | 
			
		||||
    # torch.fx.passes.regional_inductor
 | 
			
		||||
    "regional_inductor",
 | 
			
		||||
    # torch.fx.passes.reinplace
 | 
			
		||||
    "reinplace",
 | 
			
		||||
    # torch.fx.passes.split_module
 | 
			
		||||
 | 
			
		||||
@ -68,14 +68,6 @@
 | 
			
		||||
.. autofunction:: get_validators
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: write_file_on_exit
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: write_file
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: read_file
 | 
			
		||||
```
 | 
			
		||||
@ -95,3 +87,7 @@
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: get_rotating_buffer_size
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: set_numerical_check_tolerances
 | 
			
		||||
```
 | 
			
		||||
@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
 | 
			
		||||
.. autoclass:: CPUOffloadPolicy
 | 
			
		||||
    :members:
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: share_comm_ctx
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it.
 | 
			
		||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
 | 
			
		||||
| reduce_scatter | ✓   | ✓   | ✘   | ✘   | ✘   | ✓   | ✘   | ✓   |
 | 
			
		||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
 | 
			
		||||
| all_to_all     | ✓   | ✓   | ✓   | ?   | ✘   | ✓   | ✘   | ✓   |
 | 
			
		||||
| all_to_all     | ✘   | ✘   | ✓   | ?   | ✘   | ✓   | ✘   | ✓   |
 | 
			
		||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
 | 
			
		||||
| barrier        | ✓   | ✘   | ✓   | ?   | ✘   | ✓   | ✘   | ✓   |
 | 
			
		||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
 | 
			
		||||
 | 
			
		||||
@ -1169,6 +1169,7 @@ The set of leaf modules can be customized by overriding
 | 
			
		||||
.. py:module:: torch.fx.passes.operator_support
 | 
			
		||||
.. py:module:: torch.fx.passes.param_fetch
 | 
			
		||||
.. py:module:: torch.fx.passes.pass_manager
 | 
			
		||||
.. py:module:: torch.fx.passes.regional_inductor
 | 
			
		||||
.. py:module:: torch.fx.passes.reinplace
 | 
			
		||||
.. py:module:: torch.fx.passes.runtime_assert
 | 
			
		||||
.. py:module:: torch.fx.passes.shape_prop
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ Submodules
 | 
			
		||||
    flex_attention
 | 
			
		||||
    bias
 | 
			
		||||
    experimental
 | 
			
		||||
    varlen
 | 
			
		||||
 | 
			
		||||
.. toctree::
 | 
			
		||||
    :hidden:
 | 
			
		||||
@ -30,3 +31,4 @@ Submodules
 | 
			
		||||
    nn.attention.flex_attention
 | 
			
		||||
    nn.attention.bias
 | 
			
		||||
    nn.attention.experimental
 | 
			
		||||
    nn.attention.varlen
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								docs/source/nn.attention.varlen.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								docs/source/nn.attention.varlen.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. role:: hidden
 | 
			
		||||
    :class: hidden-section
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
# torch.nn.attention.varlen
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. automodule:: torch.nn.attention.varlen
 | 
			
		||||
.. currentmodule:: torch.nn.attention.varlen
 | 
			
		||||
```
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autofunction:: varlen_attn
 | 
			
		||||
```
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autoclass:: AuxRequest
 | 
			
		||||
```
 | 
			
		||||
@ -228,3 +228,4 @@ Low-Precision functions
 | 
			
		||||
    ScalingType
 | 
			
		||||
    SwizzleType
 | 
			
		||||
    scaled_mm
 | 
			
		||||
    scaled_grouped_mm
 | 
			
		||||
 | 
			
		||||
@ -1,14 +1,12 @@
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. currentmodule:: torch.compiler.config
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
# torch.compiler.config
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. automodule:: torch.compiler.config
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autodata:: torch.compiler.config.job_id
 | 
			
		||||
   :members:
 | 
			
		||||
   :undoc-members:
 | 
			
		||||
   :show-inheritance:
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -816,6 +816,10 @@ Operator Tags
 | 
			
		||||
.. py:module:: torch.types
 | 
			
		||||
.. py:module:: torch.version
 | 
			
		||||
 | 
			
		||||
.. Compiler configuration module - documented in torch.compiler.config.md
 | 
			
		||||
.. py:module:: torch.compiler.config
 | 
			
		||||
   :noindex:
 | 
			
		||||
 | 
			
		||||
.. Hidden aliases (e.g. torch.functional.broadcast_tensors()). We want `torch.broadcast_tensors()` to
 | 
			
		||||
   be visible only.
 | 
			
		||||
.. toctree::
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										76
									
								
								test/cpp/aoti_abi_check/test_scalartype.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								test/cpp/aoti_abi_check/test_scalartype.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <torch/headeronly/core/ScalarType.h>
 | 
			
		||||
 | 
			
		||||
TEST(TestScalarType, ScalarTypeToCPPTypeT) {
 | 
			
		||||
  using torch::headeronly::ScalarType;
 | 
			
		||||
  using torch::headeronly::impl::ScalarTypeToCPPTypeT;
 | 
			
		||||
 | 
			
		||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
 | 
			
		||||
  EXPECT_EQ(typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE));
 | 
			
		||||
 | 
			
		||||
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
 | 
			
		||||
#undef DEFINE_CHECK
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define DEFINE_CHECK(TYPE, SCALARTYPE)                                       \
 | 
			
		||||
  {                                                                          \
 | 
			
		||||
    EXPECT_EQ(                                                               \
 | 
			
		||||
        typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); \
 | 
			
		||||
    count++;                                                                 \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#define TEST_FORALL(M, EXPECTEDCOUNT, ...)               \
 | 
			
		||||
  TEST(TestScalarType, M) {                              \
 | 
			
		||||
    using torch::headeronly::ScalarType;                 \
 | 
			
		||||
    using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
 | 
			
		||||
    int8_t count = 0;                                    \
 | 
			
		||||
    M(__VA_ARGS__ DEFINE_CHECK);                         \
 | 
			
		||||
    EXPECT_EQ(count, EXPECTEDCOUNT);                     \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14)
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18)
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46)
 | 
			
		||||
TEST_FORALL(AT_FORALL_INT_TYPES, 5)
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7)
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, )
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND2, 9, Bool, Half, )
 | 
			
		||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND3, 10, Bool, Half, ComplexFloat, )
 | 
			
		||||
TEST_FORALL(
 | 
			
		||||
    AT_FORALL_SCALAR_TYPES_AND7,
 | 
			
		||||
    14,
 | 
			
		||||
    Bool,
 | 
			
		||||
    Half,
 | 
			
		||||
    ComplexHalf,
 | 
			
		||||
    ComplexFloat,
 | 
			
		||||
    ComplexDouble,
 | 
			
		||||
    UInt16,
 | 
			
		||||
    UInt32, )
 | 
			
		||||
TEST_FORALL(AT_FORALL_QINT_TYPES, 5)
 | 
			
		||||
TEST_FORALL(AT_FORALL_FLOAT8_TYPES, 5)
 | 
			
		||||
TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
 | 
			
		||||
 | 
			
		||||
#undef DEFINE_CHECK
 | 
			
		||||
#undef TEST_FORALL
 | 
			
		||||
 | 
			
		||||
TEST(TestScalarType, toString) {
 | 
			
		||||
  using torch::headeronly::ScalarType;
 | 
			
		||||
 | 
			
		||||
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
 | 
			
		||||
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
 | 
			
		||||
#undef DEFINE_CHECK
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestScalarType, operator_left_shift) {
 | 
			
		||||
  using torch::headeronly::ScalarType;
 | 
			
		||||
 | 
			
		||||
#define DEFINE_CHECK(_, name)   \
 | 
			
		||||
  {                             \
 | 
			
		||||
    std::stringstream ss;       \
 | 
			
		||||
    ss << ScalarType::name;     \
 | 
			
		||||
    EXPECT_EQ(ss.str(), #name); \
 | 
			
		||||
  }
 | 
			
		||||
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
 | 
			
		||||
#undef DEFINE_CHECK
 | 
			
		||||
}
 | 
			
		||||
@ -6,7 +6,7 @@ import functools
 | 
			
		||||
import itertools
 | 
			
		||||
import unittest
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
from collections.abc import Callable, Iterable
 | 
			
		||||
from typing import Any, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -24,6 +24,11 @@ from torch.distributed.fsdp import (
 | 
			
		||||
    fully_shard,
 | 
			
		||||
    OffloadPolicy,
 | 
			
		||||
    register_fsdp_forward_method,
 | 
			
		||||
    share_comm_ctx,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
 | 
			
		||||
    foreach_all_gather,
 | 
			
		||||
    foreach_reduce,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
 | 
			
		||||
from torch.distributed.tensor.debug import CommDebugMode
 | 
			
		||||
@ -39,6 +44,8 @@ from torch.testing._internal.common_fsdp import (
 | 
			
		||||
    MLP,
 | 
			
		||||
    MLPStack,
 | 
			
		||||
    patch_all_gather,
 | 
			
		||||
    patch_foreach_all_gather,
 | 
			
		||||
    patch_foreach_reduce,
 | 
			
		||||
    patch_reduce_scatter,
 | 
			
		||||
)
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
@ -1487,6 +1494,116 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
 | 
			
		||||
        check_sharded_parity(self, ref_model, model)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestFullyShardShareCommContext(FSDPTest):
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self) -> int:
 | 
			
		||||
        return min(torch.get_device_module(device_type).device_count(), 2)
 | 
			
		||||
 | 
			
		||||
    @skip_if_lt_x_gpu(2)
 | 
			
		||||
    def test_share_comm_context(self):
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
        n_layers = 3
 | 
			
		||||
        lin_dim = 16
 | 
			
		||||
        model = nn.Sequential(
 | 
			
		||||
            *[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)]
 | 
			
		||||
        )
 | 
			
		||||
        ref_model = copy.deepcopy(model).to(device_type)
 | 
			
		||||
        for layer in model:
 | 
			
		||||
            fully_shard(layer)
 | 
			
		||||
            layer._get_fsdp_state()._lazy_init()
 | 
			
		||||
        share_comm_ctx(list(model))
 | 
			
		||||
 | 
			
		||||
        torch.manual_seed(42 + self.rank + 1)
 | 
			
		||||
        inp = torch.randn(4, 3, lin_dim, device=device_type.type)
 | 
			
		||||
        ref_loss = ref_model(inp).sum()
 | 
			
		||||
 | 
			
		||||
        all_gather_streams = set()
 | 
			
		||||
        reduce_scatter_streams = set()
 | 
			
		||||
 | 
			
		||||
        from torch.distributed.fsdp._fully_shard._fsdp_api import (
 | 
			
		||||
            AllGather,
 | 
			
		||||
            ReduceScatter,
 | 
			
		||||
        )
 | 
			
		||||
        from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
 | 
			
		||||
 | 
			
		||||
        orig_foreach_all_gather = foreach_all_gather
 | 
			
		||||
 | 
			
		||||
        def foreach_all_gather_with_assert(
 | 
			
		||||
            fsdp_params: list[FSDPParam],
 | 
			
		||||
            group: dist.ProcessGroup,
 | 
			
		||||
            async_op: bool,
 | 
			
		||||
            all_gather_copy_in_stream: torch.Stream,
 | 
			
		||||
            all_gather_stream: torch.Stream,
 | 
			
		||||
            device: torch.device,
 | 
			
		||||
            all_gather_comm: AllGather,
 | 
			
		||||
        ):
 | 
			
		||||
            nonlocal all_gather_streams
 | 
			
		||||
            all_gather_streams.add(all_gather_stream)
 | 
			
		||||
            return orig_foreach_all_gather(
 | 
			
		||||
                fsdp_params,
 | 
			
		||||
                group,
 | 
			
		||||
                async_op,
 | 
			
		||||
                all_gather_copy_in_stream,
 | 
			
		||||
                all_gather_stream,
 | 
			
		||||
                device,
 | 
			
		||||
                all_gather_comm,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        orig_foreach_reduce = foreach_reduce
 | 
			
		||||
 | 
			
		||||
        @torch.no_grad()
 | 
			
		||||
        def foreach_reduce_with_assert(
 | 
			
		||||
            fsdp_params: list[FSDPParam],
 | 
			
		||||
            unsharded_grads: list[torch.Tensor],
 | 
			
		||||
            reduce_scatter_group: dist.ProcessGroup,
 | 
			
		||||
            reduce_scatter_stream: torch.Stream,
 | 
			
		||||
            reduce_scatter_comm: ReduceScatter,
 | 
			
		||||
            orig_dtype: Optional[torch.dtype],
 | 
			
		||||
            reduce_dtype: Optional[torch.dtype],
 | 
			
		||||
            device: torch.device,
 | 
			
		||||
            gradient_divide_factor: Optional[float],
 | 
			
		||||
            all_reduce_group: Optional[dist.ProcessGroup],  # not `None` iff HSDP
 | 
			
		||||
            all_reduce_stream: torch.Stream,
 | 
			
		||||
            all_reduce_grads: bool,
 | 
			
		||||
            partial_reduce_output: Optional[torch.Tensor],  # only used for HSDP
 | 
			
		||||
            all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
 | 
			
		||||
            force_sum_reduction_for_comms: bool = False,
 | 
			
		||||
        ):
 | 
			
		||||
            nonlocal reduce_scatter_streams
 | 
			
		||||
            reduce_scatter_streams.add(reduce_scatter_stream)
 | 
			
		||||
            return orig_foreach_reduce(
 | 
			
		||||
                fsdp_params,
 | 
			
		||||
                unsharded_grads,
 | 
			
		||||
                reduce_scatter_group,
 | 
			
		||||
                reduce_scatter_stream,
 | 
			
		||||
                reduce_scatter_comm,
 | 
			
		||||
                orig_dtype,
 | 
			
		||||
                reduce_dtype,
 | 
			
		||||
                device,
 | 
			
		||||
                gradient_divide_factor,
 | 
			
		||||
                all_reduce_group,
 | 
			
		||||
                all_reduce_stream,
 | 
			
		||||
                all_reduce_grads,
 | 
			
		||||
                partial_reduce_output,
 | 
			
		||||
                all_reduce_hook,
 | 
			
		||||
                force_sum_reduction_for_comms,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with (
 | 
			
		||||
            patch_foreach_all_gather(foreach_all_gather_with_assert),
 | 
			
		||||
            patch_foreach_reduce(foreach_reduce_with_assert),
 | 
			
		||||
        ):
 | 
			
		||||
            loss = model(inp).sum()
 | 
			
		||||
            self.assertEqual(ref_loss, loss)
 | 
			
		||||
            ref_loss.backward()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            for param in ref_model.parameters():
 | 
			
		||||
                dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
 | 
			
		||||
        self.assertEqual(len(all_gather_streams), 1)
 | 
			
		||||
        self.assertEqual(len(reduce_scatter_streams), 1)
 | 
			
		||||
        check_sharded_parity(self, ref_model, model)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestFullyShardWorldSize1(FSDPTest):
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self) -> int:
 | 
			
		||||
 | 
			
		||||
@ -559,7 +559,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
 | 
			
		||||
            FAIL = 138
 | 
			
		||||
            pc = start_processes(
 | 
			
		||||
                name="echo",
 | 
			
		||||
                entrypoint=bin("echo1.py"),
 | 
			
		||||
                entrypoint=bin("echo4.py"),
 | 
			
		||||
                args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
 | 
			
		||||
                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
 | 
			
		||||
                logs_specs=DefaultLogsSpecs(
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
@ -24,6 +23,5 @@ if __name__ == "__main__":
 | 
			
		||||
        print(f"exit {exitcode} from {rank}", file=sys.stderr)
 | 
			
		||||
        sys.exit(exitcode)
 | 
			
		||||
    else:
 | 
			
		||||
        time.sleep(1000)
 | 
			
		||||
        print(f"{args.msg} stdout from {rank}")
 | 
			
		||||
        print(f"{args.msg} stderr from {rank}", file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								test/distributed/elastic/multiprocessing/bin/echo4.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										29
									
								
								test/distributed/elastic/multiprocessing/bin/echo4.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,29 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(description="test binary, exits with exitcode")
 | 
			
		||||
    parser.add_argument("--exitcode", type=int, default=0)
 | 
			
		||||
    parser.add_argument("msg", type=str)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    rank = int(os.environ["RANK"])
 | 
			
		||||
    exitcode = args.exitcode
 | 
			
		||||
    if exitcode != 0:
 | 
			
		||||
        print(f"exit {exitcode} from {rank}", file=sys.stderr)
 | 
			
		||||
        sys.exit(exitcode)
 | 
			
		||||
    else:
 | 
			
		||||
        time.sleep(1000)
 | 
			
		||||
        print(f"{args.msg} stdout from {rank}")
 | 
			
		||||
        print(f"{args.msg} stderr from {rank}", file=sys.stderr)
 | 
			
		||||
@ -536,6 +536,23 @@ class TestScheduleLowering(TestCase):
 | 
			
		||||
                "compute": ["0F0", "0F1", "   ", "0B0", "0B1"],
 | 
			
		||||
                "comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"],
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                "compute": ["0F0", "0F1", "1F0", "1F1", "1B0", "1B1", "0B0", "0B1"],
 | 
			
		||||
                "comms": [
 | 
			
		||||
                    "0UNSHARD",
 | 
			
		||||
                    "1UNSHARD",
 | 
			
		||||
                    "0F0",
 | 
			
		||||
                    "0F1",
 | 
			
		||||
                    "1F0",
 | 
			
		||||
                    "1F1",
 | 
			
		||||
                    "1B0",
 | 
			
		||||
                    "1B1",
 | 
			
		||||
                    "1RESHARD",
 | 
			
		||||
                    "0B0",
 | 
			
		||||
                    "0B1",
 | 
			
		||||
                    "0RESHARD",
 | 
			
		||||
                ],
 | 
			
		||||
            },
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
    def test_unshard_reshard(self, test_info):
 | 
			
		||||
 | 
			
		||||
@ -203,6 +203,34 @@ class DistConvolutionOpsTest(DTensorTestBase):
 | 
			
		||||
        self.assertTrue(b_dt.grad is not None)
 | 
			
		||||
        self.assertTrue(x_dt.grad is None)
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_conv1d(self):
 | 
			
		||||
        device_mesh = self.build_device_mesh()
 | 
			
		||||
        model = nn.Conv1d(64, 64, 3, padding=1)
 | 
			
		||||
        model_gt = copy.deepcopy(model)
 | 
			
		||||
        x = torch.randn(1, 64, 8)
 | 
			
		||||
        x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
 | 
			
		||||
        model_dt = distribute_module(
 | 
			
		||||
            model, device_mesh, _conv_fn, input_fn=None, output_fn=None
 | 
			
		||||
        )
 | 
			
		||||
        out_dt = model_dt(x_dt)
 | 
			
		||||
        out = model_gt(x)
 | 
			
		||||
        self.assertEqual(out_dt.shape, out.shape)
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_conv3d(self):
 | 
			
		||||
        device_mesh = self.build_device_mesh()
 | 
			
		||||
        model = nn.Conv3d(64, 64, 3, padding=1)
 | 
			
		||||
        model_gt = copy.deepcopy(model).to(device=self.device_type)
 | 
			
		||||
        x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
 | 
			
		||||
        x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
 | 
			
		||||
        model_dt = distribute_module(
 | 
			
		||||
            model, device_mesh, _conv_fn, input_fn=None, output_fn=None
 | 
			
		||||
        )
 | 
			
		||||
        out_dt = model_dt(x_dt)
 | 
			
		||||
        out = model_gt(x)
 | 
			
		||||
        self.assertEqual(out_dt.shape, out.shape)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_tests()
 | 
			
		||||
 | 
			
		||||
@ -3,23 +3,13 @@
 | 
			
		||||
 | 
			
		||||
import pathlib
 | 
			
		||||
import tempfile
 | 
			
		||||
import types
 | 
			
		||||
import unittest
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from numpy.testing import assert_array_equal
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.distributed.distributed_c10d as c10d
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
 | 
			
		||||
from torch.distributed._local_tensor import (
 | 
			
		||||
    LocalIntNode,
 | 
			
		||||
    LocalTensorMode,
 | 
			
		||||
    maybe_run_for_local_tensor,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.device_mesh import init_device_mesh
 | 
			
		||||
from torch.distributed.tensor import (
 | 
			
		||||
    DeviceMesh,
 | 
			
		||||
@ -46,7 +36,9 @@ from torch.distributed.tensor.placement_types import _StridedShard
 | 
			
		||||
from torch.testing import make_tensor
 | 
			
		||||
from torch.testing._internal.common_utils import IS_FBCODE, run_tests, skipIfHpu
 | 
			
		||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
 | 
			
		||||
    create_local_tensor_test_class,
 | 
			
		||||
    DTensorTestBase,
 | 
			
		||||
    map_local_tensor_for_rank,
 | 
			
		||||
    with_comms,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -54,11 +46,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
 | 
			
		||||
c10d_functional = torch.ops.c10d_functional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@maybe_run_for_local_tensor
 | 
			
		||||
def map_tensor_for_rank(tensor, rank, func):
 | 
			
		||||
    return func(tensor, rank)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyMLP(torch.nn.Module):
 | 
			
		||||
    def __init__(self, device):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -251,7 +238,7 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        dtensor = DTensor.from_local(
 | 
			
		||||
            tensor_list[self.rank],
 | 
			
		||||
            map_local_tensor_for_rank(tensor_list, self.rank, lambda tl, r: tl[r]),
 | 
			
		||||
            device_mesh,
 | 
			
		||||
            (Shard(0),),
 | 
			
		||||
            shape=global_tensor.size(),
 | 
			
		||||
@ -279,7 +266,7 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
            RuntimeError, "Please pass both shape and stride at the same time."
 | 
			
		||||
        ):
 | 
			
		||||
            DTensor.from_local(
 | 
			
		||||
                tensor_list[self.rank],
 | 
			
		||||
                map_local_tensor_for_rank(tensor_list, self.rank, lambda tl, r: tl[r]),
 | 
			
		||||
                device_mesh,
 | 
			
		||||
                (Shard(0),),
 | 
			
		||||
                shape=global_tensor.size(),
 | 
			
		||||
@ -289,7 +276,7 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
            RuntimeError, "Please pass both shape and stride at the same time."
 | 
			
		||||
        ):
 | 
			
		||||
            DTensor.from_local(
 | 
			
		||||
                tensor_list[self.rank],
 | 
			
		||||
                map_local_tensor_for_rank(tensor_list, self.rank, lambda tl, r: tl[r]),
 | 
			
		||||
                device_mesh,
 | 
			
		||||
                (Shard(0),),
 | 
			
		||||
                stride=global_tensor.stride(),
 | 
			
		||||
@ -609,7 +596,7 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        local_tensor = sharded_tensor.to_local()
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            local_tensor,
 | 
			
		||||
            map_tensor_for_rank(
 | 
			
		||||
            map_local_tensor_for_rank(
 | 
			
		||||
                full_tensor, self.rank, lambda ft, r: ft[range(r, r + 1), :]
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
@ -622,7 +609,7 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        local_tensor = sharded_tensor.to_local()
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            local_tensor,
 | 
			
		||||
            map_tensor_for_rank(
 | 
			
		||||
            map_local_tensor_for_rank(
 | 
			
		||||
                full_tensor, self.rank, lambda ft, r: ft[:, range(r, r + 1)]
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
@ -645,103 +632,17 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        self.assertEqual(local_tensor.item(), self.rank)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalDTensorTest(DTensorTest):
 | 
			
		||||
    def get_local_tensor_mode(self):
 | 
			
		||||
        return LocalTensorMode(frozenset(range(0, self.world_size)))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def rank(self):
 | 
			
		||||
        return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
 | 
			
		||||
 | 
			
		||||
    @rank.setter
 | 
			
		||||
    def rank(self, rank):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def join_or_run(self, fn):
 | 
			
		||||
        @wraps(fn)
 | 
			
		||||
        def wrapper(self):
 | 
			
		||||
            fn()
 | 
			
		||||
 | 
			
		||||
        return types.MethodType(wrapper, self)
 | 
			
		||||
 | 
			
		||||
    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
 | 
			
		||||
        dist.init_process_group("fake", rank=0, world_size=self.world_size)
 | 
			
		||||
        self._pg = c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
    def destroy_pg(self, device_id: Optional[int] = None) -> None:
 | 
			
		||||
        dist.destroy_process_group(self._pg)
 | 
			
		||||
        self._pg = None
 | 
			
		||||
 | 
			
		||||
    def _spawn_processes(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_constructor(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_meta_dtensor(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_modules_w_meta_dtensor(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_stride(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_uneven_sharding(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_uneven_sharding_raise_error(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_negative_dim(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_to_local(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_to_local_grad_hint(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_full_tensor_sync(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_full_tensor_grad_hint(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_new_empty_strided(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_async_output(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_then_to_local(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_spec_read_only_after_set(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_spec_hash(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_properties(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_save_load(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_save_load_import(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_shard_tensor_2d(self):
 | 
			
		||||
        with self.get_local_tensor_mode():
 | 
			
		||||
            super().test_shard_tensor_2d()
 | 
			
		||||
 | 
			
		||||
    def test_shard_tensor(self):
 | 
			
		||||
        with self.get_local_tensor_mode():
 | 
			
		||||
            super().test_shard_tensor()
 | 
			
		||||
DTensorTestWithLocalTensor = create_local_tensor_test_class(
 | 
			
		||||
    DTensorTest,
 | 
			
		||||
    skipped_tests=[
 | 
			
		||||
        # Async output in local mode is not supported
 | 
			
		||||
        "test_dtensor_async_output",
 | 
			
		||||
        # Disabling saving and loading in local mode since it requires a deeper
 | 
			
		||||
        # integration
 | 
			
		||||
        "test_dtensor_save_load",
 | 
			
		||||
        "test_dtensor_save_load_import",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DTensorMeshTest(DTensorTestBase):
 | 
			
		||||
@ -1119,6 +1020,19 @@ class DTensorMeshTest(DTensorTestBase):
 | 
			
		||||
            self.fail("Unexpected ValueError raised with run_check=False")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
 | 
			
		||||
    DTensorMeshTest,
 | 
			
		||||
    skipped_tests=[
 | 
			
		||||
        # Submeshes are not supported by local tensor mode
 | 
			
		||||
        "test_from_local_sub_mesh",
 | 
			
		||||
        "test_default_value_sub_mesh",
 | 
			
		||||
        "test_redistribute_sub_mesh",
 | 
			
		||||
        # Local tensor mode doesn't support tensors of different types on different ranks
 | 
			
		||||
        "test_metadata_consistency_check",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDTensorPlacementTypes(DTensorTestBase):
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self):
 | 
			
		||||
@ -1185,6 +1099,11 @@ class TestDTensorPlacementTypes(DTensorTestBase):
 | 
			
		||||
                assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TestDTensorPlacementTypesWithLocalTensor = create_local_tensor_test_class(
 | 
			
		||||
    TestDTensorPlacementTypes,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDTensorSpec(DTensorTestBase):
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self):
 | 
			
		||||
@ -1364,5 +1283,9 @@ class TestDTensorSpec(DTensorTestBase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TestDTensorSpecWithLocalTensor = create_local_tensor_test_class(
 | 
			
		||||
    TestDTensorSpec,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_tests()
 | 
			
		||||
 | 
			
		||||
@ -959,6 +959,9 @@ def forward(self, primals_1):
 | 
			
		||||
        out_dt = torch.matmul(tmp_dt, y_dt)
 | 
			
		||||
        out_dt.sum().backward()
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(
 | 
			
		||||
        torch._inductor.config.triton.native_matmul, "Matmul is now generated"
 | 
			
		||||
    )
 | 
			
		||||
    def _test_tp_compile_comm_reordering(self):
 | 
			
		||||
        class FakeAttention(nn.Module):
 | 
			
		||||
            def __init__(self) -> None:
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,9 @@
 | 
			
		||||
# Owner(s): ["oncall: distributed"]
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import copy
 | 
			
		||||
import itertools
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.distributed.device_mesh import init_device_mesh
 | 
			
		||||
@ -15,7 +17,10 @@ from torch.distributed.tensor import (
 | 
			
		||||
    Shard,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
 | 
			
		||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
 | 
			
		||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
 | 
			
		||||
from torch.distributed.tensor.debug import CommDebugMode
 | 
			
		||||
from torch.distributed.tensor.placement_types import _StridedShard
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
    instantiate_parametrized_tests,
 | 
			
		||||
    parametrize,
 | 
			
		||||
@ -27,6 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
 | 
			
		||||
    DTensorTestBase,
 | 
			
		||||
    with_comms,
 | 
			
		||||
)
 | 
			
		||||
from torch.utils._debug_mode import DebugMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
funcol = torch.ops.c10d_functional
 | 
			
		||||
@ -748,5 +754,414 @@ class MultiDimRedistributeTest(DTensorTestBase):
 | 
			
		||||
            self.assertEqual(local_out_dt, local_expected_dt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributeWithDeviceOrderTest(DTensorTestBase):
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self) -> int:
 | 
			
		||||
        return 8
 | 
			
		||||
 | 
			
		||||
    def _extract_redistribute_trace_from_debug_mode(self, s: str) -> str:
 | 
			
		||||
        import re
 | 
			
		||||
 | 
			
		||||
        match = re.search(r"trace:\s*(.*)\)", s)
 | 
			
		||||
        if match:
 | 
			
		||||
            trace_str = match.group(1)
 | 
			
		||||
            return trace_str
 | 
			
		||||
        else:
 | 
			
		||||
            return ""
 | 
			
		||||
 | 
			
		||||
    # TODO(zpcore): remove once the native redistribute supports shard_order arg
 | 
			
		||||
    def redistribute(
 | 
			
		||||
        self,
 | 
			
		||||
        dtensor_input,
 | 
			
		||||
        device_mesh,
 | 
			
		||||
        placements,
 | 
			
		||||
        shard_order,
 | 
			
		||||
        use_graph_based_transform=True,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        wrapper function to support shard_order for redistribution
 | 
			
		||||
        This is a simpler version of Redistribute, only considers the forward.
 | 
			
		||||
        """
 | 
			
		||||
        if placements is None:
 | 
			
		||||
            placements = self._shard_order_to_placement(shard_order, device_mesh)
 | 
			
		||||
        placements = tuple(placements)
 | 
			
		||||
        old_spec = dtensor_input._spec
 | 
			
		||||
        new_spec = copy.deepcopy(old_spec)
 | 
			
		||||
        new_spec.placements = placements
 | 
			
		||||
        if shard_order is not None:
 | 
			
		||||
            new_spec.shard_order = shard_order
 | 
			
		||||
        else:
 | 
			
		||||
            new_spec.shard_order = ()
 | 
			
		||||
        if old_spec == new_spec:
 | 
			
		||||
            return dtensor_input
 | 
			
		||||
        dtensor_input = DTensor.from_local(
 | 
			
		||||
            redistribute_local_tensor(
 | 
			
		||||
                dtensor_input.to_local(),
 | 
			
		||||
                old_spec,
 | 
			
		||||
                new_spec,
 | 
			
		||||
                use_graph_based_transform=use_graph_based_transform,
 | 
			
		||||
            ),
 | 
			
		||||
            device_mesh,
 | 
			
		||||
        )
 | 
			
		||||
        dtensor_input._spec = copy.deepcopy(new_spec)
 | 
			
		||||
        return dtensor_input  # returns DTensor
 | 
			
		||||
 | 
			
		||||
    # TODO(zpcore): remove once the native distribute_tensor supports
 | 
			
		||||
    # shard_order arg
 | 
			
		||||
    def distribute_tensor(
 | 
			
		||||
        self,
 | 
			
		||||
        input_tensor,
 | 
			
		||||
        device_mesh,
 | 
			
		||||
        placements,
 | 
			
		||||
        shard_order,
 | 
			
		||||
        use_graph_based_transform=True,
 | 
			
		||||
    ):
 | 
			
		||||
        """wrapper function to support shard_order for tensor distribution"""
 | 
			
		||||
        if placements is None:
 | 
			
		||||
            placements = self._shard_order_to_placement(shard_order, device_mesh)
 | 
			
		||||
        placements = tuple(placements)
 | 
			
		||||
        tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
 | 
			
		||||
        # fix the shard order
 | 
			
		||||
        return self.redistribute(
 | 
			
		||||
            tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # TODO(zpcore): remove once the native redistribute supports shard_order arg
 | 
			
		||||
    def full_tensor(self, dtensor_input):
 | 
			
		||||
        """wrapper function to support DTensor.full_tensor"""
 | 
			
		||||
        return self.redistribute(
 | 
			
		||||
            dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
 | 
			
		||||
        ).to_local()
 | 
			
		||||
 | 
			
		||||
    def _shard_order_to_placement(self, shard_order, mesh):
 | 
			
		||||
        """convert shard_order to placement with only Replicate() and Shard()"""
 | 
			
		||||
        placements = [Replicate() for _ in range(mesh.ndim)]
 | 
			
		||||
        if shard_order is not None:
 | 
			
		||||
            for entry in shard_order:
 | 
			
		||||
                tensor_dim = entry.tensor_dim
 | 
			
		||||
                mesh_dims = entry.mesh_dims
 | 
			
		||||
                for mesh_dim in mesh_dims:
 | 
			
		||||
                    placements[mesh_dim] = Shard(tensor_dim)
 | 
			
		||||
        return tuple(placements)
 | 
			
		||||
 | 
			
		||||
    def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
 | 
			
		||||
        """Convert shard_order dict to ShardOrder"""
 | 
			
		||||
        return tuple(
 | 
			
		||||
            ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
 | 
			
		||||
            for tensor_dim, mesh_dims in shard_order.items()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_ordered_redistribute(self):
 | 
			
		||||
        """Test ordered redistribution with various sharding syntaxes"""
 | 
			
		||||
        torch.manual_seed(21)
 | 
			
		||||
        mesh = init_device_mesh(self.device_type, (2, 2, 2))
 | 
			
		||||
        input_data = torch.randn((8, 8, 8), device=self.device_type)
 | 
			
		||||
        sharding_src_dst_pairs_with_expected_trace = [
 | 
			
		||||
            (
 | 
			
		||||
                (
 | 
			
		||||
                    [Shard(0), Shard(0), Shard(0)],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1, 2)),),
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    [Replicate(), Shard(0), Shard(0)],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 2)),),
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
            (
 | 
			
		||||
                (
 | 
			
		||||
                    [Shard(0), Shard(0), Shard(0)],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0, 2)),),
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    [Replicate(), Shard(0), Shard(0)],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 2)),),
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
            (
 | 
			
		||||
                (
 | 
			
		||||
                    [Shard(0), Shard(0), Shard(0)],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0, 2)),),
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    [Shard(0), Shard(0), Replicate()],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1)),),
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
            # If we use the graph search solution, the redistribution path will
 | 
			
		||||
            # be S(0)[0, 1] -> S(0)[0]S(1)[1] -> S(1)[1] -> S(0)[2]S(1)[1],
 | 
			
		||||
            # which takes only 1 comm count. However, this placement follows the
 | 
			
		||||
            # default device order and the greedy solution will be triggered,
 | 
			
		||||
            # which results in path: S(0)[0, 1] -> S(0)[0]S(1)[1] -> S(1)[1] ->
 | 
			
		||||
            # S(0)[2]S(1)[1] with 2 comm count
 | 
			
		||||
            (
 | 
			
		||||
                (
 | 
			
		||||
                    [Shard(0), Shard(0), Replicate()],
 | 
			
		||||
                    (ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1)),),
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    [Replicate(), Shard(1), Shard(0)],
 | 
			
		||||
                    (
 | 
			
		||||
                        ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
 | 
			
		||||
                        ShardOrderEntry(tensor_dim=1, mesh_dims=(1,)),
 | 
			
		||||
                    ),
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
        for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
 | 
			
		||||
            sharding_src_dst_pairs_with_expected_trace
 | 
			
		||||
        ):
 | 
			
		||||
            sharded_dt = self.distribute_tensor(
 | 
			
		||||
                input_data.clone(), mesh, src_placement, shard_order=src_order
 | 
			
		||||
            )
 | 
			
		||||
            with DebugMode(record_torchfunction=False) as debug_mode:
 | 
			
		||||
                sharded_dt = self.redistribute(
 | 
			
		||||
                    sharded_dt, mesh, dst_placement, dst_order
 | 
			
		||||
                )
 | 
			
		||||
            trace_str = self._extract_redistribute_trace_from_debug_mode(
 | 
			
		||||
                debug_mode.debug_string()
 | 
			
		||||
            )
 | 
			
		||||
            if idx == 0:
 | 
			
		||||
                self.assertExpectedInline(
 | 
			
		||||
                    trace_str,
 | 
			
		||||
                    """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]S(1)->S(0)S(1)[1]S(1)[0]->RS(1)[1]S(1)[0]->RS(0)S(1)->RS(0)[0]S(0)[1]""",
 | 
			
		||||
                )
 | 
			
		||||
            elif idx == 1:
 | 
			
		||||
                self.assertExpectedInline(
 | 
			
		||||
                    trace_str,
 | 
			
		||||
                    """S(0)[1]S(0)[0]S(0)[2]->S(0)[1]S(0)[0]S(1)->RS(0)S(1)->RS(0)[0]S(0)[1]""",
 | 
			
		||||
                )
 | 
			
		||||
            elif idx == 2:
 | 
			
		||||
                self.assertExpectedInline(
 | 
			
		||||
                    trace_str,
 | 
			
		||||
                    """S(0)[1]S(0)[0]S(0)[2]->S(0)[1]S(0)[0]R->S(1)S(0)R->S(1)S(2)R->S(0)S(2)R->S(0)[0]S(0)[1]R""",
 | 
			
		||||
                )
 | 
			
		||||
            elif idx == 3:
 | 
			
		||||
                self.assertExpectedInline(
 | 
			
		||||
                    trace_str,
 | 
			
		||||
                    """S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
 | 
			
		||||
                )
 | 
			
		||||
            expected_dt = self.distribute_tensor(
 | 
			
		||||
                input_data.clone(), mesh, dst_placement, shard_order=dst_order
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
 | 
			
		||||
 | 
			
		||||
    def generate_shard_orders(self, mesh, tensor_rank):
 | 
			
		||||
        # Generate all possible sharding placement of tensor with rank
 | 
			
		||||
        # `tensor_rank` over mesh.
 | 
			
		||||
        def _split_list(lst: list, N: int):
 | 
			
		||||
            def compositions(n, k):
 | 
			
		||||
                if k == 1:
 | 
			
		||||
                    yield [n]
 | 
			
		||||
                else:
 | 
			
		||||
                    for i in range(1, n - k + 2):
 | 
			
		||||
                        for tail in compositions(n - i, k - 1):
 | 
			
		||||
                            yield [i] + tail
 | 
			
		||||
 | 
			
		||||
            length = len(lst)
 | 
			
		||||
            for comp in compositions(length, N):
 | 
			
		||||
                result = []
 | 
			
		||||
                start = 0
 | 
			
		||||
                for size in comp:
 | 
			
		||||
                    result.append(lst[start : start + size])
 | 
			
		||||
                    start += size
 | 
			
		||||
                yield result
 | 
			
		||||
 | 
			
		||||
        all_mesh = list(range(mesh.ndim))
 | 
			
		||||
        all_device_order = list(itertools.permutations(all_mesh))
 | 
			
		||||
        for device_order in all_device_order:
 | 
			
		||||
            # split on device orders, and assign each device order segment to a tensor dim
 | 
			
		||||
            for num_split in range(1, mesh.ndim + 1):
 | 
			
		||||
                for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
 | 
			
		||||
                    for tensor_dims in itertools.combinations(
 | 
			
		||||
                        range(tensor_rank), len(splitted_list)
 | 
			
		||||
                    ):
 | 
			
		||||
                        shard_order = {}
 | 
			
		||||
                        assert len(tensor_dims) == len(splitted_list)
 | 
			
		||||
                        for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
 | 
			
		||||
                            shard_order[tensor_dim] = device_order[
 | 
			
		||||
                                mesh_dims[0] : mesh_dims[-1] + 1
 | 
			
		||||
                            ]
 | 
			
		||||
                        yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_generate_shard_orders(self):
 | 
			
		||||
        """Check if `generate_shard_orders` generates unique sharding combinations"""
 | 
			
		||||
        import math
 | 
			
		||||
 | 
			
		||||
        test_inputs = [
 | 
			
		||||
            {"mesh": init_device_mesh(self.device_type, (2, 2, 2)), "tensor_rank": 2},
 | 
			
		||||
            {"mesh": init_device_mesh(self.device_type, (2, 2, 2)), "tensor_rank": 3},
 | 
			
		||||
            {"mesh": init_device_mesh(self.device_type, (2, 2, 2)), "tensor_rank": 4},
 | 
			
		||||
        ]
 | 
			
		||||
        for test_input in test_inputs:
 | 
			
		||||
            all_combinations = []
 | 
			
		||||
            for shard_order in self.generate_shard_orders(
 | 
			
		||||
                test_input["mesh"], test_input["tensor_rank"]
 | 
			
		||||
            ):
 | 
			
		||||
                all_combinations.append(shard_order)  # noqa: PERF402
 | 
			
		||||
            for i in range(len(all_combinations)):
 | 
			
		||||
                for j in range(i + 1, len(all_combinations)):
 | 
			
		||||
                    assert all_combinations[i] != all_combinations[j], (
 | 
			
		||||
                        f"Duplicate elements found in all_combinations {all_combinations[i]}, {all_combinations[j]}"
 | 
			
		||||
                    )
 | 
			
		||||
            expected_total_combination = 0
 | 
			
		||||
            N = test_input["mesh"].ndim
 | 
			
		||||
            M = test_input["tensor_rank"]
 | 
			
		||||
            for i in range(1, N + 1):
 | 
			
		||||
                # assign total i split of device to tensor dims
 | 
			
		||||
                if M < i:
 | 
			
		||||
                    continue
 | 
			
		||||
                device_combination_count = math.comb(
 | 
			
		||||
                    N - 1, i - 1
 | 
			
		||||
                )  # choose i-1 non-empty segments from a list of size N
 | 
			
		||||
                tensor_dim_order_permutation = math.comb(M, i)  # choose i tensor dims
 | 
			
		||||
                expected_total_combination += (
 | 
			
		||||
                    device_combination_count * tensor_dim_order_permutation
 | 
			
		||||
                )
 | 
			
		||||
            # multiply by total possible permutation of device order
 | 
			
		||||
            expected_total_combination *= math.factorial(N)
 | 
			
		||||
            self.assertEqual(len(all_combinations), expected_total_combination)
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_ordered_distribute_all_combination(self):
 | 
			
		||||
        """Exhaustively test all possible sharding combinations and verify correctness"""
 | 
			
		||||
        torch.manual_seed(21)
 | 
			
		||||
        mesh = init_device_mesh(self.device_type, (2, 2, 2))
 | 
			
		||||
        input_tensor_shape = [
 | 
			
		||||
            # even sharding
 | 
			
		||||
            (16, 8),
 | 
			
		||||
            (8, 16, 32),
 | 
			
		||||
            (8, 32, 16, 16),
 | 
			
		||||
            # uneven sharding with padding
 | 
			
		||||
            (17, 5),
 | 
			
		||||
            (13, 2, 13),
 | 
			
		||||
            (33, 16, 8, 1),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        # 1. Verify correctness of distribute_tensor from Tensor to DTensor.
 | 
			
		||||
        for tensor_shape in input_tensor_shape:
 | 
			
		||||
            input_data = torch.randn(tensor_shape, device=self.device_type)
 | 
			
		||||
            tensor_rank = input_data.ndim
 | 
			
		||||
            for shard_order in self.generate_shard_orders(mesh, tensor_rank):
 | 
			
		||||
                sharded_dt = self.distribute_tensor(
 | 
			
		||||
                    input_data.clone(), mesh, placements=None, shard_order=shard_order
 | 
			
		||||
                )
 | 
			
		||||
                self.assertEqual(self.full_tensor(sharded_dt), input_data)
 | 
			
		||||
 | 
			
		||||
        # 2. Verify the correctness of redistribution from DTensor to DTensor.
 | 
			
		||||
        # This test repeatedly redistributes a DTensor to various ordered
 | 
			
		||||
        # placements and checks that the resulting tensor matches the original
 | 
			
		||||
        # full tensor.
 | 
			
		||||
        for tensor_shape in input_tensor_shape:
 | 
			
		||||
            input_data = torch.randn(tensor_shape, device=self.device_type)
 | 
			
		||||
            tensor_rank = input_data.ndim
 | 
			
		||||
            prev_sharded_dt = None
 | 
			
		||||
            for shard_order in self.generate_shard_orders(mesh, tensor_rank):
 | 
			
		||||
                if prev_sharded_dt is None:
 | 
			
		||||
                    prev_sharded_dt = self.distribute_tensor(
 | 
			
		||||
                        input_data.clone(),
 | 
			
		||||
                        mesh,
 | 
			
		||||
                        placements=None,
 | 
			
		||||
                        shard_order=shard_order,
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    sharded_dt = self.redistribute(
 | 
			
		||||
                        prev_sharded_dt, mesh, placements=None, shard_order=shard_order
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertEqual(self.full_tensor(sharded_dt), input_data)
 | 
			
		||||
                    prev_sharded_dt = sharded_dt
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_ordered_redistribute_with_partial(self):
 | 
			
		||||
        """Test mixing Partial in the original placements and do redistribute."""
 | 
			
		||||
        # This test takes 226s to complete on 8XA100...
 | 
			
		||||
        torch.manual_seed(21)
 | 
			
		||||
        mesh = init_device_mesh(self.device_type, (2, 2, 2))
 | 
			
		||||
        input_tensor_shape = [
 | 
			
		||||
            # even sharding
 | 
			
		||||
            (16, 8),
 | 
			
		||||
            (8, 16, 32),
 | 
			
		||||
            # uneven sharding with padding
 | 
			
		||||
            (17, 5),
 | 
			
		||||
            (13, 2, 13),
 | 
			
		||||
            (33, 16, 8, 1),
 | 
			
		||||
        ]
 | 
			
		||||
        placement_choice = [
 | 
			
		||||
            Shard(0),
 | 
			
		||||
            Shard(1),
 | 
			
		||||
            Shard(2),
 | 
			
		||||
            Partial("sum"),
 | 
			
		||||
            Partial("min"),
 | 
			
		||||
            Replicate(),
 | 
			
		||||
        ]
 | 
			
		||||
        # pick 3 for the 3D mesh
 | 
			
		||||
        partial_placement_comb = list(itertools.combinations(placement_choice, 3))
 | 
			
		||||
 | 
			
		||||
        def _is_valid_placement(placements, tensor_rank):
 | 
			
		||||
            # Check if placements is valid for tensor with rank `tensor_rank`
 | 
			
		||||
            for placement in placements:
 | 
			
		||||
                if isinstance(placement, Shard):
 | 
			
		||||
                    if placement.dim >= tensor_rank:
 | 
			
		||||
                        return False
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        for shape in input_tensor_shape:
 | 
			
		||||
            for placements in partial_placement_comb:
 | 
			
		||||
                if not _is_valid_placement(placements, len(shape)):
 | 
			
		||||
                    continue
 | 
			
		||||
                local_tensor = torch.randn(shape, device=self.device_type)
 | 
			
		||||
                full_tensor = DTensor.from_local(local_tensor, mesh, placements)
 | 
			
		||||
                for shard_order in self.generate_shard_orders(mesh, len(shape)):
 | 
			
		||||
                    sharded_dt = self.redistribute(
 | 
			
		||||
                        full_tensor, mesh, placements=None, shard_order=shard_order
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertEqual(
 | 
			
		||||
                        self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        "Temporarily skipping until we support special placement types in "
 | 
			
		||||
        "graph based redistribution"
 | 
			
		||||
    )
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_ordered_redistribute_for_special_placement(self):
 | 
			
		||||
        """Test ordered redistribution with special placement"""
 | 
			
		||||
        from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
 | 
			
		||||
 | 
			
		||||
        torch.manual_seed(21)
 | 
			
		||||
        mesh = init_device_mesh(self.device_type, (8,))
 | 
			
		||||
        input_data = torch.randn((8, 8), device=self.device_type)
 | 
			
		||||
        src_placement = [Shard(1)]
 | 
			
		||||
        tgt_placement = [
 | 
			
		||||
            (_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
 | 
			
		||||
        ]
 | 
			
		||||
        sharded_dt = self.distribute_tensor(
 | 
			
		||||
            input_data.clone(),
 | 
			
		||||
            mesh,
 | 
			
		||||
            src_placement,
 | 
			
		||||
            shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
 | 
			
		||||
        )
 | 
			
		||||
        sharded_dt = self.redistribute(
 | 
			
		||||
            sharded_dt, mesh, tgt_placement, shard_order=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_shard_order_same_data_as_strided_shard(self):
 | 
			
		||||
        device_mesh = init_device_mesh(self.device_type, (4, 2))
 | 
			
		||||
        x = torch.randn(8, 4, device=self.device_type)
 | 
			
		||||
        # specify right-to-left order use _StridedShard
 | 
			
		||||
        strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
 | 
			
		||||
        x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
 | 
			
		||||
        # specify right-to-left order use ordered shard
 | 
			
		||||
        x_ordered_dt = self.distribute_tensor(
 | 
			
		||||
            x,
 | 
			
		||||
            device_mesh,
 | 
			
		||||
            placements=[Shard(0), Shard(0)],
 | 
			
		||||
            shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_tests()
 | 
			
		||||
 | 
			
		||||
@ -64,11 +64,14 @@ def get_patches():
 | 
			
		||||
    return {
 | 
			
		||||
        "test_configs.estimate_aten_runtime": estimate_aten_runtime,
 | 
			
		||||
        "reorder_for_locality": False,
 | 
			
		||||
        "triton.native_matmul": False,
 | 
			
		||||
        "reorder_for_compute_comm_overlap_passes": [],
 | 
			
		||||
        "compile_threads": 1,
 | 
			
		||||
        "force_disable_caches": True,
 | 
			
		||||
        # Messes up existing test strings
 | 
			
		||||
        "test_configs.aten_fx_overlap_insert_overlap_deps": False,
 | 
			
		||||
        # interferes with testing, / custom estimation
 | 
			
		||||
        "test_configs.assume_bucketing_reduces_latency": False,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -357,11 +360,14 @@ def get_bucket_patches(compute_multiplier=1.0):
 | 
			
		||||
        "test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
 | 
			
		||||
        "test_configs.aten_fx_overlap_preserving_bucketing": True,
 | 
			
		||||
        "reorder_for_locality": False,
 | 
			
		||||
        "triton.native_matmul": False,
 | 
			
		||||
        "reorder_for_compute_comm_overlap_passes": [],
 | 
			
		||||
        "compile_threads": 1,
 | 
			
		||||
        "force_disable_caches": True,
 | 
			
		||||
        # messes up test strings
 | 
			
		||||
        "test_configs.aten_fx_overlap_insert_overlap_deps": False,
 | 
			
		||||
        # interferes with testing, / custom estimation
 | 
			
		||||
        "test_configs.assume_bucketing_reduces_latency": False,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -577,7 +583,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
 | 
			
		||||
    @torch._inductor.config.patch(get_bucket_patches(2.0))
 | 
			
		||||
    def test_bucketing_split_for_overlap_blocking(self):
 | 
			
		||||
    def test_bucketing_split_for_overlap_blocking_no_deps(self):
 | 
			
		||||
        """Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
 | 
			
		||||
 | 
			
		||||
        def func(a, b, c, d, *, ranks):
 | 
			
		||||
 | 
			
		||||
@ -938,6 +938,9 @@ class CompileTest(TestCase):
 | 
			
		||||
        assert "torch.ops._c10d_functional.wait_tensor.default" in code
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
 | 
			
		||||
    @unittest.skipIf(
 | 
			
		||||
        torch._inductor.config.triton.native_matmul, "no extern_kernels.mm"
 | 
			
		||||
    )
 | 
			
		||||
    @fresh_cache()
 | 
			
		||||
    def test_inductor_reuse_buffer_after_inplace_collective(self):
 | 
			
		||||
        def func(arg: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
@ -78,6 +78,10 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@requires_accelerator_dist_backend()
 | 
			
		||||
@unittest.skipIf(
 | 
			
		||||
    torch._inductor.config.triton.native_matmul,
 | 
			
		||||
    "native matmul is fused with surrounding ops",
 | 
			
		||||
)
 | 
			
		||||
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
 | 
			
		||||
    """
 | 
			
		||||
    Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
 | 
			
		||||
@ -367,6 +371,10 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
 | 
			
		||||
            self.assertTrue(same(out, correct))
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
 | 
			
		||||
    @unittest.skipIf(
 | 
			
		||||
        torch._inductor.config.triton.native_matmul,
 | 
			
		||||
        "native matmul is fused with surrounding ops",
 | 
			
		||||
    )
 | 
			
		||||
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
 | 
			
		||||
    @patch.object(torch._inductor.config, "compile_threads", 1)
 | 
			
		||||
    @patch.object(
 | 
			
		||||
 | 
			
		||||
@ -7,8 +7,13 @@ from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.multiprocessing.reductions import reduce_tensor
 | 
			
		||||
from torch.testing._internal.common_cuda import SM100OrLater
 | 
			
		||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
 | 
			
		||||
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
    requires_cuda_p2p_access,
 | 
			
		||||
    run_tests,
 | 
			
		||||
    skip_but_pass_in_sandcastle_if,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# So that tests are written in device-agnostic way
 | 
			
		||||
@ -59,6 +64,10 @@ class CupyAsTensorTest(MultiProcContinuousTest):
 | 
			
		||||
    def device(self) -> torch.device:
 | 
			
		||||
        return torch.device(device_type, self.rank)
 | 
			
		||||
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(
 | 
			
		||||
        SM100OrLater,
 | 
			
		||||
        "Fails if ran in docker environment without privileged access (https://github.com/pytorch/pytorch/issues/165170)",
 | 
			
		||||
    )
 | 
			
		||||
    def test_cupy_as_tensor(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Test that torch.as_tensor works for cupy array interface
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@
 | 
			
		||||
# Owner(s): ["oncall: distributed"]
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
@ -40,6 +41,13 @@ from torch.utils._typing_utils import not_none
 | 
			
		||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
 | 
			
		||||
device_count = torch.accelerator.device_count()
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import torch._C._distributed_c10d.ProcessGroupNCCL
 | 
			
		||||
 | 
			
		||||
    _NCCL_AVAILABLE = True
 | 
			
		||||
except ImportError:
 | 
			
		||||
    _NCCL_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
 | 
			
		||||
    os.environ["MASTER_ADDR"] = addr
 | 
			
		||||
@ -962,6 +970,85 @@ class TestDeviceMeshGetItem(DTensorTestBase):
 | 
			
		||||
        # check flattened mesh dependency
 | 
			
		||||
        self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d)
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_unflatten_mesh_2d(self):
 | 
			
		||||
        mesh_shape = (4, 2)
 | 
			
		||||
        mesh_dim_names = ("dp", "tp")
 | 
			
		||||
        mesh_2d = init_device_mesh(
 | 
			
		||||
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
 | 
			
		||||
        )
 | 
			
		||||
        unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
 | 
			
		||||
        self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group())
 | 
			
		||||
 | 
			
		||||
        # Not supporting slicing out unflatten dim name from root mesh.
 | 
			
		||||
        with self.assertRaises(KeyError):
 | 
			
		||||
            self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh)
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_unflatten_mesh_3d(self):
 | 
			
		||||
        # Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
 | 
			
		||||
        global_mesh = init_device_mesh(
 | 
			
		||||
            self.device_type,
 | 
			
		||||
            (8,),
 | 
			
		||||
            mesh_dim_names=("world",),
 | 
			
		||||
        )
 | 
			
		||||
        non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
 | 
			
		||||
        ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
 | 
			
		||||
        self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
 | 
			
		||||
        self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
 | 
			
		||||
        mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))
 | 
			
		||||
        unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh)
 | 
			
		||||
        self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group())
 | 
			
		||||
        self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh)
 | 
			
		||||
        self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group())
 | 
			
		||||
 | 
			
		||||
        # Test unflatten with backend override set.
 | 
			
		||||
        if not _NCCL_AVAILABLE:
 | 
			
		||||
            return
 | 
			
		||||
        opts = dist.ProcessGroupNCCL.Options()
 | 
			
		||||
        opts._timeout = timedelta(seconds=30)
 | 
			
		||||
        mesh_2d = global_mesh._unflatten(
 | 
			
		||||
            0,
 | 
			
		||||
            (1, 8),
 | 
			
		||||
            ("pp", "spmd"),
 | 
			
		||||
            backend_override={"pp": "fake", "spmd": ("nccl", opts)},
 | 
			
		||||
        )
 | 
			
		||||
        opts = dist.ProcessGroupNCCL.Options()
 | 
			
		||||
        opts._timeout = timedelta(seconds=60)
 | 
			
		||||
        mesh_4d = mesh_2d._unflatten(
 | 
			
		||||
            1,
 | 
			
		||||
            (2, 2, 2),
 | 
			
		||||
            ("dp", "cp", "tp"),
 | 
			
		||||
            backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)},
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom")
 | 
			
		||||
        spmd_pg = mesh_2d["spmd"].get_group()
 | 
			
		||||
        self.assertEqual(spmd_pg._get_backend_name(), "nccl")
 | 
			
		||||
        w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            spmd_pg._get_backend(
 | 
			
		||||
                torch.device(f"cuda:{self.rank}")
 | 
			
		||||
            )._verify_work_timeout(w, timedelta(seconds=30))
 | 
			
		||||
        )
 | 
			
		||||
        w.wait()
 | 
			
		||||
        tp_pg = mesh_4d["tp"].get_group()
 | 
			
		||||
        self.assertEqual(tp_pg._get_backend_name(), "nccl")
 | 
			
		||||
        w = tp_pg.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout(
 | 
			
		||||
                w, timedelta(seconds=60)
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        w.wait()
 | 
			
		||||
 | 
			
		||||
    @with_comms
 | 
			
		||||
    def test_reconstruct_mesh_with_flatten_dim(self):
 | 
			
		||||
        mesh_3d = init_device_mesh(
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user