mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			main-enabl
			...
			gh/fffrog/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 74f91b46a0 | |||
| 224fe1ffe6 | |||
| b2660ebfd0 | |||
| b17827f74c | |||
| f441900db8 | |||
| e67f5582eb | |||
| 29ca2cfccf | 
| @ -187,22 +187,19 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then | ||||
|             export USE_CUFILE=0 | ||||
|         else | ||||
|             DEPS_LIST+=( | ||||
|                 "/usr/local/cuda/lib64/libnvToolsExt.so.1" | ||||
|                 "/usr/local/cuda/lib64/libcublas.so.12" | ||||
|                 "/usr/local/cuda/lib64/libcublasLt.so.12" | ||||
|                 "/usr/local/cuda/lib64/libcudart.so.12" | ||||
|                 "/usr/local/cuda/lib64/libnvrtc.so.12" | ||||
|                 "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12") | ||||
|             DEPS_SONAME+=( | ||||
|                 "libnvToolsExt.so.1" | ||||
|                 "libcublas.so.12" | ||||
|                 "libcublasLt.so.12" | ||||
|                 "libcudart.so.12" | ||||
|                 "libnvrtc.so.12" | ||||
|                 "libcupti.so.12") | ||||
|  | ||||
|             if [[ $CUDA_VERSION != 12.9* ]]; then | ||||
|                 DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1") | ||||
|                 DEPS_SONAME+=("libnvToolsExt.so.1") | ||||
|             fi | ||||
|         fi | ||||
|     else | ||||
|         echo "Using nvidia libs from pypi." | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							| @ -8,7 +8,6 @@ assignees: '' | ||||
| --- | ||||
|  | ||||
| > NOTE: Remember to label this issue with "`ci: sev`" | ||||
| >       If you want autorevert to be disabled, keep the ci: disable-autorevert label | ||||
|  | ||||
|  <!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open --> | ||||
|  | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							| @ -1,7 +1,7 @@ | ||||
| --- | ||||
| name: "D❌\U0001F519 ISABLE AUTOREVERT" | ||||
| name: DISABLE AUTOREVERT | ||||
| about: Disables autorevert when open | ||||
| title: "[DISABLE AUTOREVERT]" | ||||
| title: "❌\U0001F519 [DISABLE AUTOREVERT]" | ||||
| labels: 'ci: disable-autorevert' | ||||
| assignees: '' | ||||
|  | ||||
|  | ||||
| @ -65,7 +65,7 @@ runs: | ||||
|           cd .ci/lumen_cli | ||||
|           python3 -m pip install -e . | ||||
|         ) | ||||
|         MAX_JOBS="$(nproc --ignore=10)" | ||||
|         MAX_JOBS="$(nproc --ignore=6)" | ||||
|         export MAX_JOBS | ||||
|  | ||||
|         # Split the comma-separated list and build each target | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| 1b013f5b5a87a1882eb143c26d79d091150d6a37 | ||||
| 8ad2aa5d354d1bf432339113860185d5a5d1abbd | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| faffd5cf673615583da6517275e361cb3dbc77e6 | ||||
| f5c6c2ec6490455e86f67b2a25c10390d60a27f7 | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							| @ -3,7 +3,6 @@ ciflow_tracking_issue: 64124 | ||||
| ciflow_push_tags: | ||||
| - ciflow/b200 | ||||
| - ciflow/b200-symm-mem | ||||
| - ciflow/b200-distributed | ||||
| - ciflow/binaries | ||||
| - ciflow/binaries_libtorch | ||||
| - ciflow/binaries_wheel | ||||
| @ -16,8 +15,7 @@ ciflow_push_tags: | ||||
| - ciflow/inductor-micro-benchmark | ||||
| - ciflow/inductor-micro-benchmark-cpu-x86 | ||||
| - ciflow/inductor-perf-compare | ||||
| - ciflow/inductor-perf-test-nightly-rocm-mi300 | ||||
| - ciflow/inductor-perf-test-nightly-rocm-mi355 | ||||
| - ciflow/inductor-perf-test-nightly-rocm | ||||
| - ciflow/inductor-perf-test-nightly-x86-zen | ||||
| - ciflow/inductor-periodic | ||||
| - ciflow/inductor-rocm | ||||
|  | ||||
							
								
								
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,62 +0,0 @@ | ||||
| name: CI for distributed tests on B200 | ||||
|  | ||||
| on: | ||||
|   pull_request: | ||||
|     paths: | ||||
|       - .github/workflows/b200-distributed.yml | ||||
|   workflow_dispatch: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/b200-distributed/* | ||||
|   schedule: | ||||
|     - cron: 46 8 * * *  # about 1:46am PDT | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: | ||||
|   id-token: write | ||||
|   contents: read | ||||
|  | ||||
| jobs: | ||||
|  | ||||
|   get-label-type: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200: | ||||
|     name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       runner: linux.12xlarge.memory | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 | ||||
|       cuda-arch-list: '10.0' | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" }, | ||||
|           { config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200: | ||||
|     name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200 | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200 | ||||
|     with: | ||||
|       timeout-minutes: 1200 | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 | ||||
|       docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }} | ||||
|       aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only | ||||
|     secrets: inherit | ||||
							
								
								
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							| @ -27,8 +27,9 @@ jobs: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         python-version: [ '3.12' ] | ||||
|         # TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved | ||||
|         platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ] | ||||
|         device: [ 'cu128', 'cu129', 'cu130' ] | ||||
|         device: [ 'cu128', 'cu129' ] | ||||
|         include: | ||||
|           - platform: manylinux_2_28_x86_64 | ||||
|             device: cu128 | ||||
| @ -38,10 +39,6 @@ jobs: | ||||
|             device: cu129 | ||||
|             manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9' | ||||
|             runner: linux.12xlarge.memory | ||||
|           - platform: manylinux_2_28_x86_64 | ||||
|             device: cu130 | ||||
|             manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0' | ||||
|             runner: linux.12xlarge.memory | ||||
|           - platform: manylinux_2_28_aarch64 | ||||
|             device: cu128 | ||||
|             manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8' | ||||
| @ -50,11 +47,6 @@ jobs: | ||||
|             device: cu129 | ||||
|             manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9' | ||||
|             runner: linux.arm64.r7g.12xlarge.memory | ||||
|         exclude: | ||||
|           # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and | ||||
|           # xformers is update to support 13.0 | ||||
|           - platform: manylinux_2_28_aarch64 | ||||
|             device: cu130 | ||||
|     name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}" | ||||
|     runs-on: ${{ matrix.runner }} | ||||
|     timeout-minutes: 480 | ||||
| @ -177,12 +169,7 @@ jobs: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ] | ||||
|         device: [ 'cu128', 'cu129', 'cu130' ] | ||||
|         exclude: | ||||
|           # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and | ||||
|           # xformers is update to support 13.0 | ||||
|           - platform: manylinux_2_28_aarch64 | ||||
|             device: cu130 | ||||
|         device: [ 'cu128', 'cu129' ] | ||||
|     env: | ||||
|       PLATFORM: ${{ matrix.platform }} | ||||
|       BUILD_DEVICE: ${{ matrix.device }} | ||||
|  | ||||
| @ -1,132 +0,0 @@ | ||||
| name: inductor-perf-nightly-rocm-mi300 | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/inductor-perf-test-nightly-rocm-mi300/* | ||||
|   schedule: | ||||
|     - cron: 15 0 * * * | ||||
|   # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it | ||||
|   # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       training: | ||||
|         description: Run training (on by default)? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: true | ||||
|       inference: | ||||
|         description: Run inference (on by default)? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: true | ||||
|       default: | ||||
|         description: Run inductor_default? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: false | ||||
|       dynamic: | ||||
|         description: Run inductor_dynamic_shapes? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: false | ||||
|       cppwrapper: | ||||
|         description: Run inductor_cpp_wrapper? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: false | ||||
|       cudagraphs: | ||||
|         description: Run inductor_cudagraphs? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: true | ||||
|       freezing_cudagraphs: | ||||
|         description: Run inductor_cudagraphs with freezing for inference? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: false | ||||
|       aotinductor: | ||||
|         description: Run aot_inductor for inference? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: false | ||||
|       maxautotune: | ||||
|         description: Run inductor_max_autotune? | ||||
|         required: false | ||||
|         type: boolean | ||||
|         default: false | ||||
|       benchmark_configs: | ||||
|         description: The list of configs used the benchmark | ||||
|         required: false | ||||
|         type: string | ||||
|         default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300 | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: read-all | ||||
|  | ||||
| jobs: | ||||
|   get-label-type: | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|       opt_out_experiments: lf | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-inductor-benchmark-build: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: rocm-py3_10-inductor-benchmark-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3_10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-inductor-benchmark-test: | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: rocm-py3_10-inductor-benchmark-test | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: linux-jammy-rocm-py3_10-inductor-benchmark-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3_10 | ||||
|       dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }} | ||||
|       timeout-minutes: 720 | ||||
|       # Disable monitor in perf tests for more investigation | ||||
|       disable-monitor: true | ||||
|       monitor-log-interval: 10 | ||||
|       monitor-data-collect-interval: 2 | ||||
|     secrets: inherit | ||||
| @ -1,11 +1,11 @@ | ||||
| name: inductor-perf-nightly-rocm-mi355 | ||||
| name: inductor-perf-nightly-rocm | ||||
| 
 | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/inductor-perf-test-nightly-rocm-mi355/* | ||||
|       - ciflow/inductor-perf-test-nightly-rocm/* | ||||
|   schedule: | ||||
|     - cron: 15 0 * * * | ||||
|     - cron: 0 7 * * 0,3 | ||||
|   # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it | ||||
|   # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs | ||||
|   workflow_dispatch: | ||||
| @ -59,7 +59,7 @@ on: | ||||
|         description: The list of configs used the benchmark | ||||
|         required: false | ||||
|         type: string | ||||
|         default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355 | ||||
|         default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm | ||||
| 
 | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
| @ -88,27 +88,23 @@ jobs: | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
| 
 | ||||
							
								
								
									
										23
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										23
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							| @ -7,11 +7,9 @@ on: | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       test_mode: | ||||
|         type: choice | ||||
|         options: | ||||
|           - 'short' | ||||
|           - 'long' | ||||
|           - 'all' | ||||
|         required: false | ||||
|         type: string | ||||
|         default: 'short' | ||||
|         description: tag filter for operator benchmarks, options from long, short, all | ||||
|   schedule: | ||||
|     # Run at 07:00 UTC every Sunday | ||||
| @ -39,7 +37,20 @@ jobs: | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, | ||||
|           { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   opbenchmark-on-demand-build: | ||||
|     if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }} | ||||
|     name: opbenchmark-on-demand-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       build-environment: linux-jammy-py3.10-gcc11-build | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
| @ -256,7 +256,6 @@ endif() | ||||
| IF(USE_FBGEMM_GENAI) | ||||
|   set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/) | ||||
|   set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize) | ||||
|  | ||||
|   if(USE_CUDA) | ||||
|     # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. | ||||
|     # If you want to integrate a kernel from FBGEMM into torch, you have to add it here. | ||||
| @ -293,64 +292,58 @@ IF(USE_FBGEMM_GENAI) | ||||
|       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" | ||||
|     ) | ||||
|  | ||||
|     target_include_directories(fbgemm_genai PRIVATE | ||||
|     target_include_directories(fbgemm_genai PUBLIC | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||
|       ${fbgemm_genai_mx8mx8bf16_grouped} | ||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||
|     ) | ||||
|   else() | ||||
|     if(USE_ROCM) | ||||
|       # Only include the kernels we want to build to avoid increasing binary size. | ||||
|       file(GLOB_RECURSE fbgemm_genai_native_rocm_hip | ||||
|         "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip" | ||||
|         "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip") | ||||
|       set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) | ||||
|  | ||||
|     # Add FBGEMM_GENAI include directories for torch_ops.h | ||||
|     list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include) | ||||
|     list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include) | ||||
|   elseif(USE_ROCM) | ||||
|     # Only include the kernels we want to build to avoid increasing binary size. | ||||
|     file(GLOB_RECURSE fbgemm_genai_native_rocm_hip | ||||
|       "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip" | ||||
|       "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip") | ||||
|     set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) | ||||
|       # Add additional HIPCC compiler flags for performance | ||||
|       set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS | ||||
|         -mllvm | ||||
|         -amdgpu-coerce-illegal-types=1 | ||||
|         -mllvm | ||||
|         -enable-post-misched=0 | ||||
|         -mllvm | ||||
|         -greedy-reverse-local-assignment=1 | ||||
|         -fhip-new-launch-api) | ||||
|  | ||||
|     # Add additional HIPCC compiler flags for performance | ||||
|     set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS | ||||
|       -mllvm | ||||
|       -amdgpu-coerce-illegal-types=1 | ||||
|       -mllvm | ||||
|       -enable-post-misched=0 | ||||
|       -mllvm | ||||
|       -greedy-reverse-local-assignment=1 | ||||
|       -fhip-new-launch-api) | ||||
|       # Only compile for gfx942 for now. | ||||
|       # This is rather hacky, I could not figure out a clean solution :( | ||||
|       set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) | ||||
|       string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") | ||||
|       if("gfx942" IN_LIST PYTORCH_ROCM_ARCH) | ||||
|         list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;) | ||||
|       endif() | ||||
|       set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS}) | ||||
|  | ||||
|     # Only compile for gfx942 for now. | ||||
|     # This is rather hacky, I could not figure out a clean solution :( | ||||
|     set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) | ||||
|     string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") | ||||
|     if("gfx942" IN_LIST PYTORCH_ROCM_ARCH) | ||||
|       list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;) | ||||
|       hip_add_library( | ||||
|         fbgemm_genai STATIC | ||||
|         ${fbgemm_genai_native_rocm_hip} | ||||
|         HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS}) | ||||
|       set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL}) | ||||
|       set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||
|       target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES) | ||||
|  | ||||
|       target_include_directories(fbgemm_genai PUBLIC | ||||
|         # FBGEMM version of Composable Kernel is used due to some customizations | ||||
|         ${FBGEMM_THIRD_PARTY}/composable_kernel/include | ||||
|         ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include | ||||
|         ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||
|         ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||
|         ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||
|         ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||
|       ) | ||||
|     endif() | ||||
|     set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS}) | ||||
|  | ||||
|     hip_add_library( | ||||
|       fbgemm_genai STATIC | ||||
|       ${fbgemm_genai_native_rocm_hip} | ||||
|       HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS}) | ||||
|     set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL}) | ||||
|     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||
|     target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES) | ||||
|  | ||||
|     target_include_directories(fbgemm_genai PRIVATE | ||||
|       # FBGEMM version of Composable Kernel is used due to some customizations | ||||
|       ${FBGEMM_THIRD_PARTY}/composable_kernel/include | ||||
|       ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||
|     ) | ||||
|  | ||||
|     # Add FBGEMM_GENAI include directories for torch_ops.h | ||||
|     list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include) | ||||
|     list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include) | ||||
|   endif() | ||||
| endif() | ||||
|  | ||||
| @ -699,6 +692,12 @@ if(USE_CUDA AND NOT USE_ROCM) | ||||
|   list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) | ||||
|   list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) | ||||
|  | ||||
|   # Add FBGEMM_GENAI include directories for torch_ops.h | ||||
|   if(USE_FBGEMM_GENAI) | ||||
|     list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include) | ||||
|     list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include) | ||||
|   endif() | ||||
|  | ||||
|   if($ENV{ATEN_STATIC_CUDA}) | ||||
|     if(CUDA_VERSION VERSION_LESS_EQUAL 12.9) | ||||
|       list(APPEND ATen_CUDA_DEPENDENCY_LIBS | ||||
|  | ||||
| @ -389,16 +389,37 @@ void fillVersion<DLManagedTensorVersioned>( | ||||
| // constructed out of ATen tensor | ||||
| template <class T> | ||||
| T* toDLPackImpl(const Tensor& src) { | ||||
|   auto view = src; | ||||
|  | ||||
|   // Detect whether there is need to normalize the strides | ||||
|   // Background: gh-83069 | ||||
|   // | ||||
|   // However, normalizing strides can come at a high-cost | ||||
|   // to slow down toDLPack conversion 3x, so we | ||||
|   // only normalize if needed. | ||||
|   // | ||||
|   // The following code detects whether the src follows | ||||
|   // a continuous pattern. If the src follows such pattern (common-case) | ||||
|   // then we do not need to normalize the strides. | ||||
|   bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1; | ||||
|   // less common case, try normalizing the strides | ||||
|   if (need_normalize_strides) { | ||||
|     // create a new tensor with possibly normalized strides | ||||
|     // gh-83069 | ||||
|     auto shape = src.sizes(); | ||||
|     view = src.as_strided(shape, {1}, src.storage_offset()); | ||||
|   } | ||||
|  | ||||
|   ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>); | ||||
|   atDLMTensor->handle = src; | ||||
|   atDLMTensor->handle = view; | ||||
|   atDLMTensor->tensor.manager_ctx = atDLMTensor; | ||||
|   atDLMTensor->tensor.deleter = &deleter<T>; | ||||
|   atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); | ||||
|   atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); | ||||
|   atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); | ||||
|   atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim()); | ||||
|   atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); | ||||
|   atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data()); | ||||
|   atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data()); | ||||
|   atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data()); | ||||
|   atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data()); | ||||
|   atDLMTensor->tensor.dl_tensor.byte_offset = 0; | ||||
|   fillVersion(&atDLMTensor->tensor); | ||||
|  | ||||
|  | ||||
| @ -624,14 +624,7 @@ struct TORCH_API IValue final { | ||||
|   IValue(const c10::SymBool& i) { | ||||
|     if (auto mi = i.maybe_as_bool()) { | ||||
|       tag = Tag::Bool; | ||||
| #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | ||||
|       payload.u.as_int = *mi; | ||||
| #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ | ||||
|       /* due to byteorder if value assigned as_int, as_bool actually is not set correctly */ | ||||
|       payload.u.as_bool = *mi; | ||||
| #else | ||||
| #error Unexpected or undefined __BYTE_ORDER__ | ||||
| #endif | ||||
|     } else { | ||||
|       tag = Tag::SymBool; | ||||
|       payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); | ||||
|  | ||||
| @ -13,7 +13,6 @@ | ||||
| #include <c10/core/ScalarType.h> | ||||
|  | ||||
| #include <ATen/cuda/tunable/TunableOp.h> | ||||
| #include <ATen/cuda/tunable/Tunable.h> | ||||
| #include <ATen/cuda/CUDABlas.h> | ||||
| #include <ATen/cuda/Exceptions.h> | ||||
| #include <c10/util/StringUtil.h> | ||||
| @ -151,7 +150,6 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) { | ||||
|       BLASType = "unknown"; | ||||
|   } | ||||
|   return BLASType; | ||||
|  | ||||
| } | ||||
|  | ||||
| // Similar to Compute Type in GemmRocblas.h | ||||
| @ -246,25 +244,33 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio | ||||
|  | ||||
| namespace detail { | ||||
|  | ||||
| static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) { | ||||
|  | ||||
|   if (!config.enabled) { | ||||
|     return true; // skip when disabled | ||||
|   } | ||||
|  | ||||
| static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) { | ||||
|   auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA); | ||||
|   // comparison done as 1D tensor | ||||
|   at::Tensor ref = at::from_blob(c,       {size}, options); | ||||
|   at::Tensor oth = at::from_blob(other_c, {size}, options); | ||||
|   at::Tensor ref_float = ref.to(at::kFloat); | ||||
|   at::Tensor oth_float = oth.to(at::kFloat); | ||||
|  | ||||
|   const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol); | ||||
|   if (ok) { | ||||
|     TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol); | ||||
|   } else { | ||||
|     TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol); | ||||
|   std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; | ||||
|   std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; | ||||
|   double last_succeed_atol = 1; | ||||
|   double last_succeed_rtol = 1; | ||||
|   for (auto& atol : atols) { | ||||
|     for (auto& rtol : rtols) { | ||||
|       if (at::allclose(ref_float, oth_float, rtol, atol)) { | ||||
|         last_succeed_atol = atol; | ||||
|         last_succeed_rtol = rtol; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return ok; | ||||
|   if (last_succeed_atol == 1) { | ||||
|     return false; | ||||
|   } | ||||
|   else { | ||||
|     TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); | ||||
|   } | ||||
|  | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| } | ||||
| @ -349,10 +355,8 @@ struct GemmParams : OpParams { | ||||
|   } | ||||
|  | ||||
|   TuningStatus NumericalCheck(GemmParams<T> *other) { | ||||
|     auto* ctx = getTuningContext(); | ||||
|     auto cfg = ctx->GetNumericalCheckConfig(); | ||||
|     auto c_dtype = c10::CppTypeToScalarType<T>::value; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; | ||||
|   } | ||||
|  | ||||
|   char transa{}; | ||||
| @ -445,10 +449,8 @@ struct GemmAndBiasParams : OpParams { | ||||
|   } | ||||
|  | ||||
|   TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) { | ||||
|     auto* ctx = getTuningContext(); | ||||
|     auto cfg = ctx->GetNumericalCheckConfig(); | ||||
|     auto c_dtype = c10::CppTypeToScalarType<T>::value; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; | ||||
|   } | ||||
|  | ||||
|   char transa{}; | ||||
| @ -544,10 +546,8 @@ struct GemmStridedBatchedParams : OpParams { | ||||
|   } | ||||
|  | ||||
|   TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) { | ||||
|     auto* ctx = getTuningContext(); | ||||
|     auto cfg = ctx->GetNumericalCheckConfig(); | ||||
|     auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; | ||||
|   } | ||||
|  | ||||
|   char transa{}; | ||||
| @ -663,9 +663,7 @@ struct ScaledGemmParams : OpParams { | ||||
|   } | ||||
|  | ||||
|   TuningStatus NumericalCheck(ScaledGemmParams<T> *other) { | ||||
|     auto* ctx = getTuningContext(); | ||||
|     auto cfg = ctx->GetNumericalCheckConfig(); | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; | ||||
|     return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; | ||||
|   } | ||||
|  | ||||
|   char transa{}; | ||||
|  | ||||
| @ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins | ||||
| | PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. | | ||||
| | PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. | | ||||
| | PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. | | ||||
| | PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". | | ||||
| | PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. | | ||||
| | PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. | | ||||
| | PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. | | ||||
| | PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. | | ||||
| @ -173,9 +173,10 @@ All python APIs exist in the `torch.cuda.tunable` module. | ||||
| | get_max_tuning_iterations() -> int | | | ||||
| | set_filename(filename: str, insert_device_ordinal: bool = False) -> None | | | ||||
| | get_filename() -> str | | | ||||
| | set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5. | ||||
| | get_results() -> Tuple[str, str, str, float] | | | ||||
| | get_validators() -> Tuple[str, str] | | | ||||
| | write_file_on_exit(val: bool) -> None | Default is True. | | ||||
| | write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | | ||||
| | read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | | ||||
| | tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. | | ||||
| | mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. | | ||||
|  | ||||
| @ -107,30 +107,14 @@ void TuningResultsManager::AddImpl(const std::string& op_signature, | ||||
| } | ||||
|  | ||||
| void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) { | ||||
|   bool is_new = false; | ||||
|   ResultEntry inserted = ResultEntry::Null(); | ||||
|   std::scoped_lock l{lock_}; | ||||
|  | ||||
|   // ---- mutate maps under results lock ---- | ||||
|   { | ||||
|     std::scoped_lock l{lock_}; | ||||
|     auto& km = results_[op_signature];  // creates if missing | ||||
|     is_new = (km.find(params_signature) == km.end()); | ||||
|     AddImpl(op_signature, params_signature, std::move(best), km); | ||||
|     if (is_new) { | ||||
|       inserted = km.at(params_signature);  // snapshot for I/O after unlocking | ||||
|     } | ||||
|   } | ||||
|    if (!is_new) return;  // only write once per unique (op, params) | ||||
|  | ||||
|    TuningContext* ctx = getTuningContext(); | ||||
|   if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) { | ||||
|     InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators()); | ||||
|  | ||||
|     if (is_new && realtime_out_ && realtime_out_->good()) { | ||||
|       AppendResultLine(op_signature, params_signature, inserted); | ||||
|     } | ||||
|   auto it = results_.find(op_signature); | ||||
|   if (it == results_.end()) { | ||||
|     it = results_.insert({op_signature, {}}).first; | ||||
|   } | ||||
|  | ||||
|   AddImpl(op_signature, params_signature, std::move(best), it->second); | ||||
| } | ||||
|  | ||||
| void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, | ||||
| @ -166,77 +150,6 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std | ||||
|   } | ||||
| } | ||||
|  | ||||
| void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) { | ||||
|   std::scoped_lock fl{realtime_file_mutex_}; | ||||
|  | ||||
|   if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (realtime_out_ && realtime_filename_ != filename) { | ||||
|     realtime_out_->flush(); | ||||
|     realtime_out_->close(); | ||||
|     realtime_out_.reset(); | ||||
|     validators_written_ = false; | ||||
|   } | ||||
|  | ||||
|   bool file_exists = false; | ||||
|   bool file_empty = true; | ||||
|  | ||||
|   { | ||||
|     std::ifstream check_file(filename); | ||||
|     if (check_file.good()) { | ||||
|       file_exists = true; | ||||
|       file_empty = (check_file.peek() == std::ifstream::traits_type::eof()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app); | ||||
|  | ||||
|   if (!realtime_out_->good()) { | ||||
|     TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'"); | ||||
|     realtime_out_.reset(); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if(!file_exists || file_empty) { | ||||
|     for(const auto& [key, val] : validators) { | ||||
|       (*realtime_out_) << "Validator," << key << "," << val << std::endl; | ||||
|       realtime_out_->flush(); | ||||
|     } | ||||
|     validators_written_ = true; | ||||
|  | ||||
|     TUNABLE_LOG2("Wrote validators to realtime output file"); | ||||
|   } | ||||
|  | ||||
|   realtime_filename_ = filename; | ||||
| } | ||||
|  | ||||
| void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) { | ||||
|   std::scoped_lock fl{realtime_file_mutex_}; | ||||
|  | ||||
|   if(!realtime_out_ || !realtime_out_->good()) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   (*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl; | ||||
|   realtime_out_->flush(); //ensure immediate write to disk | ||||
|  | ||||
|   TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result); | ||||
| } | ||||
|  | ||||
| void TuningResultsManager::CloseRealtimeAppend() { | ||||
|   std::scoped_lock fl{realtime_file_mutex_}; | ||||
|  | ||||
|  | ||||
|   if(realtime_out_) { | ||||
|     realtime_out_->flush(); | ||||
|     realtime_out_->close(); | ||||
|     realtime_out_.reset(); | ||||
|     TUNABLE_LOG2("Closed realtime output file"); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { | ||||
|   std::scoped_lock l{lock_}; | ||||
|  | ||||
| @ -483,6 +396,7 @@ TuningContext::TuningContext() : | ||||
|     tuning_enable_{true}, | ||||
|     record_untuned_enable_{false}, | ||||
|     manager_initialized_{false}, | ||||
|     write_file_on_exit_{true}, | ||||
|     numerics_check_enable_{false}, | ||||
|     max_tuning_duration_ms_{30}, | ||||
|     max_tuning_iterations_{100}, | ||||
| @ -503,8 +417,20 @@ TuningContext::~TuningContext() { | ||||
|     // but doesn't do any computation itself. | ||||
|     return; | ||||
|   } | ||||
|   TUNABLE_LOG1("Closing File"); | ||||
|   GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now. | ||||
|   auto filename = GetFilename(); | ||||
|   if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) { | ||||
|     if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) { | ||||
|       if (results_count_from_input_file_ > 0) { | ||||
|         TUNABLE_LOG1("additional tuning results available, rewriting file ", filename); | ||||
|       } | ||||
|       else { | ||||
|         TUNABLE_LOG1("writing file ", filename); | ||||
|       } | ||||
|       if (!WriteFile(filename)) { | ||||
|         TUNABLE_LOG1("failed to write file ", filename); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (untuned_file_.good()) { | ||||
|     untuned_file_.close(); | ||||
| @ -585,54 +511,20 @@ std::ofstream& TuningContext::GetUntunedFile(){ | ||||
|   return untuned_file_; | ||||
| } | ||||
|  | ||||
| void TuningContext::WriteFileOnExit(bool value) { | ||||
|   write_file_on_exit_ = value; | ||||
| } | ||||
|  | ||||
| void TuningContext::EnableNumericsCheck(bool value) { | ||||
|   numerics_check_enable_ = value; | ||||
| } | ||||
|  | ||||
| NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const { | ||||
|   const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); | ||||
|  | ||||
|   if (!env_opt.has_value()) { | ||||
|     return numerics_cfg_; | ||||
|   } | ||||
|  | ||||
|   const std::string& env = env_opt.value(); | ||||
|  | ||||
|   if (env == "0") { | ||||
|     return NumericalCheckConfig(false, 1e-5, 1e-5); | ||||
|   } | ||||
|  | ||||
|   const size_t underscore = env.find('_'); | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       underscore != std::string::npos, | ||||
|       "Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. " | ||||
|       "Expected 'atol_rtol', got: ", | ||||
|       env); | ||||
|  | ||||
|   double atol = 0.0; | ||||
|   double rtol = 0.0; | ||||
|  | ||||
|   try { | ||||
|     atol = std::stod(env.substr(0, underscore)); | ||||
|     rtol = std::stod(env.substr(underscore + 1)); | ||||
|   } catch (const std::exception& e) { | ||||
|     TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what()); | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol); | ||||
|   return NumericalCheckConfig(true, atol, rtol); | ||||
| } | ||||
|  | ||||
| void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) { | ||||
|   TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive"); | ||||
|   numerics_cfg_ = {enabled, atol, rtol}; | ||||
| } | ||||
|  | ||||
| bool TuningContext::IsNumericsCheckEnabled() const { | ||||
|   const auto cfg = GetNumericalCheckConfig(); | ||||
|   return cfg.enabled || numerics_check_enable_; | ||||
|   const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); | ||||
|   if (env == "1") { | ||||
|     return true; | ||||
|   } | ||||
|   return numerics_check_enable_; | ||||
| } | ||||
|  | ||||
| void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) { | ||||
| @ -742,6 +634,11 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() { | ||||
|     auto filename = GetFilename(); | ||||
|     if (!filename.empty() && !IsRecordUntunedEnabled()) { | ||||
|       ReadFile(filename); | ||||
|       // attempt immediately to open file for writing to catch errors early | ||||
|       std::ofstream file(filename, std::ios::out | std::ios::app); | ||||
|       if (!file.good()) { | ||||
|         TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved"); | ||||
|       } | ||||
|     } | ||||
|   }); | ||||
|   return manager_; | ||||
| @ -847,6 +744,27 @@ bool TuningContext::ReadFile(const std::string& filename_) { | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| bool TuningContext::WriteFile(const std::string& filename_) { | ||||
|   std::string filename = filename_.empty() ? GetFilename() : filename_; | ||||
|   std::ofstream file(filename, std::ios::out | std::ios::trunc); | ||||
|   if (!file.good()) { | ||||
|     TUNABLE_LOG1("error opening tuning results file for writing ", filename); | ||||
|     return false; | ||||
|   } | ||||
|   auto validators = GetTuningResultsValidator().GetAllValidators(); | ||||
|   for (const auto& [key, val] : validators) { | ||||
|     file << "Validator," << key << "," << val << std::endl; | ||||
|   } | ||||
|   auto results = GetTuningResultsManager().Dump(); | ||||
|   for (const auto& [op_sig, kernelmap] : results) { | ||||
|     for (const auto& [param_sig, result] : kernelmap) { | ||||
|       file << op_sig << "," << param_sig << "," << result << std::endl; | ||||
|     } | ||||
|   } | ||||
|   file.close(); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| struct MaybeDelete { | ||||
|  | ||||
| @ -103,24 +103,10 @@ class TORCH_CUDA_CPP_API TuningResultsManager { | ||||
|  | ||||
|     void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, | ||||
|       const std::string& params_signature, const std::string& blas_signature); | ||||
|  | ||||
|     void InitRealtimeAppend( | ||||
|         const std::string& filename, | ||||
|         const std::unordered_map<std::string, std::string>& validators); | ||||
|  | ||||
|     void AppendResultLine(const std::string& op_sig, | ||||
|                          const std::string& param_sig, | ||||
|                          const ResultEntry& result); | ||||
|  | ||||
|     void CloseRealtimeAppend();  // For clean shutdown | ||||
|   private: | ||||
|     std::mutex lock_; | ||||
|     std::mutex realtime_file_mutex_; | ||||
|     std::unique_ptr<std::ofstream> realtime_out_; | ||||
|     std::string realtime_filename_; | ||||
|     ResultsMap results_; | ||||
|     UntunedMap untuned_results_; | ||||
|     bool validators_written_ = false; | ||||
|  | ||||
| }; | ||||
|  | ||||
| @ -148,16 +134,6 @@ class TORCH_CUDA_CPP_API TuningResultsValidator { | ||||
|     GetValidateFuncs validators_; | ||||
| }; | ||||
|  | ||||
| struct NumericalCheckConfig { | ||||
|   bool   enabled{false}; | ||||
|   double atol{1e-5}; | ||||
|   double rtol{1e-5}; | ||||
|  | ||||
|   NumericalCheckConfig() = default; | ||||
|   NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {} | ||||
| }; | ||||
|  | ||||
|  | ||||
| class TORCH_CUDA_CPP_API TuningContext { | ||||
|   public: | ||||
|     TuningContext(); | ||||
| @ -179,8 +155,6 @@ class TORCH_CUDA_CPP_API TuningContext { | ||||
|  | ||||
|     void EnableNumericsCheck(bool value); | ||||
|     bool IsNumericsCheckEnabled() const; | ||||
|     void SetNumericalCheckConfig(bool enabled, double atol, double rtol); | ||||
|     NumericalCheckConfig GetNumericalCheckConfig() const; | ||||
|  | ||||
|     void SetMaxTuningDurationMs(int max_duration_ms); | ||||
|     int GetMaxTuningDurationMs() const; | ||||
| @ -211,7 +185,10 @@ class TORCH_CUDA_CPP_API TuningContext { | ||||
|     void SetFilename(const std::string& filename, bool insert_device_ordinal=false); | ||||
|     std::string GetFilename() const; | ||||
|  | ||||
|     void WriteFileOnExit(bool value); | ||||
|  | ||||
|     bool ReadFile(const std::string& filename={}); | ||||
|     bool WriteFile(const std::string& filename={}); | ||||
|  | ||||
|     template<class... Types> | ||||
|     void Log(int level, Types... args) { | ||||
| @ -230,6 +207,7 @@ class TORCH_CUDA_CPP_API TuningContext { | ||||
|     bool tuning_enable_; | ||||
|     bool record_untuned_enable_; | ||||
|     bool manager_initialized_; | ||||
|     bool write_file_on_exit_; | ||||
|     bool numerics_check_enable_; | ||||
|     int max_tuning_duration_ms_; | ||||
|     int max_tuning_iterations_; | ||||
| @ -244,8 +222,6 @@ class TORCH_CUDA_CPP_API TuningContext { | ||||
|     std::ofstream untuned_file_; | ||||
|     size_t results_count_from_input_file_; | ||||
|     bool is_shutting_down_; | ||||
|  | ||||
|     NumericalCheckConfig numerics_cfg_{}; | ||||
| }; | ||||
|  | ||||
| TORCH_CUDA_CPP_API TuningContext* getTuningContext(); | ||||
|  | ||||
| @ -267,10 +267,27 @@ class TunableOp { | ||||
|       for (size_t i = 0; i < op_names_.size(); i++) { | ||||
|         auto* candidate = ops_[op_names_[i]].get(); // borrow pointer | ||||
|  | ||||
|         auto status = candidate->Call(reusable_params[0]); | ||||
|         if (status != OK) { | ||||
|           TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); | ||||
|           continue; | ||||
|         if (do_numerics_check) { | ||||
|           ParamsT* numerical_params = params->DeepCopy(false); | ||||
|           auto status = candidate->Call(numerical_params); | ||||
|           if (status != OK) { | ||||
|             numerical_params->Delete(); | ||||
|             TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); | ||||
|             continue; | ||||
|           } | ||||
|           status = reference_params->NumericalCheck(numerical_params); | ||||
|           numerical_params->Delete(); | ||||
|           if (status != OK) { | ||||
|             TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); | ||||
|             continue; | ||||
|           } | ||||
|         } | ||||
|         else { | ||||
|           auto status = candidate->Call(reusable_params[0]); | ||||
|           if (status != OK) { | ||||
|             TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); | ||||
|             continue; | ||||
|           } | ||||
|         } | ||||
|  | ||||
|         // collect a small profile | ||||
| @ -293,22 +310,6 @@ class TunableOp { | ||||
|           continue; | ||||
|         } | ||||
|  | ||||
|         if (do_numerics_check) { | ||||
|           ParamsT* numerical_params = params->DeepCopy(false); | ||||
|           auto status = candidate->Call(numerical_params); | ||||
|           if (status != OK) { | ||||
|             numerical_params->Delete(); | ||||
|             TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); | ||||
|             continue; | ||||
|           } | ||||
|           status = reference_params->NumericalCheck(numerical_params); | ||||
|           numerical_params->Delete(); | ||||
|           if (status != OK) { | ||||
|             TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); | ||||
|             continue; | ||||
|           } | ||||
|         } | ||||
|  | ||||
|         // for warmup does user set max duration, max iters, or both? | ||||
|         // warmup is skipped by default, i.e. warmup_iter = 0 | ||||
|         // warmup will be set to the non-zero value of max_warmup_duration | ||||
|  | ||||
| @ -213,22 +213,40 @@ static cudnn_grid_sample_backward_batch_rule( | ||||
|   return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size); | ||||
| } | ||||
|  | ||||
| // uses functional formulation for one_hot under vmap to be compatible with | ||||
| // fakeTensor/dynamic shapes and compiled functorch transforms. | ||||
| // mirrors the meta path in aten/src/ATen/native/Onehot.cpp, | ||||
| // but requires explicit positive num_classes under vmap to avoid | ||||
| // data-dependent output shapes. | ||||
| // TODO: replace with targetable functionalization | ||||
| static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) { | ||||
|     TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor."); | ||||
|     auto shape = self.sym_sizes().vec(); | ||||
|  | ||||
|     // empty tensor could be converted to one hot representation, | ||||
|     // but shape inference is not possible. | ||||
|     if (self.sym_numel() == 0) { | ||||
|         if (num_classes <= 0) { | ||||
|             TORCH_CHECK(false, "Can not infer total number of classes from empty tensor."); | ||||
|         } else { | ||||
|             shape.emplace_back(num_classes); | ||||
|             return at::empty_symint(shape, self.options()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // disallow implicit inference under vmap; this would be data-dependent | ||||
|     // and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py. | ||||
|     TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please " | ||||
|         "provide an explicit positive num_classes argument."); | ||||
|  | ||||
|     const auto options = self.options(); | ||||
|     at::Tensor index = at::arange(num_classes, options); | ||||
|     return at::eq(self.unsqueeze(-1), index).to(at::kLong); | ||||
|     // Disabling all of the following checks. This is OK because scatter has checks too. | ||||
|     // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this. | ||||
|     // // non-empty tensor | ||||
|     // if (self.device().type() != at::kCUDA) { | ||||
|     //   //for cuda, rely on device assert thrown by scatter | ||||
|     //   TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative."); | ||||
|     // } | ||||
|     // if (self.device().type() != at::kCUDA) { | ||||
|     //   //rely on device asserts from scatter to avoid sync here | ||||
|     //   TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); | ||||
|     // } | ||||
|  | ||||
|     shape.emplace_back(num_classes); | ||||
|     Tensor ret = at::zeros_symint(shape, self.options()); | ||||
|     return ret.scatter(-1, self.unsqueeze(-1), 1); | ||||
| } | ||||
|  | ||||
| template <typename A, A a, typename C> | ||||
|  | ||||
| @ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     auto shape = self.sym_sizes().vec(); | ||||
|     auto shape = self.sizes().vec(); | ||||
|  | ||||
|     // empty tensor could be converted to one hot representation, | ||||
|     // but shape inference is not possible. | ||||
|     if (self.sym_numel() == 0) { | ||||
|     if (self.numel() == 0) { | ||||
|         if (num_classes <= 0) { | ||||
|             TORCH_CHECK(false, "Can not infer total number of classes from empty tensor."); | ||||
|         } else { | ||||
|             shape.emplace_back(num_classes); | ||||
|             return at::empty_symint(shape, self.options()); | ||||
|             shape.push_back(num_classes); | ||||
|             return at::empty(shape, self.options()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     shape.emplace_back(num_classes); | ||||
|     Tensor ret = at::zeros_symint(shape, self.options()); | ||||
|     shape.push_back(num_classes); | ||||
|     Tensor ret = at::zeros(shape, self.options()); | ||||
|     ret.scatter_(-1, self.unsqueeze(-1), 1); | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| @ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel( | ||||
|   } else if (dtype == ScalarType::Half) { | ||||
|     [&]() { | ||||
|       using scalar_t = | ||||
|           c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>; | ||||
|           decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t); | ||||
|       const auto exp = exp_scalar.to<scalar_t>(); | ||||
|       using Vec = Vectorized<scalar_t>; | ||||
|       cpu_kernel_vec(iter, | ||||
|  | ||||
| @ -1230,205 +1230,8 @@ std::pair<ScalingType, ScalingType> get_joint_scaling( | ||||
|   ); | ||||
| } | ||||
|  | ||||
| Tensor& | ||||
| _tunable_scaled_gemm_rocm( | ||||
|           cublasCommonArgs& args, | ||||
|           const Tensor& mat1, const Tensor& mat2, | ||||
|           const Tensor& scale_a, const Tensor& scale_b, | ||||
|           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const bool use_fast_accum, | ||||
|           const at::ScalarType out_dtype, | ||||
|           Tensor& out) { | ||||
| #ifdef USE_ROCM | ||||
| #define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \ | ||||
|       if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \ | ||||
|         if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|         else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|       }                                                               \ | ||||
|       else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \ | ||||
|         if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|         else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|       }                                                               \ | ||||
|       else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) {     \ | ||||
|         if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t,         \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|         else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e4m3fn, at::Float8_e5m2, scalar_t,           \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|       }                                                               \ | ||||
|       else if (mat1.scalar_type() == ScalarType::Float8_e5m2) {       \ | ||||
|         if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e5m2, at::Float8_e4m3fn, scalar_t,           \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|         else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \ | ||||
|           static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|               at::Float8_e5m2, at::Float8_e5m2, scalar_t,             \ | ||||
|               BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|           scaledgemm(¶ms);                                        \ | ||||
|         }                                                             \ | ||||
|       } | ||||
|   AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] { | ||||
|     bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); | ||||
|     bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); | ||||
|     at::cuda::tunable::ScaledGemmParams<scalar_t> params; | ||||
|     params.transa = args.transa; | ||||
|     params.transb = args.transb; | ||||
|     params.m = args.m; | ||||
|     params.n = args.n; | ||||
|     params.k = args.k; | ||||
|     params.a = args.mata->data_ptr(); | ||||
|     params.a_scale_ptr = args.scale_mata_ptr; | ||||
|     params.a_scale_dtype = args.scale_mata_dtype.value(); | ||||
|     params.lda = args.lda; | ||||
|     params.a_dtype = args.mata->scalar_type(); | ||||
|     params.a_scale_dtype = args.scale_mata_dtype.value(); | ||||
|     params.a_scaling_type = args.scaling_mata_type.value(); | ||||
|     params.b = args.matb->data_ptr(); | ||||
|     params.b_scale_ptr = args.scale_matb_ptr; | ||||
|     params.b_scale_dtype = args.scale_matb_dtype.value(); | ||||
|     params.ldb = args.ldb; | ||||
|     params.b_dtype = args.matb->scalar_type(); | ||||
|     params.b_scale_dtype = args.scale_matb_dtype.value(); | ||||
|     params.b_scaling_type = args.scaling_matb_type.value(); | ||||
|     params.bias_ptr = bias ? bias->data_ptr(): nullptr; | ||||
|     params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype; | ||||
|     params.c = args.result->data_ptr(); | ||||
|     params.c_scale_ptr = args.scale_result_ptr; | ||||
|     params.ldc = args.result_ld; | ||||
|     params.c_dtype = out_dtype; | ||||
|     params.use_fast_accum = use_fast_accum; | ||||
|     if (transa_ && transb_) { | ||||
|       TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) | ||||
|     } | ||||
|     else if (transa_ && !transb_) { | ||||
|       TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N) | ||||
|     } | ||||
|     else if (!transa_ && transb_) { | ||||
|       TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T) | ||||
|     } | ||||
|     else if (!transa_ && !transb_) { | ||||
|       TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N) | ||||
|     } | ||||
|     else { | ||||
|       TORCH_CHECK(false, "unreachable"); | ||||
|     } | ||||
|   }), | ||||
|   kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); | ||||
| #undef TUNABLE_DISPATCH | ||||
|   return out; | ||||
| #else | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| Tensor& | ||||
| _scaled_gemm( | ||||
|           const Tensor& mat1, const Tensor& mat2, | ||||
|           const Tensor& scale_a, const Tensor& scale_b, | ||||
|           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const bool use_fast_accum, | ||||
|           Tensor& out) { | ||||
|   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); | ||||
|   const auto out_dtype_ = args.result->scalar_type(); | ||||
|   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); | ||||
|  | ||||
| // ROCM enables the TunableOp path only | ||||
| // but can fallback to at::cuda::blas::scaled_gemm | ||||
| #ifdef USE_ROCM | ||||
|   auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|   bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled(); | ||||
| #else | ||||
|   bool tunable_op_enabled = false; | ||||
| #endif | ||||
|   if (tunable_op_enabled) { | ||||
|       // Only available on ROCM | ||||
|       return _tunable_scaled_gemm_rocm( | ||||
|           args, | ||||
|           mat1, mat2, | ||||
|           scale_a, scale_b, | ||||
|           scaling_choice_a, scaling_choice_b, | ||||
|           bias, | ||||
|           use_fast_accum, | ||||
|           out_dtype_, | ||||
|           out); | ||||
|   } | ||||
|   else | ||||
|   { | ||||
|       at::cuda::blas::scaled_gemm( | ||||
|           args.transa, | ||||
|           args.transb, | ||||
|           args.m, | ||||
|           args.n, | ||||
|           args.k, | ||||
|           args.mata->data_ptr(), | ||||
|           args.scale_mata_ptr, | ||||
|           args.lda, | ||||
|           args.mata->scalar_type(), | ||||
|           args.scale_mata_dtype.value(), | ||||
|           args.scaling_mata_type.value(), | ||||
|           args.matb->data_ptr(), | ||||
|           args.scale_matb_ptr, | ||||
|           args.ldb, | ||||
|           args.matb->scalar_type(), | ||||
|           args.scale_matb_dtype.value(), | ||||
|           args.scaling_matb_type.value(), | ||||
|           bias ? bias->data_ptr(): nullptr, | ||||
|           bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, | ||||
|           args.result->data_ptr(), | ||||
|           args.scale_result_ptr, | ||||
|           args.result_ld, | ||||
|           out_dtype_, | ||||
|           use_fast_accum); | ||||
|       return out; | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| // NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here | ||||
| //                  to help cleanup v1 call structure. | ||||
| Tensor& | ||||
| _scaled_rowwise_rowwise( | ||||
|           const Tensor&, const Tensor&, | ||||
|           const Tensor&, const Tensor&, | ||||
|           const std::optional<Tensor>&, | ||||
|           const c10::ScalarType, | ||||
|           bool, | ||||
|           Tensor&); | ||||
|  | ||||
|  | ||||
| // Computes matrix multiply + bias while applying scaling to input and output matrices | ||||
| // Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default. | ||||
| // If output matrix type is 16 or 32-bit type, scale_result is not applied. | ||||
| @ -1470,10 +1273,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, | ||||
|   // by decreasing priority. We prefer "simpler" schemes as they are supported | ||||
|   // more broadly (more GPU archs, more CUDA versions) and because they are more | ||||
|   // efficient. This tends to matter only for small matmuls (e.g., 1x1x128). | ||||
|  | ||||
|   // List of supported BlockWise pairs for FP8: | ||||
|   // https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types | ||||
|  | ||||
|   auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( | ||||
|     { | ||||
|       std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), | ||||
| @ -1506,7 +1305,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, | ||||
|   TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type()); | ||||
| #ifndef USE_ROCM | ||||
|   // Type restrictions imposed by CuBLASLt as of CUDA-12.1 | ||||
|   TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, | ||||
|   TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, | ||||
|         "Multiplication of two Float8_e5m2 matrices is not supported"); | ||||
| #endif | ||||
|   if (use_fast_accum) { | ||||
| @ -1572,44 +1371,41 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, | ||||
|  | ||||
|   // NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9, | ||||
|   // and only for compute capability 9.0+. In other cases we use CUTLASS. | ||||
|   // We are doing row-wise scaling | ||||
|   if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { | ||||
| #ifndef USE_ROCM | ||||
|     auto dprops = at::cuda::getCurrentDeviceProperties(); | ||||
|     if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) | ||||
|         // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales | ||||
|         ||  (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) { | ||||
|       TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); | ||||
|       return _scaled_rowwise_rowwise( | ||||
|           mat1, | ||||
|           mat2, | ||||
|           scale_a, | ||||
|           scale_b, | ||||
|           bias, | ||||
|           out.scalar_type(), | ||||
|           use_fast_accum, | ||||
|           out); | ||||
|     } | ||||
|   // We are doing row-wise scaling | ||||
|   auto dprops = at::cuda::getCurrentDeviceProperties(); | ||||
|   if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise | ||||
|       && ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) | ||||
|       // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales | ||||
|       ||  (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) { | ||||
|     TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); | ||||
|     at::cuda::detail::f8f8bf16_rowwise( | ||||
|         mat1, | ||||
|         mat2, | ||||
|         scale_a, | ||||
|         scale_b, | ||||
|         bias, | ||||
|         use_fast_accum, | ||||
|         out); | ||||
|     return out; | ||||
|   } | ||||
| #else | ||||
|   if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { | ||||
|     // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. | ||||
|     Tensor b = mat2; | ||||
|     if (_scaled_mm_is_fnuz()) { | ||||
|       TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz, | ||||
|           "Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype()); | ||||
|       TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz); | ||||
|     } | ||||
|     else { | ||||
|       TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn, | ||||
|           "Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype()); | ||||
|       TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn); | ||||
|     } | ||||
|     // Until more than bf16 is supported. | ||||
|     TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16, | ||||
|     TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, | ||||
|          "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); | ||||
| #endif | ||||
|   } | ||||
|   else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) { | ||||
| #ifdef USE_ROCM | ||||
|     #if ROCM_VERSION >= 70000 | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||||
|     TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||||
|                 "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||||
|  | ||||
|     int packed_factor = 1; | ||||
| @ -1618,20 +1414,163 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, | ||||
|       // effectively packing two elements into one byte. | ||||
|       packed_factor = 2; | ||||
|     } | ||||
|     TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 && | ||||
|     TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 && | ||||
|                 mat2.size(1) % 16 == 0, | ||||
|                 "M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling"); | ||||
|  | ||||
|     TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || | ||||
|     TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 || | ||||
|                 out.scalar_type() == ScalarType::Half, | ||||
|                 "Block-wise scaling only supports BFloat16 or Half output types"); | ||||
| #else | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); | ||||
| #endif | ||||
|     TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); | ||||
| #endif | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); | ||||
|   const auto out_dtype_ = args.result->scalar_type(); | ||||
|   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|   if (tuning_ctx->IsTunableOpEnabled()) { | ||||
| #define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \ | ||||
|         if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|         else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|         else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) {     \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t,         \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fn, at::Float8_e5m2, scalar_t,           \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|         else if (mat1.scalar_type() == ScalarType::Float8_e5m2) {       \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2, at::Float8_e4m3fn, scalar_t,           \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2, at::Float8_e5m2, scalar_t,             \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         } | ||||
|     AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] { | ||||
|       bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); | ||||
|       bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); | ||||
|       at::cuda::tunable::ScaledGemmParams<scalar_t> params; | ||||
|       params.transa = args.transa; | ||||
|       params.transb = args.transb; | ||||
|       params.m = args.m; | ||||
|       params.n = args.n; | ||||
|       params.k = args.k; | ||||
|       params.a = args.mata->data_ptr(); | ||||
|       params.a_scale_ptr = args.scale_mata_ptr; | ||||
|       params.a_scale_dtype = args.scale_mata_dtype.value(); | ||||
|       params.lda = args.lda; | ||||
|       params.a_dtype = args.mata->scalar_type(); | ||||
|       params.a_scale_dtype = args.scale_mata_dtype.value(); | ||||
|       params.a_scaling_type = args.scaling_mata_type.value(); | ||||
|       params.b = args.matb->data_ptr(); | ||||
|       params.b_scale_ptr = args.scale_matb_ptr; | ||||
|       params.b_scale_dtype = args.scale_matb_dtype.value(); | ||||
|       params.ldb = args.ldb; | ||||
|       params.b_dtype = args.matb->scalar_type(); | ||||
|       params.b_scale_dtype = args.scale_matb_dtype.value(); | ||||
|       params.b_scaling_type = args.scaling_matb_type.value(); | ||||
|       params.bias_ptr = bias ? bias->data_ptr(): nullptr; | ||||
|       params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; | ||||
|       params.c = args.result->data_ptr(); | ||||
|       params.c_scale_ptr = args.scale_result_ptr; | ||||
|       params.ldc = args.result_ld; | ||||
|       params.c_dtype = out_dtype_; | ||||
|       params.use_fast_accum = use_fast_accum; | ||||
|       if (transa_ && transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) | ||||
|       } | ||||
|       else if (transa_ && !transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N) | ||||
|       } | ||||
|       else if (!transa_ && transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T) | ||||
|       } | ||||
|       else if (!transa_ && !transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N) | ||||
|       } | ||||
|       else { | ||||
|         TORCH_CHECK(false, "unreachable"); | ||||
|       } | ||||
|     }), | ||||
|     kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); | ||||
| #undef TUNABLE_DISPATCH | ||||
|   } | ||||
|   else | ||||
| #endif | ||||
|  { | ||||
|     at::cuda::blas::scaled_gemm( | ||||
|         args.transa, | ||||
|         args.transb, | ||||
|         args.m, | ||||
|         args.n, | ||||
|         args.k, | ||||
|         args.mata->data_ptr(), | ||||
|         args.scale_mata_ptr, | ||||
|         args.lda, | ||||
|         args.mata->scalar_type(), | ||||
|         args.scale_mata_dtype.value(), | ||||
|         args.scaling_mata_type.value(), | ||||
|         args.matb->data_ptr(), | ||||
|         args.scale_matb_ptr, | ||||
|         args.ldb, | ||||
|         args.matb->scalar_type(), | ||||
|         args.scale_matb_dtype.value(), | ||||
|         args.scaling_matb_type.value(), | ||||
|         bias ? bias->data_ptr(): nullptr, | ||||
|         bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, | ||||
|         args.result->data_ptr(), | ||||
|         args.scale_result_ptr, | ||||
|         args.result_ld, | ||||
|         out_dtype_, | ||||
|         use_fast_accum); | ||||
|   } | ||||
|  | ||||
|   return out; | ||||
| } | ||||
|  | ||||
| namespace { | ||||
| @ -1971,6 +1910,159 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> | ||||
|   { "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE }, | ||||
|   { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; | ||||
|  | ||||
| Tensor& | ||||
| _cutlass_scaled_gemm( | ||||
|           const Tensor& mat1, const Tensor& mat2, | ||||
|           const Tensor& scale_a, const Tensor& scale_b, | ||||
|           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const bool use_fast_accum, | ||||
|           Tensor& out) { | ||||
|   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); | ||||
|   const auto out_dtype_ = args.result->scalar_type(); | ||||
|   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|   if (tuning_ctx->IsTunableOpEnabled()) { | ||||
| #define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \ | ||||
|         if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|         else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|         else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) {     \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t,         \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e4m3fn, at::Float8_e5m2, scalar_t,           \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         }                                                               \ | ||||
|         else if (mat1.scalar_type() == ScalarType::Float8_e5m2) {       \ | ||||
|           if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) {        \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2, at::Float8_e4m3fn, scalar_t,           \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|           else if (mat2.scalar_type() == ScalarType::Float8_e5m2) {     \ | ||||
|             static at::cuda::tunable::ScaledGemmTunableOp<              \ | ||||
|                 at::Float8_e5m2, at::Float8_e5m2, scalar_t,             \ | ||||
|                 BLASOP_A, BLASOP_B> scaledgemm{};                       \ | ||||
|             scaledgemm(¶ms);                                        \ | ||||
|           }                                                             \ | ||||
|         } | ||||
|     AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] { | ||||
|       bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); | ||||
|       bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); | ||||
|       at::cuda::tunable::ScaledGemmParams<scalar_t> params; | ||||
|       params.transa = args.transa; | ||||
|       params.transb = args.transb; | ||||
|       params.m = args.m; | ||||
|       params.n = args.n; | ||||
|       params.k = args.k; | ||||
|       params.a = args.mata->data_ptr(); | ||||
|       params.a_scale_ptr = args.scale_mata_ptr; | ||||
|       params.a_scale_dtype = args.scale_mata_dtype.value(); | ||||
|       params.lda = args.lda; | ||||
|       params.a_dtype = args.mata->scalar_type(); | ||||
|       params.a_scale_dtype = args.scale_mata_dtype.value(); | ||||
|       params.a_scaling_type = args.scaling_mata_type.value(); | ||||
|       params.b = args.matb->data_ptr(); | ||||
|       params.b_scale_ptr = args.scale_matb_ptr; | ||||
|       params.b_scale_dtype = args.scale_matb_dtype.value(); | ||||
|       params.ldb = args.ldb; | ||||
|       params.b_dtype = args.matb->scalar_type(); | ||||
|       params.b_scale_dtype = args.scale_matb_dtype.value(); | ||||
|       params.b_scaling_type = args.scaling_matb_type.value(); | ||||
|       params.bias_ptr = bias ? bias->data_ptr(): nullptr; | ||||
|       params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; | ||||
|       params.c = args.result->data_ptr(); | ||||
|       params.c_scale_ptr = args.scale_result_ptr; | ||||
|       params.ldc = args.result_ld; | ||||
|       params.c_dtype = out_dtype_; | ||||
|       params.use_fast_accum = use_fast_accum; | ||||
|       if (transa_ && transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) | ||||
|       } | ||||
|       else if (transa_ && !transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N) | ||||
|       } | ||||
|       else if (!transa_ && transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T) | ||||
|       } | ||||
|       else if (!transa_ && !transb_) { | ||||
|         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N) | ||||
|       } | ||||
|       else { | ||||
|         TORCH_CHECK(false, "unreachable"); | ||||
|       } | ||||
|     }), | ||||
|     kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); | ||||
| #undef TUNABLE_DISPATCH | ||||
|   } | ||||
|   else | ||||
| #endif | ||||
|  { | ||||
|     at::cuda::blas::scaled_gemm( | ||||
|         args.transa, | ||||
|         args.transb, | ||||
|         args.m, | ||||
|         args.n, | ||||
|         args.k, | ||||
|         args.mata->data_ptr(), | ||||
|         args.scale_mata_ptr, | ||||
|         args.lda, | ||||
|         args.mata->scalar_type(), | ||||
|         args.scale_mata_dtype.value(), | ||||
|         args.scaling_mata_type.value(), | ||||
|         args.matb->data_ptr(), | ||||
|         args.scale_matb_ptr, | ||||
|         args.ldb, | ||||
|         args.matb->scalar_type(), | ||||
|         args.scale_matb_dtype.value(), | ||||
|         args.scaling_matb_type.value(), | ||||
|         bias ? bias->data_ptr(): nullptr, | ||||
|         bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, | ||||
|         args.result->data_ptr(), | ||||
|         args.scale_result_ptr, | ||||
|         args.result_ld, | ||||
|         out_dtype_, | ||||
|         use_fast_accum); | ||||
|   } | ||||
|   return out; | ||||
| } | ||||
|  | ||||
| Tensor& | ||||
| _scaled_tensorwise_tensorwise( | ||||
|           const Tensor& mat_a, const Tensor& mat_b, | ||||
| @ -1990,7 +2082,7 @@ _scaled_tensorwise_tensorwise( | ||||
|   auto scaling_choice_a = ScalingType::TensorWise; | ||||
|   auto scaling_choice_b = ScalingType::TensorWise; | ||||
|  | ||||
|   _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|   _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
| @ -2026,7 +2118,7 @@ _scaled_rowwise_rowwise( | ||||
|   if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) | ||||
|       // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales | ||||
|       ||  (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) { | ||||
|     TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); | ||||
|     TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); | ||||
|     at::cuda::detail::f8f8bf16_rowwise( | ||||
|         mat_a, | ||||
|         mat_b, | ||||
| @ -2052,38 +2144,11 @@ _scaled_rowwise_rowwise( | ||||
|        "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); | ||||
| #endif | ||||
|  | ||||
|   _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|   _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
|  | ||||
| // Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling. | ||||
| // Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1, | ||||
| // and strides become somewhat meaningless | ||||
| void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) { | ||||
|   if (scale_type == ScalingType::BlockWise1x128) { | ||||
|     TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1), | ||||
|         "at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ", | ||||
|         "shape=", scale.sizes(), ", stride=", scale.strides()); | ||||
|     auto expected_size = ceil_div<int64_t>(t.size(1), 128); | ||||
|     TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)), | ||||
|         "at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ", | ||||
|         "shape=", scale.sizes(), ", stride=", scale.strides()); | ||||
|   } else if (scale_type == ScalingType::BlockWise128x128) { | ||||
|       TORCH_CHECK_VALUE(check_size_stride( | ||||
|           scale, | ||||
|           0, | ||||
|           ceil_div<int64_t>(t.size(0), 128), | ||||
|           ceil_div<int64_t>(t.size(1), 128)), | ||||
|         "at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ", | ||||
|         "shape=", scale.sizes(), ", stride=", scale.strides()); | ||||
|       TORCH_CHECK(check_size_stride( | ||||
|           scale, 1, ceil_div<int64_t>(t.size(1), 128), 1), | ||||
|         "at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ", | ||||
|         "shape=", scale.sizes(), ", stride=", scale.strides()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| Tensor& | ||||
| _scaled_block1x128_block1x128( | ||||
|           const Tensor& mat_a, const Tensor& mat_b, | ||||
| @ -2101,14 +2166,15 @@ _scaled_block1x128_block1x128( | ||||
|   TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat, | ||||
|       "scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes()) | ||||
|  | ||||
|   TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0)); | ||||
|   TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1)); | ||||
|   TORCH_CHECK(scale_b.stride(0) == scale_b.size(1), | ||||
|       "expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1)); | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x128; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x128; | ||||
|  | ||||
|   // Check scale strides (including stride=1 small cases) | ||||
|   _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a); | ||||
|   _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b); | ||||
|  | ||||
|   _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|   _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
| @ -2123,8 +2189,6 @@ _scaled_block128x128_block1x128( | ||||
|           Tensor& out) { | ||||
|   // Restrictions: | ||||
|   // A, B are FP8, scales are fp32, shape K//128 | ||||
|   std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl; | ||||
|   std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl; | ||||
|   TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", | ||||
|       mat_a.scalar_type(), mat_b.scalar_type()); | ||||
|   TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat, | ||||
| @ -2132,14 +2196,15 @@ _scaled_block128x128_block1x128( | ||||
|   TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat, | ||||
|       "scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes()) | ||||
|  | ||||
|   TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)); | ||||
|   TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1)); | ||||
|   TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1), | ||||
|       "expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0)); | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise128x128; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x128; | ||||
|  | ||||
|   // Check scale strides (including stride=1 small cases) | ||||
|   _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a); | ||||
|   _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b); | ||||
|  | ||||
|   _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|   _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
| @ -2161,14 +2226,15 @@ _scaled_block1x128_block128x128( | ||||
|   TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat, | ||||
|       "scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes()) | ||||
|  | ||||
|   TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0)); | ||||
|   TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0)); | ||||
|   TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0), | ||||
|       "expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1)); | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x128; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise128x128; | ||||
|  | ||||
|   // Check scale strides (including stride=1 small cases) | ||||
|   _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a); | ||||
|   _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b); | ||||
|  | ||||
|   _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|   _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
| @ -2222,7 +2288,7 @@ _scaled_mxfp8_mxfp8( | ||||
| #endif | ||||
| #endif | ||||
|  | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
|   return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
| } | ||||
|  | ||||
| Tensor& | ||||
| @ -2259,7 +2325,7 @@ _scaled_nvfp4_nvfp4( | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x16; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x16; | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
|   return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -2508,9 +2574,7 @@ _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||
|         const Tensor& mat_a, | ||||
|         const Tensor& mat_b, | ||||
|         const Tensor& scale_a, | ||||
|         const SwizzleType& swizzle_a, | ||||
|         const Tensor& scale_b, | ||||
|         const SwizzleType& swizzle_b, | ||||
|         const std::optional<at::Tensor>& offs, | ||||
|         Tensor& out) { | ||||
|     const bool a_is_2d = mat_a.dim() == 2; | ||||
| @ -2521,16 +2585,6 @@ _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||
|     TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases"); | ||||
|     TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets"); | ||||
|     TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm"); | ||||
|     // MXFP8 expects float8_e8m0fnu scales. | ||||
|     TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu, | ||||
|         "For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors."); | ||||
| #ifdef USE_ROCM | ||||
|     TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE, | ||||
|         "For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE"); | ||||
| #else | ||||
|     TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4, | ||||
|         "For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4"); | ||||
| #endif | ||||
|  | ||||
| #if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM) | ||||
|     fbgemm_gpu::mx8mx8bf16_grouped_mm( | ||||
| @ -2615,9 +2669,6 @@ _f8_f8_bf16_rowwise_grouped_mm( | ||||
|       const std::optional<Tensor>& bias, | ||||
|       bool use_fast_accum, | ||||
|       Tensor& out) { | ||||
|   // FP8 per-tensor and per-row scaling expect fp32 scales. | ||||
|   TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, | ||||
|       "For grouped FP8 rowwise, both scales must be float32 tensors"); | ||||
| #ifndef USE_ROCM | ||||
|   return _f8_f8_bf16_rowwise_grouped_mm_cuda( | ||||
|       mat_a, | ||||
| @ -2717,15 +2768,11 @@ _scaled_grouped_mm_cuda( | ||||
| #endif | ||||
|  | ||||
|   if (is_mx8mx8bf16) { | ||||
|     // Note: Passing implied SwizzleType here, correctness of scale previously checked | ||||
|     //       in `check_scale` call | ||||
|     return _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||
|         mat_a, | ||||
|         mat_b, | ||||
|         scale_a, | ||||
|         SwizzleType::SWIZZLE_32_4_4, | ||||
|         scale_b, | ||||
|         SwizzleType::SWIZZLE_32_4_4, | ||||
|         offs.value(), | ||||
|         out); | ||||
|   } | ||||
| @ -2742,140 +2789,6 @@ _scaled_grouped_mm_cuda( | ||||
|       out); | ||||
| } | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{ | ||||
|   { "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, | ||||
|   { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| Tensor | ||||
| _scaled_grouped_mm_cuda_v2( | ||||
|           const Tensor& mat_a, const Tensor& mat_b, | ||||
|           ArrayRef<Tensor> scale_a, | ||||
|           IntArrayRef scale_recipe_a, | ||||
|           IntArrayRef swizzle_a, | ||||
|           ArrayRef<Tensor> scale_b, | ||||
|           IntArrayRef scale_recipe_b, | ||||
|           IntArrayRef swizzle_b, | ||||
|           const std::optional<Tensor>& offs, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const std::optional<c10::ScalarType> out_dtype, | ||||
|           IntArrayRef contraction_dim, | ||||
|           bool use_fast_accum) { | ||||
|   bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); | ||||
|   TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+"); | ||||
|  | ||||
|   TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); | ||||
|   TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); | ||||
|   TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); | ||||
|   TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); | ||||
|   const bool a_is_2d = mat_a.dim() == 2; | ||||
|   const bool b_is_2d = mat_b.dim() == 2; | ||||
|  | ||||
|   // NOTE(slayton): For sub-1B formats want contraction_dim argument? | ||||
|   if (!a_is_2d || !b_is_2d) { | ||||
|     if (contraction_dim.size() > 0) { | ||||
|       const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]); | ||||
|       TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b), | ||||
|           "Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ", | ||||
|           mat_b.size(dim_b)); | ||||
|       // Note: only (-1, -2) is currently supported | ||||
|       TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only"); | ||||
|     } else { | ||||
|       TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); | ||||
|     } | ||||
|   } | ||||
|   TORCH_CHECK_VALUE( | ||||
|     mat_a.size(-1) % 16 == 0, | ||||
|     "Expected trailing dimension of mat_a to be divisible by 16 ", | ||||
|     "but got mat1 shape: (", | ||||
|     mat_a.sizes(), | ||||
|     ")."); | ||||
|   TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0, | ||||
|     "Expected mat_b shape to be divisible by 16 ", | ||||
|     "but got mat_b shape: (", | ||||
|     mat_b.sizes(), | ||||
|     ")."); | ||||
|  | ||||
|   TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet"); | ||||
|   TORCH_CHECK_VALUE(offs.has_value() ==  (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); | ||||
|  | ||||
|   // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present. | ||||
|   //       for rowwise, no offsets implies 3d-3d and is handled by lower-level | ||||
|   //       routines | ||||
|   if (offs.has_value()) { | ||||
|     TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D"); | ||||
|     TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32"); | ||||
|   } | ||||
|  | ||||
|   const auto out_dtype_ = out_dtype.value_or(kBFloat16); | ||||
|   TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); | ||||
|  | ||||
|   Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); | ||||
|  | ||||
|   // Conversion of implicitly-defined enums to explicit | ||||
|   auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a); | ||||
|   auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a); | ||||
|   auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b); | ||||
|   auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b); | ||||
|  | ||||
|   // at this point we can start working out what we want to be doing | ||||
|   // Try to do as few steps as possible. | ||||
|   // NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed. | ||||
|   // Do this via a list of defined (name, acceptance, concrete_impl) tuples. | ||||
|   ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE; | ||||
|   for (const auto& fn_entry : scale_grouped_kernel_dispatch) { | ||||
|     const auto [name, accept_fn, scaled_gemm_impl] = fn_entry; | ||||
|     bool ok = accept_fn(mat_a.scalar_type(), | ||||
|                         scale_recipe_a_enum, | ||||
|                         scale_a, | ||||
|                         mat_b.scalar_type(), | ||||
|                         scale_recipe_b_enum, | ||||
|                         scale_b); | ||||
|     if (ok) { | ||||
|       gemm_impl = scaled_gemm_impl; | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|   TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE, | ||||
|       "No gemm implementation was found"); | ||||
|  | ||||
|   switch (gemm_impl) { | ||||
|     case ScaledGemmImplementation::ROWWISE_ROWWISE: { | ||||
|       const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1; | ||||
|       _check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier); | ||||
|       _check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier); | ||||
|       return _f8_f8_bf16_rowwise_grouped_mm( | ||||
|           mat_a, | ||||
|           mat_b, | ||||
|           scale_a[0], | ||||
|           scale_b[0], | ||||
|           offs, | ||||
|           bias, | ||||
|           use_fast_accum, | ||||
|           out); | ||||
|     } | ||||
|     case ScaledGemmImplementation::MXFP8_MXFP8: { | ||||
|       _check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */); | ||||
|       _check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */); | ||||
|       return _mx8_mx8_bf16_grouped_mm_fbgemm( | ||||
|           mat_a, | ||||
|           mat_b, | ||||
|           scale_a[0], | ||||
|           swizzle_a_enum[0], | ||||
|           scale_b[0], | ||||
|           swizzle_b_enum[0], | ||||
|           offs.value(), | ||||
|           out); | ||||
|     } | ||||
|     default: | ||||
|       TORCH_CHECK_NOT_IMPLEMENTED(false, | ||||
|           "_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here"); | ||||
|   } | ||||
| } | ||||
|  | ||||
| Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, | ||||
| const std::optional<at::Tensor>& offs, | ||||
| const std::optional<at::Tensor>& bias, | ||||
|  | ||||
| @ -856,13 +856,9 @@ struct type_specialized_kernel_launcher { | ||||
|       out_calc_t output_offset_calculator, | ||||
|       loader_t loader, | ||||
|       storer_t storer) { | ||||
|     constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0]; | ||||
|     constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1]; | ||||
|     constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2]; | ||||
|     if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) { | ||||
|       using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>; | ||||
|       using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>; | ||||
|       using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>; | ||||
|     if (ret_t == rt_binary_specializations[arg_index][0] && | ||||
|         arg0_t == rt_binary_specializations[arg_index][1] && | ||||
|         arg1_t == rt_binary_specializations[arg_index][2]) | ||||
|       launch_vectorized_templated_kernel< | ||||
|           func_t, | ||||
|           array_t, | ||||
| @ -870,9 +866,12 @@ struct type_specialized_kernel_launcher { | ||||
|           out_calc_t, | ||||
|           loader_t, | ||||
|           storer_t, | ||||
|           cret_t, | ||||
|           carg0_t, | ||||
|           carg1_t>( | ||||
|           decltype(c10::impl::ScalarTypeToCPPType< | ||||
|                    rt_binary_specializations[arg_index][0]>::t), | ||||
|           decltype(c10::impl::ScalarTypeToCPPType< | ||||
|                    rt_binary_specializations[arg_index][1]>::t), | ||||
|           decltype(c10::impl::ScalarTypeToCPPType< | ||||
|                    rt_binary_specializations[arg_index][2]>::t)>( | ||||
|           numel, | ||||
|           f, | ||||
|           data, | ||||
| @ -880,7 +879,6 @@ struct type_specialized_kernel_launcher { | ||||
|           output_offset_calculator, | ||||
|           loader, | ||||
|           storer); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -655,14 +655,8 @@ struct ReduceOp { | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|     // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics | ||||
|     // matching Triton, etc. | ||||
|     // todo for AMD | ||||
|     #ifdef USE_ROCM | ||||
|  | ||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||
|     #else | ||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||
|     #endif | ||||
|       #pragma unroll | ||||
|       for (int i = 0; i < output_vec_size; i++) { | ||||
|         arg_t other = ops.warp_shfl_down(value[i], offset); | ||||
|  | ||||
| @ -77,8 +77,8 @@ struct nansum_functor_complex { | ||||
| #if AT_USE_JITERATOR() | ||||
|   void operator()(TensorIterator& iter) { | ||||
|     std::string func = jiterator_stringify( | ||||
|         arg_t combine(arg_t a, arg_t b) { | ||||
|           return a + (std::isnan(b) ? arg_t{0.} : b); | ||||
|         arg_t combine(arg_t a, scalar_t b) { | ||||
|           return a + (std::isnan(b) ? arg_t{0.} : arg_t{b}); | ||||
|         } | ||||
|     ); | ||||
|     jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>( | ||||
|  | ||||
| @ -464,7 +464,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|     } | ||||
| #endif | ||||
|     int32_t trailingSize; | ||||
|     int nDimsLocal = nDims; | ||||
|     TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam; | ||||
|     if (isInOutAligned) { | ||||
|       // in this case we can and should flatten the tensors after the cat dim | ||||
| @ -478,7 +477,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|       // and divide all strides except last by elems_per_vec (last stride is 1 always) | ||||
|       // for input, we will fix up the sizes and strides in the kernel directly | ||||
|       kernelOutputParam = outputParam; | ||||
|       nDimsLocal = dimension + 1; | ||||
|       nDims = dimension + 1; | ||||
|       constexpr auto elems_per_vec = alignment / sizeof(scalar_t); | ||||
|       auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1]; | ||||
|       kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec; | ||||
| @ -495,7 +494,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|       case 0: | ||||
|         break; | ||||
|       case 1: | ||||
|         cat_dim = nDimsLocal - cat_dim; | ||||
|         cat_dim = nDims - cat_dim; | ||||
|         break; | ||||
|       default: | ||||
|         cat_dim--; | ||||
| @ -526,7 +525,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|               data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\ | ||||
|     }\ | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|     switch (nDimsLocal) { | ||||
|     switch (nDims) { | ||||
|       case 1: | ||||
|         HANDLE_CASE(1); | ||||
|         break; | ||||
|  | ||||
| @ -21,15 +21,9 @@ namespace { | ||||
| struct offset_t { | ||||
|   int stride; | ||||
|   int begin; | ||||
|   __device__ int operator[](int i) const { | ||||
|   __device__ int operator[](int i) { | ||||
|     return stride * (begin + i); | ||||
|   } | ||||
| #if CCCL_VERSION >= 3001000 | ||||
|   __device__ offset_t& operator+=(int i) { | ||||
|     begin += i; | ||||
|     return *this; | ||||
|   } | ||||
| #endif | ||||
| }; | ||||
| // Segmented sort by full sort algorithm:. | ||||
| // Say we are sorting a (2, 3) tensor. We have in flattened form: | ||||
|  | ||||
| @ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| // Helper function to compute output pixel range that can contribute to input pixel | ||||
| template <typename accscalar_t> | ||||
| __device__ __forceinline__ void compute_output_range( | ||||
|     int input_pos, | ||||
|     accscalar_t scale, | ||||
|     int output_size, | ||||
|     bool align_corners, | ||||
|     int& min_output, | ||||
|     int& max_output) { | ||||
|   accscalar_t lo, hi; | ||||
|   if (align_corners) { | ||||
|       lo = static_cast<accscalar_t>(input_pos - 1) / scale; | ||||
|       hi = static_cast<accscalar_t>(input_pos + 1) / scale; | ||||
|   } else { | ||||
|       lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5); | ||||
|       hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5); | ||||
|   } | ||||
|   min_output = max(0, static_cast<int>(ceil(lo))); | ||||
|   max_output = min(output_size - 1, static_cast<int>(floor(hi))); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| // Backward (adjoint) operation 1 <- 2 (accumulates) | ||||
| template <typename scalar_t, typename accscalar_t> | ||||
| C10_LAUNCH_BOUNDS_1(1024) | ||||
| @ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame( | ||||
|     const bool align_corners, | ||||
|     scalar_t* __restrict__ idata, | ||||
|     const scalar_t* __restrict__ odata) { | ||||
|   // In C++, integer multiplication, like in standard arithmetic, is generally commutative. | ||||
|   const size_t i_numel = nc * width1 * height1; | ||||
| #ifdef USE_ROCM | ||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel; | ||||
|        index += blockDim.x * gridDim.x) { | ||||
|     // Decode input pixel coordinates | ||||
|     size_t index_temp = index; | ||||
|     const int w1 = index_temp % width1; | ||||
|     index_temp /= width1; | ||||
|     const int h1 = index_temp % height1; | ||||
|     const size_t nc_idx = index_temp / height1; | ||||
|  | ||||
|     accscalar_t grad_sum = 0; | ||||
|  | ||||
|     // Find range of output pixels that could interpolate from this input pixel | ||||
|     int h2_min, h2_max, w2_min, w2_max; | ||||
|     compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max); | ||||
|     compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max); | ||||
|  | ||||
|     // Iterate over potential output pixels | ||||
|     for (int h2 = h2_min; h2 <= h2_max; h2++) { | ||||
|       for (int w2 = w2_min; w2 <= w2_max; w2++) { | ||||
|         // Compute source coordinates for this output pixel | ||||
|         const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>( | ||||
|             rheight, h2, align_corners, /*cubic=*/false); | ||||
|         const int h1_base = (int)h1r; | ||||
|         const int h1p = (h1_base < height1 - 1) ? 1 : 0; | ||||
|         const accscalar_t h1lambda = h1r - h1_base; | ||||
|         const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda; | ||||
|  | ||||
|         const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>( | ||||
|             rwidth, w2, align_corners, /*cubic=*/false); | ||||
|         const int w1_base = (int)w1r; | ||||
|         const int w1p = (w1_base < width1 - 1) ? 1 : 0; | ||||
|         const accscalar_t w1lambda = w1r - w1_base; | ||||
|         const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda; | ||||
|  | ||||
|         // Check if our input pixel participates in this interpolation and accumulate all weights | ||||
|         // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse | ||||
|         // to the same pixel, so we need to accumulate weights from all matching positions | ||||
|         accscalar_t weight = 0; | ||||
|  | ||||
|         // Check all four interpolation positions and accumulate weights | ||||
|         if (h1 == h1_base && w1 == w1_base) { | ||||
|           weight += h0lambda * w0lambda;  // top-left | ||||
|         } | ||||
|         if (h1 == h1_base && w1 == w1_base + w1p) { | ||||
|           weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0) | ||||
|         } | ||||
|         if (h1 == h1_base + h1p && w1 == w1_base) { | ||||
|           weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0) | ||||
|         } | ||||
|         if (h1 == h1_base + h1p && w1 == w1_base + w1p) { | ||||
|           weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions) | ||||
|         } | ||||
|  | ||||
|         if (weight > 0) { | ||||
|           const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2; | ||||
|           grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Write accumulated gradient (no atomics needed) | ||||
|     idata[index] = static_cast<scalar_t>(grad_sum); | ||||
|   } | ||||
| #else | ||||
|   const size_t o_numel = nc * width2 * height2; | ||||
|   const size_t i_numel = nc * width1 * height1; | ||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; | ||||
|        index += blockDim.x * gridDim.x) { | ||||
|     size_t index_temp = index; | ||||
| @ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame( | ||||
|         static_cast<scalar_t>(h1lambda * w1lambda * d2val), | ||||
|         true); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename accscalar_t> | ||||
| @ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|   // threads are not covering the whole input tensor. | ||||
|   grad_input.zero_(); | ||||
|  | ||||
|   const size_t num_kernels = nbatch * channels * output_height * output_width; | ||||
|   const int num_threads = std::min( | ||||
|       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
| @ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   constexpr bool use_input = true; | ||||
| #else | ||||
|   constexpr bool use_input = false; | ||||
| #endif | ||||
|  | ||||
|   AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|       at::ScalarType::Half, at::ScalarType::BFloat16, | ||||
|       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { | ||||
| @ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||
|           input_width, output_width, align_corners, scales_w); | ||||
|  | ||||
|       const size_t num_kernels = nbatch * channels * output_height * output_width; | ||||
|  | ||||
|       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t> | ||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>( | ||||
|               input_height, | ||||
| @ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||
|           input_width, output_width, align_corners, scales_w); | ||||
|  | ||||
|       const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width); | ||||
|  | ||||
|       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t> | ||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), | ||||
|              num_threads, | ||||
|  | ||||
| @ -466,11 +466,7 @@ struct ReduceJitOp { | ||||
|  | ||||
|     __syncthreads(); | ||||
|  | ||||
|     #ifdef USE_ROCM | ||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||
|     #else | ||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||
|     #endif | ||||
|       #pragma unroll | ||||
|       for (int i = 0; i < output_vec_size; i++) { | ||||
|         arg_t other = reducer::warp_shfl_down(value[i], offset); | ||||
|  | ||||
| @ -512,7 +512,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps) | ||||
| } | ||||
|  | ||||
| static MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { | ||||
|   // (1.0f + erf(x*SQRT1_2)) * 0.5f; | ||||
|   // (1.0f + erf(x*SQRT1_2)) * 0.5f * x; | ||||
|   auto dataType = [inputTensor dataType]; | ||||
|   const float SQRT1_2 = 0.707106781186547524400844362104849039f; | ||||
|   MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType]; | ||||
|  | ||||
| @ -54,10 +54,6 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) { | ||||
|   using namespace mps; | ||||
|   using CachedGraph = MPSBinaryCachedGraph; | ||||
|  | ||||
|   if (self.numel() == 0 & other.numel() == 0) { | ||||
|     return zeros({}, self.options()); | ||||
|   } | ||||
|  | ||||
|   dot_check(self, other); | ||||
|  | ||||
|   auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); | ||||
|  | ||||
| @ -7183,12 +7183,6 @@ | ||||
|     CUDA: _scaled_grouped_mm_cuda | ||||
|   tags: needs_exact_strides | ||||
|  | ||||
| - func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor | ||||
|   variants: function | ||||
|   dispatch: | ||||
|     CUDA: _scaled_grouped_mm_cuda_v2 | ||||
|   tags: needs_exact_strides | ||||
|  | ||||
| - func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor | ||||
|   variants: function | ||||
|   dispatch: | ||||
|  | ||||
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,fail_accuracy,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,fail_accuracy,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,0 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,0 | ||||
|  | ||||
| 
 | 
| @ -10,18 +10,10 @@ beit_base_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_base_distilled_patch16_224,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| deit_tiny_patch16_224.fb_in1k,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| dm_nfnet_f0,pass,6 | ||||
|  | ||||
|  | ||||
| @ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6 | ||||
|  | ||||
|  | ||||
| visformer_small,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch14_dinov2.lvd142m,fail_accuracy,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| vit_base_patch16_siglip_256,pass,7 | ||||
|  | ||||
| 
 | 
| @ -1060,8 +1060,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): | ||||
|             frozen_model_iter_fn = export_nativert(model, example_inputs) | ||||
|         elif args.torchscript_jit_trace: | ||||
|             frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs) | ||||
|         elif args.aot_precompile: | ||||
|             frozen_model_iter_fn = aot_precompile(model, example_inputs) | ||||
|         else: | ||||
|             if kwargs["hf_llm"]: | ||||
|                 # If it's an llm, we want to optimize model.forward, and use | ||||
| @ -1497,37 +1495,6 @@ def export(model, example_inputs): | ||||
|     return opt_export | ||||
|  | ||||
|  | ||||
| def aot_precompile(model, example_inputs): | ||||
|     example_args, example_kwargs = _normalize_bench_inputs(example_inputs) | ||||
|  | ||||
|     with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: | ||||
|         save_path = f.name | ||||
|  | ||||
|     with fresh_cache(), torch._dynamo.config.patch("enable_aot_compile", True): | ||||
|         compiled_fn = torch.compile( | ||||
|             model, | ||||
|             fullgraph=True, | ||||
|             options={"guard_filter_fn": lambda guards: [False for _ in guards]}, | ||||
|         ).forward.aot_compile((example_args, example_kwargs)) | ||||
|  | ||||
|         compiled_fn.save_compiled_function(save_path) | ||||
|  | ||||
|         torch._dynamo.reset() | ||||
|         with open(save_path, "rb") as f: | ||||
|             load_start_time = time.perf_counter() | ||||
|             loaded_fn = torch.compiler.load_compiled_function(f) | ||||
|             load_end_time = time.perf_counter() | ||||
|             print( | ||||
|                 f"AOT Precompile loading time: {load_end_time - load_start_time} seconds" | ||||
|             ) | ||||
|  | ||||
|             def opt_aot_precompile(_, example_inputs, collect_outputs=False): | ||||
|                 example_args, example_kwargs = _normalize_bench_inputs(example_inputs) | ||||
|                 return loaded_fn(model, *example_args, **example_kwargs) | ||||
|  | ||||
|             return opt_aot_precompile | ||||
|  | ||||
|  | ||||
| def export_nativert(model, example_inputs): | ||||
|     optimized = NativeRTCache.load(model, example_inputs) | ||||
|  | ||||
| @ -2307,7 +2274,6 @@ class BenchmarkRunner: | ||||
|                     or self.args.export_aot_inductor | ||||
|                     or self.args.export_nativert | ||||
|                     or self.args.torchscript_jit_trace | ||||
|                     or self.args.aot_precompile | ||||
|                 ): | ||||
|                     # apply export on module directly | ||||
|                     # no need for n iterations | ||||
| @ -2763,7 +2729,6 @@ class BenchmarkRunner: | ||||
|                 self.args.export_aot_inductor | ||||
|                 or self.args.export_nativert | ||||
|                 or self.args.torchscript_jit_trace | ||||
|                 or self.args.aot_precompile | ||||
|             ): | ||||
|                 optimized_model_iter_fn = optimize_ctx | ||||
|             else: | ||||
| @ -3540,11 +3505,6 @@ def parse_args(args=None): | ||||
|         action="store_true", | ||||
|         help="Measure pass rate with Export+AOTInductor", | ||||
|     ) | ||||
|     group.add_argument( | ||||
|         "--aot-precompile", | ||||
|         action="store_true", | ||||
|         help="Measure pass rate with AOT Precompile", | ||||
|     ) | ||||
|     group.add_argument( | ||||
|         "--export-nativert", | ||||
|         action="store_true", | ||||
| @ -3975,10 +3935,6 @@ def run(runner, args, original_dir=None): | ||||
|         optimize_ctx = export | ||||
|         experiment = speedup_experiment | ||||
|         output_filename = "export.csv" | ||||
|     elif args.aot_precompile: | ||||
|         optimize_ctx = aot_precompile | ||||
|         experiment = speedup_experiment | ||||
|         output_filename = "aot_precompile.csv" | ||||
|     elif args.export_nativert: | ||||
|         optimize_ctx = export_nativert | ||||
|         experiment = speedup_experiment | ||||
|  | ||||
| @ -271,6 +271,8 @@ class TimmRunner(BenchmarkRunner): | ||||
|             memory_format=torch.channels_last if channels_last else None, | ||||
|         ) | ||||
|  | ||||
|         self.num_classes = model.num_classes | ||||
|  | ||||
|         data_config = resolve_data_config( | ||||
|             vars(self._args) if timmversion >= "0.8.0" else self._args, | ||||
|             model=model, | ||||
| @ -300,6 +302,7 @@ class TimmRunner(BenchmarkRunner): | ||||
|         example_inputs = [ | ||||
|             example_inputs, | ||||
|         ] | ||||
|         self.target = self._gen_target(batch_size, device) | ||||
|  | ||||
|         self.loss = torch.nn.CrossEntropyLoss().to(device) | ||||
|  | ||||
| @ -367,6 +370,11 @@ class TimmRunner(BenchmarkRunner): | ||||
|                 tolerance = 1e-2 | ||||
|         return tolerance, cosine | ||||
|  | ||||
|     def _gen_target(self, batch_size, device): | ||||
|         return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_( | ||||
|             self.num_classes | ||||
|         ) | ||||
|  | ||||
|     def compute_loss(self, pred): | ||||
|         # High loss values make gradient checking harder, as small changes in | ||||
|         # accumulation order upsets accuracy checks. | ||||
|  | ||||
| @ -1,8 +1,6 @@ | ||||
| adv_inception_v3 128 | ||||
| beit_base_patch16_224 128 | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k 128 | ||||
| deit_base_distilled_patch16_224 128 | ||||
| deit_tiny_patch16_224.fb_in1k 128 | ||||
| dm_nfnet_f0 128 | ||||
| ghostnet_100 512 | ||||
| inception_v3 128 | ||||
| @ -14,5 +12,3 @@ repvgg_a2 128 | ||||
| swin_base_patch4_window7_224 128 | ||||
| tf_efficientnet_b0 128 | ||||
| visformer_small 128 | ||||
| vit_base_patch14_dinov2.lvd142m 128 | ||||
| vit_base_patch16_siglip_256 128 | ||||
| @ -1,8 +1,6 @@ | ||||
| adv_inception_v3,128 | ||||
| beit_base_patch16_224,64 | ||||
| convnextv2_nano.fcmae_ft_in22k_in1k,128 | ||||
| deit_base_distilled_patch16_224,64 | ||||
| deit_tiny_patch16_224.fb_in1k,128 | ||||
| dm_nfnet_f0,128 | ||||
| ghostnet_100,128 | ||||
| inception_v3,128 | ||||
| @ -14,5 +12,3 @@ repvgg_a2,128 | ||||
| swin_base_patch4_window7_224,64 | ||||
| tf_efficientnet_b0,128 | ||||
| visformer_small,128 | ||||
| vit_base_patch14_dinov2.lvd142m,128 | ||||
| ViT-B-16-SigLIP-i18n-256,128 | ||||
| @ -28,8 +28,101 @@ | ||||
|  | ||||
| namespace c10 { | ||||
|  | ||||
| // See [dtype Macros note] in torch/headeronly/core/ScalarType.h | ||||
| // regarding macros. | ||||
| // [dtype Macros note] For the macros below: | ||||
| // | ||||
| // For users: If you want to macro some code for all non-QInt scalar types | ||||
| // (i.e. types with complete information, you probably want one of the | ||||
| // AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are | ||||
| // designed to behave similarly to the Dispatch macros with the same name. | ||||
| // | ||||
| // For adding a new dtype: In the beginning, we had an idea that there was a | ||||
| // list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to | ||||
| // iterate over them.  But over the years we added weird types which couldn't | ||||
| // be handled uniformly everywhere and so in the end we ended up with some | ||||
| // mish-mosh of some helper macros, but mostly use sites making a call about | ||||
| // what dtypes they can or can't support.  So if you want to add a new dtype, | ||||
| // the preferred resolution is to find a dtype similar to what you want, | ||||
| // grep for it and edit all the sites you find this way.  If you need to add | ||||
| // a completely new kind of dtype, you're going to have to laboriously audit | ||||
| // all of the sites everywhere to figure out how it should work.  Consulting | ||||
| // some old PRs where we added new dtypes (check history of this file) can | ||||
| // help give you an idea where to start. | ||||
|  | ||||
| // If you want to support ComplexHalf for real, add ComplexHalf | ||||
| // into this macro (and change the name).  But beware: convert() | ||||
| // doesn't work for all the conversions you need... | ||||
| // | ||||
| // TODO: To add unsigned int types here, we must define accumulate type. | ||||
| // But uint8 currently accumulates into int64, so we would have to make | ||||
| // an inconsistent choice for the larger types.  Difficult. | ||||
| #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \ | ||||
|   _(uint8_t, Byte)                                                      \ | ||||
|   _(int8_t, Char)                                                       \ | ||||
|   _(int16_t, Short)                                                     \ | ||||
|   _(int, Int)                                                           \ | ||||
|   _(int64_t, Long)                                                      \ | ||||
|   _(at::Half, Half)                                                     \ | ||||
|   _(float, Float)                                                       \ | ||||
|   _(double, Double)                                                     \ | ||||
|   _(c10::complex<float>, ComplexFloat)                                  \ | ||||
|   _(c10::complex<double>, ComplexDouble)                                \ | ||||
|   _(bool, Bool)                                                         \ | ||||
|   _(at::BFloat16, BFloat16)                                             \ | ||||
|   _(at::Float8_e5m2, Float8_e5m2)                                       \ | ||||
|   _(at::Float8_e4m3fn, Float8_e4m3fn) | ||||
|  | ||||
| // This macro controls many of our C++ APIs, including constructors | ||||
| // for Scalar as well as the data() and item() accessors on Tensor | ||||
| #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ | ||||
|   _(uint8_t, Byte)                             \ | ||||
|   _(int8_t, Char)                              \ | ||||
|   _(int16_t, Short)                            \ | ||||
|   _(int, Int)                                  \ | ||||
|   _(int64_t, Long)                             \ | ||||
|   _(at::Half, Half)                            \ | ||||
|   _(float, Float)                              \ | ||||
|   _(double, Double)                            \ | ||||
|   _(c10::complex<c10::Half>, ComplexHalf)      \ | ||||
|   _(c10::complex<float>, ComplexFloat)         \ | ||||
|   _(c10::complex<double>, ComplexDouble)       \ | ||||
|   _(bool, Bool)                                \ | ||||
|   _(at::BFloat16, BFloat16)                    \ | ||||
|   _(at::Float8_e5m2, Float8_e5m2)              \ | ||||
|   _(at::Float8_e4m3fn, Float8_e4m3fn)          \ | ||||
|   _(at::Float8_e5m2fnuz, Float8_e5m2fnuz)      \ | ||||
|   _(at::Float8_e4m3fnuz, Float8_e4m3fnuz)      \ | ||||
|   _(at::Float8_e8m0fnu, Float8_e8m0fnu) | ||||
|  | ||||
| namespace impl { | ||||
|  | ||||
| // These are used to map ScalarTypes to C++ types. | ||||
|  | ||||
| template <c10::ScalarType N> | ||||
| struct ScalarTypeToCPPType; | ||||
|  | ||||
| #define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type)                \ | ||||
|   template <>                                                                \ | ||||
|   struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> {                 \ | ||||
|     using type = cpp_type;                                                   \ | ||||
|                                                                              \ | ||||
|     /* This is a workaround for the CUDA bug which prevents */               \ | ||||
|     /* ::detail::ScalarTypeToCType<T>::type being used directly due to */    \ | ||||
|     /* ambiguous reference which can't to be resolved. For some reason it */ \ | ||||
|     /* can't pick between at::detail and at::cuda::detail. */                \ | ||||
|     /* For repro example, please see: */                                     \ | ||||
|     /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */    \ | ||||
|     /* TODO: remove once the bug is fixed. */                                \ | ||||
|     static type t;                                                           \ | ||||
|   }; | ||||
|  | ||||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) | ||||
|  | ||||
| #undef SPECIALIZE_ScalarTypeToCPPType | ||||
|  | ||||
| template <c10::ScalarType N> | ||||
| using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type; | ||||
|  | ||||
| } // namespace impl | ||||
|  | ||||
| template <typename T> | ||||
| struct CppTypeToScalarType; | ||||
| @ -45,6 +138,130 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) | ||||
|  | ||||
| #undef SPECIALIZE_CppTypeToScalarType | ||||
|  | ||||
| // NB: despite its generic sounding name, the macros that don't take _AND | ||||
| // are mostly only used by tensorexpr | ||||
| #define AT_FORALL_INT_TYPES(_) \ | ||||
|   _(uint8_t, Byte)             \ | ||||
|   _(int8_t, Char)              \ | ||||
|   _(int16_t, Short)            \ | ||||
|   _(int, Int)                  \ | ||||
|   _(int64_t, Long) | ||||
|  | ||||
| #define AT_FORALL_SCALAR_TYPES(_) \ | ||||
|   _(uint8_t, Byte)                \ | ||||
|   _(int8_t, Char)                 \ | ||||
|   _(int16_t, Short)               \ | ||||
|   _(int, Int)                     \ | ||||
|   _(int64_t, Long)                \ | ||||
|   _(float, Float)                 \ | ||||
|   _(double, Double) | ||||
|  | ||||
| // These macros are often controlling how many template instantiations we | ||||
| // create for kernels.  It is typically inappropriate to add new dtypes here, | ||||
| // instead, new types should be added to use sites on a case-by-case basis. | ||||
| // We generally are not accepting new dtypes due to binary size concerns. | ||||
|  | ||||
| #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ | ||||
|   _(uint8_t, Byte)                                \ | ||||
|   _(int8_t, Char)                                 \ | ||||
|   _(int16_t, Short)                               \ | ||||
|   _(int, Int)                                     \ | ||||
|   _(int64_t, Long)                                \ | ||||
|   _(float, Float)                                 \ | ||||
|   _(double, Double)                               \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE>::t),  \ | ||||
|     SCALARTYPE) | ||||
|  | ||||
| #define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ | ||||
|   _(uint8_t, Byte)                                               \ | ||||
|   _(int8_t, Char)                                                \ | ||||
|   _(int16_t, Short)                                              \ | ||||
|   _(int, Int)                                                    \ | ||||
|   _(int64_t, Long)                                               \ | ||||
|   _(float, Float)                                                \ | ||||
|   _(double, Double)                                              \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<                   \ | ||||
|              ::c10::ScalarType::SCALARTYPE1>::t),                \ | ||||
|     SCALARTYPE1)                                                 \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<                   \ | ||||
|              ::c10::ScalarType::SCALARTYPE2>::t),                \ | ||||
|     SCALARTYPE2) | ||||
|  | ||||
| #define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ | ||||
|   _(uint8_t, Byte)                                                            \ | ||||
|   _(int8_t, Char)                                                             \ | ||||
|   _(int16_t, Short)                                                           \ | ||||
|   _(int, Int)                                                                 \ | ||||
|   _(int64_t, Long)                                                            \ | ||||
|   _(float, Float)                                                             \ | ||||
|   _(double, Double)                                                           \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<                                \ | ||||
|              ::c10::ScalarType::SCALARTYPE1>::t),                             \ | ||||
|     SCALARTYPE1)                                                              \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<                                \ | ||||
|              ::c10::ScalarType::SCALARTYPE2>::t),                             \ | ||||
|     SCALARTYPE2)                                                              \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<                                \ | ||||
|              ::c10::ScalarType::SCALARTYPE3>::t),                             \ | ||||
|     SCALARTYPE3) | ||||
|  | ||||
| #define AT_FORALL_SCALAR_TYPES_AND7(              \ | ||||
|     SCALARTYPE1,                                  \ | ||||
|     SCALARTYPE2,                                  \ | ||||
|     SCALARTYPE3,                                  \ | ||||
|     SCALARTYPE4,                                  \ | ||||
|     SCALARTYPE5,                                  \ | ||||
|     SCALARTYPE6,                                  \ | ||||
|     SCALARTYPE7,                                  \ | ||||
|     _)                                            \ | ||||
|   _(uint8_t, Byte)                                \ | ||||
|   _(int8_t, Char)                                 \ | ||||
|   _(int16_t, Short)                               \ | ||||
|   _(int, Int)                                     \ | ||||
|   _(int64_t, Long)                                \ | ||||
|   _(float, Float)                                 \ | ||||
|   _(double, Double)                               \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE1>::t), \ | ||||
|     SCALARTYPE1)                                  \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE2>::t), \ | ||||
|     SCALARTYPE2)                                  \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE3>::t), \ | ||||
|     SCALARTYPE3)                                  \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE4>::t), \ | ||||
|     SCALARTYPE4)                                  \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE5>::t), \ | ||||
|     SCALARTYPE5)                                  \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE6>::t), \ | ||||
|     SCALARTYPE6)                                  \ | ||||
|   _(decltype(::c10::impl::ScalarTypeToCPPType<    \ | ||||
|              ::c10::ScalarType::SCALARTYPE7>::t), \ | ||||
|     SCALARTYPE7) | ||||
|  | ||||
| #define AT_FORALL_QINT_TYPES(_) \ | ||||
|   _(c10::qint8, QInt8)          \ | ||||
|   _(c10::quint8, QUInt8)        \ | ||||
|   _(c10::qint32, QInt32)        \ | ||||
|   _(c10::quint4x2, QUInt4x2)    \ | ||||
|   _(c10::quint2x4, QUInt2x4) | ||||
|  | ||||
| #define AT_FORALL_FLOAT8_TYPES(_)         \ | ||||
|   _(at::Float8_e5m2, Float8_e5m2)         \ | ||||
|   _(at::Float8_e4m3fn, Float8_e4m3fn)     \ | ||||
|   _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ | ||||
|   _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ | ||||
|   _(at::Float8_e8m0fnu, Float8_e8m0fnu) | ||||
|  | ||||
| #define AT_FORALL_COMPLEX_TYPES(_)     \ | ||||
|   _(c10::complex<float>, ComplexFloat) \ | ||||
|   _(c10::complex<double>, ComplexDouble) | ||||
|  | ||||
| #define DEFINE_CONSTANT(_, name) \ | ||||
|   constexpr ScalarType k##name = ScalarType::name; | ||||
|  | ||||
| @ -52,6 +269,19 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) | ||||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) | ||||
| #undef DEFINE_CONSTANT | ||||
|  | ||||
| inline const char* toString(ScalarType t) { | ||||
| #define DEFINE_CASE(_, name) \ | ||||
|   case ScalarType::name:     \ | ||||
|     return #name; | ||||
|  | ||||
|   switch (t) { | ||||
|     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) | ||||
|     default: | ||||
|       return "UNKNOWN_SCALAR"; | ||||
|   } | ||||
| #undef DEFINE_CASE | ||||
| } | ||||
|  | ||||
| inline size_t elementSize(ScalarType t) { | ||||
| #define CASE_ELEMENTSIZE_CASE(ctype, name) \ | ||||
|   case ScalarType::name:                   \ | ||||
| @ -295,6 +525,12 @@ inline bool canCast(const ScalarType from, const ScalarType to) { | ||||
|  | ||||
| C10_API ScalarType promoteTypes(ScalarType a, ScalarType b); | ||||
|  | ||||
| inline std::ostream& operator<<( | ||||
|     std::ostream& stream, | ||||
|     at::ScalarType scalar_type) { | ||||
|   return stream << toString(scalar_type); | ||||
| } | ||||
|  | ||||
| // Returns a pair of strings representing the names for each dtype. | ||||
| // The returned pair is (name, legacy_name_if_applicable) | ||||
| C10_API std::pair<std::string, std::string> getDtypeNames( | ||||
|  | ||||
| @ -65,7 +65,7 @@ struct default_constructible | ||||
|  | ||||
| namespace impl { | ||||
|   template <typename T> | ||||
|   constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>* /*unused*/) | ||||
|   constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>*) | ||||
|   { | ||||
|     return true; | ||||
|   } | ||||
| @ -76,7 +76,7 @@ class type : public modifier<M, type<T, Tag, M...>>... | ||||
| { | ||||
| public: | ||||
|   template <typename TT = T, typename = std::enable_if_t<std::is_trivially_constructible<TT>{}>> | ||||
|   explicit type(uninitialized_t /*unused*/) | ||||
|   explicit type(uninitialized_t) | ||||
|     noexcept | ||||
|   { | ||||
|   } | ||||
| @ -138,7 +138,7 @@ private: | ||||
|  | ||||
| namespace impl { | ||||
|   template <typename T, typename Tag, typename ... Ms> | ||||
|   constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>* /*unused*/) { return true;} | ||||
|   constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>*) { return true;} | ||||
|   constexpr bool is_strong_type_func(...) { return false;} | ||||
|   template <typename T, typename Tag, typename ... Ms> | ||||
|   constexpr T underlying_type(strong::type<T, Tag, Ms...>*); | ||||
|  | ||||
| @ -68,6 +68,14 @@ | ||||
| .. autofunction:: get_validators | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autofunction:: write_file_on_exit | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autofunction:: write_file | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autofunction:: read_file | ||||
| ``` | ||||
| @ -87,7 +95,3 @@ | ||||
| ```{eval-rst} | ||||
| .. autofunction:: get_rotating_buffer_size | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autofunction:: set_numerical_check_tolerances | ||||
| ``` | ||||
| @ -123,7 +123,3 @@ The frontend API is `fully_shard` that can be called on a `module`: | ||||
| .. autoclass:: CPUOffloadPolicy | ||||
|     :members: | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autofunction:: share_comm_ctx | ||||
| ``` | ||||
|  | ||||
| @ -51,7 +51,7 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it. | ||||
| +----------------+-----+-----+-----+-----+-----+-----+-----+-----+ | ||||
| | reduce_scatter | ✓   | ✓   | ✘   | ✘   | ✘   | ✓   | ✘   | ✓   | | ||||
| +----------------+-----+-----+-----+-----+-----+-----+-----+-----+ | ||||
| | all_to_all     | ✘   | ✘   | ✓   | ?   | ✘   | ✓   | ✘   | ✓   | | ||||
| | all_to_all     | ✓   | ✓   | ✓   | ?   | ✘   | ✓   | ✘   | ✓   | | ||||
| +----------------+-----+-----+-----+-----+-----+-----+-----+-----+ | ||||
| | barrier        | ✓   | ✘   | ✓   | ?   | ✘   | ✓   | ✘   | ✓   | | ||||
| +----------------+-----+-----+-----+-----+-----+-----+-----+-----+ | ||||
|  | ||||
| @ -23,7 +23,6 @@ Submodules | ||||
|     flex_attention | ||||
|     bias | ||||
|     experimental | ||||
|     varlen | ||||
|  | ||||
| .. toctree:: | ||||
|     :hidden: | ||||
| @ -31,4 +30,3 @@ Submodules | ||||
|     nn.attention.flex_attention | ||||
|     nn.attention.bias | ||||
|     nn.attention.experimental | ||||
|     nn.attention.varlen | ||||
|  | ||||
| @ -1,17 +0,0 @@ | ||||
| ```{eval-rst} | ||||
| .. role:: hidden | ||||
|     :class: hidden-section | ||||
| ``` | ||||
|  | ||||
| # torch.nn.attention.varlen | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. automodule:: torch.nn.attention.varlen | ||||
| .. currentmodule:: torch.nn.attention.varlen | ||||
| ``` | ||||
| ```{eval-rst} | ||||
| .. autofunction:: varlen_attn | ||||
| ``` | ||||
| ```{eval-rst} | ||||
| .. autoclass:: AuxRequest | ||||
| ``` | ||||
| @ -228,4 +228,3 @@ Low-Precision functions | ||||
|     ScalingType | ||||
|     SwizzleType | ||||
|     scaled_mm | ||||
|     scaled_grouped_mm | ||||
|  | ||||
| @ -1,12 +1,14 @@ | ||||
| ```{eval-rst} | ||||
| .. currentmodule:: torch.compiler.config | ||||
|  | ||||
| ``` | ||||
|  | ||||
| # torch.compiler.config | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. automodule:: torch.compiler.config | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autodata:: torch.compiler.config.job_id | ||||
| ``` | ||||
|  | ||||
| @ -816,10 +816,6 @@ Operator Tags | ||||
| .. py:module:: torch.types | ||||
| .. py:module:: torch.version | ||||
|  | ||||
| .. Compiler configuration module - documented in torch.compiler.config.md | ||||
| .. py:module:: torch.compiler.config | ||||
|    :noindex: | ||||
|  | ||||
| .. Hidden aliases (e.g. torch.functional.broadcast_tensors()). We want `torch.broadcast_tensors()` to | ||||
|    be visible only. | ||||
| .. toctree:: | ||||
|  | ||||
| @ -10,7 +10,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS | ||||
|   ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp | ||||
|   ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp | ||||
|   ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp | ||||
|   ${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp | ||||
|   ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp | ||||
|   ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp | ||||
| ) | ||||
|  | ||||
| @ -1,76 +0,0 @@ | ||||
| #include <gtest/gtest.h> | ||||
|  | ||||
| #include <torch/headeronly/core/ScalarType.h> | ||||
|  | ||||
| TEST(TestScalarType, ScalarTypeToCPPTypeT) { | ||||
|   using torch::headeronly::ScalarType; | ||||
|   using torch::headeronly::impl::ScalarTypeToCPPTypeT; | ||||
|  | ||||
| #define DEFINE_CHECK(TYPE, SCALARTYPE) \ | ||||
|   EXPECT_EQ(typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); | ||||
|  | ||||
|   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK); | ||||
| #undef DEFINE_CHECK | ||||
| } | ||||
|  | ||||
| #define DEFINE_CHECK(TYPE, SCALARTYPE)                                       \ | ||||
|   {                                                                          \ | ||||
|     EXPECT_EQ(                                                               \ | ||||
|         typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); \ | ||||
|     count++;                                                                 \ | ||||
|   } | ||||
|  | ||||
| #define TEST_FORALL(M, EXPECTEDCOUNT, ...)               \ | ||||
|   TEST(TestScalarType, M) {                              \ | ||||
|     using torch::headeronly::ScalarType;                 \ | ||||
|     using torch::headeronly::impl::ScalarTypeToCPPTypeT; \ | ||||
|     int8_t count = 0;                                    \ | ||||
|     M(__VA_ARGS__ DEFINE_CHECK);                         \ | ||||
|     EXPECT_EQ(count, EXPECTEDCOUNT);                     \ | ||||
|   } | ||||
|  | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14) | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18) | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46) | ||||
| TEST_FORALL(AT_FORALL_INT_TYPES, 5) | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7) | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, ) | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND2, 9, Bool, Half, ) | ||||
| TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND3, 10, Bool, Half, ComplexFloat, ) | ||||
| TEST_FORALL( | ||||
|     AT_FORALL_SCALAR_TYPES_AND7, | ||||
|     14, | ||||
|     Bool, | ||||
|     Half, | ||||
|     ComplexHalf, | ||||
|     ComplexFloat, | ||||
|     ComplexDouble, | ||||
|     UInt16, | ||||
|     UInt32, ) | ||||
| TEST_FORALL(AT_FORALL_QINT_TYPES, 5) | ||||
| TEST_FORALL(AT_FORALL_FLOAT8_TYPES, 5) | ||||
| TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2) | ||||
|  | ||||
| #undef DEFINE_CHECK | ||||
| #undef TEST_FORALL | ||||
|  | ||||
| TEST(TestScalarType, toString) { | ||||
|   using torch::headeronly::ScalarType; | ||||
|  | ||||
| #define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name); | ||||
|   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK); | ||||
| #undef DEFINE_CHECK | ||||
| } | ||||
|  | ||||
| TEST(TestScalarType, operator_left_shift) { | ||||
|   using torch::headeronly::ScalarType; | ||||
|  | ||||
| #define DEFINE_CHECK(_, name)   \ | ||||
|   {                             \ | ||||
|     std::stringstream ss;       \ | ||||
|     ss << ScalarType::name;     \ | ||||
|     EXPECT_EQ(ss.str(), #name); \ | ||||
|   } | ||||
|   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK); | ||||
| #undef DEFINE_CHECK | ||||
| } | ||||
| @ -6,7 +6,7 @@ import functools | ||||
| import itertools | ||||
| import unittest | ||||
| from collections import defaultdict | ||||
| from collections.abc import Callable, Iterable | ||||
| from collections.abc import Iterable | ||||
| from typing import Any, Optional, Union | ||||
|  | ||||
| import torch | ||||
| @ -24,11 +24,6 @@ from torch.distributed.fsdp import ( | ||||
|     fully_shard, | ||||
|     OffloadPolicy, | ||||
|     register_fsdp_forward_method, | ||||
|     share_comm_ctx, | ||||
| ) | ||||
| from torch.distributed.fsdp._fully_shard._fsdp_collectives import ( | ||||
|     foreach_all_gather, | ||||
|     foreach_reduce, | ||||
| ) | ||||
| from torch.distributed.tensor import DTensor, init_device_mesh, Shard | ||||
| from torch.distributed.tensor.debug import CommDebugMode | ||||
| @ -44,8 +39,6 @@ from torch.testing._internal.common_fsdp import ( | ||||
|     MLP, | ||||
|     MLPStack, | ||||
|     patch_all_gather, | ||||
|     patch_foreach_all_gather, | ||||
|     patch_foreach_reduce, | ||||
|     patch_reduce_scatter, | ||||
| ) | ||||
| from torch.testing._internal.common_utils import ( | ||||
| @ -1494,116 +1487,6 @@ class TestFullyShardCustomForwardMethod(FSDPTest): | ||||
|         check_sharded_parity(self, ref_model, model) | ||||
|  | ||||
|  | ||||
| class TestFullyShardShareCommContext(FSDPTest): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return min(torch.get_device_module(device_type).device_count(), 2) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_share_comm_context(self): | ||||
|         torch.manual_seed(42) | ||||
|         n_layers = 3 | ||||
|         lin_dim = 16 | ||||
|         model = nn.Sequential( | ||||
|             *[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)] | ||||
|         ) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|         for layer in model: | ||||
|             fully_shard(layer) | ||||
|             layer._get_fsdp_state()._lazy_init() | ||||
|         share_comm_ctx(list(model)) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank + 1) | ||||
|         inp = torch.randn(4, 3, lin_dim, device=device_type.type) | ||||
|         ref_loss = ref_model(inp).sum() | ||||
|  | ||||
|         all_gather_streams = set() | ||||
|         reduce_scatter_streams = set() | ||||
|  | ||||
|         from torch.distributed.fsdp._fully_shard._fsdp_api import ( | ||||
|             AllGather, | ||||
|             ReduceScatter, | ||||
|         ) | ||||
|         from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam | ||||
|  | ||||
|         orig_foreach_all_gather = foreach_all_gather | ||||
|  | ||||
|         def foreach_all_gather_with_assert( | ||||
|             fsdp_params: list[FSDPParam], | ||||
|             group: dist.ProcessGroup, | ||||
|             async_op: bool, | ||||
|             all_gather_copy_in_stream: torch.Stream, | ||||
|             all_gather_stream: torch.Stream, | ||||
|             device: torch.device, | ||||
|             all_gather_comm: AllGather, | ||||
|         ): | ||||
|             nonlocal all_gather_streams | ||||
|             all_gather_streams.add(all_gather_stream) | ||||
|             return orig_foreach_all_gather( | ||||
|                 fsdp_params, | ||||
|                 group, | ||||
|                 async_op, | ||||
|                 all_gather_copy_in_stream, | ||||
|                 all_gather_stream, | ||||
|                 device, | ||||
|                 all_gather_comm, | ||||
|             ) | ||||
|  | ||||
|         orig_foreach_reduce = foreach_reduce | ||||
|  | ||||
|         @torch.no_grad() | ||||
|         def foreach_reduce_with_assert( | ||||
|             fsdp_params: list[FSDPParam], | ||||
|             unsharded_grads: list[torch.Tensor], | ||||
|             reduce_scatter_group: dist.ProcessGroup, | ||||
|             reduce_scatter_stream: torch.Stream, | ||||
|             reduce_scatter_comm: ReduceScatter, | ||||
|             orig_dtype: Optional[torch.dtype], | ||||
|             reduce_dtype: Optional[torch.dtype], | ||||
|             device: torch.device, | ||||
|             gradient_divide_factor: Optional[float], | ||||
|             all_reduce_group: Optional[dist.ProcessGroup],  # not `None` iff HSDP | ||||
|             all_reduce_stream: torch.Stream, | ||||
|             all_reduce_grads: bool, | ||||
|             partial_reduce_output: Optional[torch.Tensor],  # only used for HSDP | ||||
|             all_reduce_hook: Optional[Callable[[torch.Tensor], None]], | ||||
|             force_sum_reduction_for_comms: bool = False, | ||||
|         ): | ||||
|             nonlocal reduce_scatter_streams | ||||
|             reduce_scatter_streams.add(reduce_scatter_stream) | ||||
|             return orig_foreach_reduce( | ||||
|                 fsdp_params, | ||||
|                 unsharded_grads, | ||||
|                 reduce_scatter_group, | ||||
|                 reduce_scatter_stream, | ||||
|                 reduce_scatter_comm, | ||||
|                 orig_dtype, | ||||
|                 reduce_dtype, | ||||
|                 device, | ||||
|                 gradient_divide_factor, | ||||
|                 all_reduce_group, | ||||
|                 all_reduce_stream, | ||||
|                 all_reduce_grads, | ||||
|                 partial_reduce_output, | ||||
|                 all_reduce_hook, | ||||
|                 force_sum_reduction_for_comms, | ||||
|             ) | ||||
|  | ||||
|         with ( | ||||
|             patch_foreach_all_gather(foreach_all_gather_with_assert), | ||||
|             patch_foreach_reduce(foreach_reduce_with_assert), | ||||
|         ): | ||||
|             loss = model(inp).sum() | ||||
|             self.assertEqual(ref_loss, loss) | ||||
|             ref_loss.backward() | ||||
|             loss.backward() | ||||
|             for param in ref_model.parameters(): | ||||
|                 dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) | ||||
|         self.assertEqual(len(all_gather_streams), 1) | ||||
|         self.assertEqual(len(reduce_scatter_streams), 1) | ||||
|         check_sharded_parity(self, ref_model, model) | ||||
|  | ||||
|  | ||||
| class TestFullyShardWorldSize1(FSDPTest): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|  | ||||
| @ -559,7 +559,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|             FAIL = 138 | ||||
|             pc = start_processes( | ||||
|                 name="echo", | ||||
|                 entrypoint=bin("echo4.py"), | ||||
|                 entrypoint=bin("echo1.py"), | ||||
|                 args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")}, | ||||
|                 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, | ||||
|                 logs_specs=DefaultLogsSpecs( | ||||
|  | ||||
| @ -9,6 +9,7 @@ | ||||
| import argparse | ||||
| import os | ||||
| import sys | ||||
| import time | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| @ -23,5 +24,6 @@ if __name__ == "__main__": | ||||
|         print(f"exit {exitcode} from {rank}", file=sys.stderr) | ||||
|         sys.exit(exitcode) | ||||
|     else: | ||||
|         time.sleep(1000) | ||||
|         print(f"{args.msg} stdout from {rank}") | ||||
|         print(f"{args.msg} stderr from {rank}", file=sys.stderr) | ||||
|  | ||||
| @ -1,29 +0,0 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the BSD-style license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import argparse | ||||
| import os | ||||
| import sys | ||||
| import time | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="test binary, exits with exitcode") | ||||
|     parser.add_argument("--exitcode", type=int, default=0) | ||||
|     parser.add_argument("msg", type=str) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     rank = int(os.environ["RANK"]) | ||||
|     exitcode = args.exitcode | ||||
|     if exitcode != 0: | ||||
|         print(f"exit {exitcode} from {rank}", file=sys.stderr) | ||||
|         sys.exit(exitcode) | ||||
|     else: | ||||
|         time.sleep(1000) | ||||
|         print(f"{args.msg} stdout from {rank}") | ||||
|         print(f"{args.msg} stderr from {rank}", file=sys.stderr) | ||||
| @ -536,23 +536,6 @@ class TestScheduleLowering(TestCase): | ||||
|                 "compute": ["0F0", "0F1", "   ", "0B0", "0B1"], | ||||
|                 "comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"], | ||||
|             }, | ||||
|             { | ||||
|                 "compute": ["0F0", "0F1", "1F0", "1F1", "1B0", "1B1", "0B0", "0B1"], | ||||
|                 "comms": [ | ||||
|                     "0UNSHARD", | ||||
|                     "1UNSHARD", | ||||
|                     "0F0", | ||||
|                     "0F1", | ||||
|                     "1F0", | ||||
|                     "1F1", | ||||
|                     "1B0", | ||||
|                     "1B1", | ||||
|                     "1RESHARD", | ||||
|                     "0B0", | ||||
|                     "0B1", | ||||
|                     "0RESHARD", | ||||
|                 ], | ||||
|             }, | ||||
|         ], | ||||
|     ) | ||||
|     def test_unshard_reshard(self, test_info): | ||||
|  | ||||
| @ -1020,19 +1020,6 @@ class DTensorMeshTest(DTensorTestBase): | ||||
|             self.fail("Unexpected ValueError raised with run_check=False") | ||||
|  | ||||
|  | ||||
| DTensorMeshTestWithLocalTensor = create_local_tensor_test_class( | ||||
|     DTensorMeshTest, | ||||
|     skipped_tests=[ | ||||
|         # Submeshes are not supported by local tensor mode | ||||
|         "test_from_local_sub_mesh", | ||||
|         "test_default_value_sub_mesh", | ||||
|         "test_redistribute_sub_mesh", | ||||
|         # Local tensor mode doesn't support tensors of different types on different ranks | ||||
|         "test_metadata_consistency_check", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TestDTensorPlacementTypes(DTensorTestBase): | ||||
|     @property | ||||
|     def world_size(self): | ||||
| @ -1099,11 +1086,6 @@ class TestDTensorPlacementTypes(DTensorTestBase): | ||||
|                 assert_array_equal(expected_is_tensor_empty, is_tensor_empty) | ||||
|  | ||||
|  | ||||
| TestDTensorPlacementTypesWithLocalTensor = create_local_tensor_test_class( | ||||
|     TestDTensorPlacementTypes, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TestDTensorSpec(DTensorTestBase): | ||||
|     @property | ||||
|     def world_size(self): | ||||
| @ -1283,9 +1265,5 @@ class TestDTensorSpec(DTensorTestBase): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| TestDTensorSpecWithLocalTensor = create_local_tensor_test_class( | ||||
|     TestDTensorSpec, | ||||
| ) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -239,7 +239,9 @@ class DTensorExportTest(TestCase): | ||||
|                 "view_9", | ||||
|                 "t_15", | ||||
|                 "detach", | ||||
|                 "detach_3", | ||||
|                 "detach_1", | ||||
|                 "detach_6", | ||||
|                 "detach_7", | ||||
|                 "threshold_backward_1", | ||||
|                 "t_16", | ||||
|                 "mm_6", | ||||
| @ -257,8 +259,10 @@ class DTensorExportTest(TestCase): | ||||
|                 "sum_1", | ||||
|                 "view_7", | ||||
|                 "t_7", | ||||
|                 "detach_1", | ||||
|                 "detach_2", | ||||
|                 "detach_3", | ||||
|                 "detach_4", | ||||
|                 "detach_5", | ||||
|                 "threshold_backward", | ||||
|                 "mm_2", | ||||
|                 "t_9", | ||||
|  | ||||
| @ -20,7 +20,6 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall | ||||
| from torch.distributed.tensor._dtensor_spec import ShardOrderEntry | ||||
| from torch.distributed.tensor._redistribute import redistribute_local_tensor | ||||
| from torch.distributed.tensor.debug import CommDebugMode | ||||
| from torch.distributed.tensor.placement_types import _StridedShard | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
|     parametrize, | ||||
| @ -1146,22 +1145,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase): | ||||
|             sharded_dt, mesh, tgt_placement, shard_order=None | ||||
|         ) | ||||
|  | ||||
|     @with_comms | ||||
|     def test_shard_order_same_data_as_strided_shard(self): | ||||
|         device_mesh = init_device_mesh(self.device_type, (4, 2)) | ||||
|         x = torch.randn(8, 4, device=self.device_type) | ||||
|         # specify right-to-left order use _StridedShard | ||||
|         strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)] | ||||
|         x_strided_dt = distribute_tensor(x, device_mesh, strided_placement) | ||||
|         # specify right-to-left order use ordered shard | ||||
|         x_ordered_dt = self.distribute_tensor( | ||||
|             x, | ||||
|             device_mesh, | ||||
|             placements=[Shard(0), Shard(0)], | ||||
|             shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),), | ||||
|         ) | ||||
|         self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -70,8 +70,6 @@ def get_patches(): | ||||
|         "force_disable_caches": True, | ||||
|         # Messes up existing test strings | ||||
|         "test_configs.aten_fx_overlap_insert_overlap_deps": False, | ||||
|         # interferes with testing, / custom estimation | ||||
|         "test_configs.assume_bucketing_reduces_latency": False, | ||||
|     } | ||||
|  | ||||
|  | ||||
| @ -366,8 +364,6 @@ def get_bucket_patches(compute_multiplier=1.0): | ||||
|         "force_disable_caches": True, | ||||
|         # messes up test strings | ||||
|         "test_configs.aten_fx_overlap_insert_overlap_deps": False, | ||||
|         # interferes with testing, / custom estimation | ||||
|         "test_configs.assume_bucketing_reduces_latency": False, | ||||
|     } | ||||
|  | ||||
|  | ||||
| @ -583,7 +579,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc): | ||||
|  | ||||
|     @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") | ||||
|     @torch._inductor.config.patch(get_bucket_patches(2.0)) | ||||
|     def test_bucketing_split_for_overlap_blocking_no_deps(self): | ||||
|     def test_bucketing_split_for_overlap_blocking(self): | ||||
|         """Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute.""" | ||||
|  | ||||
|         def func(a, b, c, d, *, ranks): | ||||
|  | ||||
| @ -7,13 +7,8 @@ from dataclasses import dataclass | ||||
|  | ||||
| import torch | ||||
| from torch.multiprocessing.reductions import reduce_tensor | ||||
| from torch.testing._internal.common_cuda import SM100OrLater | ||||
| from torch.testing._internal.common_distributed import MultiProcContinuousTest | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     requires_cuda_p2p_access, | ||||
|     run_tests, | ||||
|     skip_but_pass_in_sandcastle_if, | ||||
| ) | ||||
| from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests | ||||
|  | ||||
|  | ||||
| # So that tests are written in device-agnostic way | ||||
| @ -64,10 +59,6 @@ class CupyAsTensorTest(MultiProcContinuousTest): | ||||
|     def device(self) -> torch.device: | ||||
|         return torch.device(device_type, self.rank) | ||||
|  | ||||
|     @skip_but_pass_in_sandcastle_if( | ||||
|         SM100OrLater, | ||||
|         "Fails if ran in docker environment without privileged access (https://github.com/pytorch/pytorch/issues/165170)", | ||||
|     ) | ||||
|     def test_cupy_as_tensor(self) -> None: | ||||
|         """ | ||||
|         Test that torch.as_tensor works for cupy array interface | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| import os | ||||
| import unittest | ||||
| from datetime import timedelta | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| @ -41,13 +40,6 @@ from torch.utils._typing_utils import not_none | ||||
| device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" | ||||
| device_count = torch.accelerator.device_count() | ||||
|  | ||||
| try: | ||||
|     import torch._C._distributed_c10d.ProcessGroupNCCL | ||||
|  | ||||
|     _NCCL_AVAILABLE = True | ||||
| except ImportError: | ||||
|     _NCCL_AVAILABLE = False | ||||
|  | ||||
|  | ||||
| def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1): | ||||
|     os.environ["MASTER_ADDR"] = addr | ||||
| @ -970,85 +962,6 @@ class TestDeviceMeshGetItem(DTensorTestBase): | ||||
|         # check flattened mesh dependency | ||||
|         self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d) | ||||
|  | ||||
|     @with_comms | ||||
|     def test_unflatten_mesh_2d(self): | ||||
|         mesh_shape = (4, 2) | ||||
|         mesh_dim_names = ("dp", "tp") | ||||
|         mesh_2d = init_device_mesh( | ||||
|             self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names | ||||
|         ) | ||||
|         unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate")) | ||||
|         self.assertEqual( | ||||
|             unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"] | ||||
|         ) | ||||
|         self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh) | ||||
|         self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group()) | ||||
|  | ||||
|         # Not supporting slicing out unflatten dim name from root mesh. | ||||
|         with self.assertRaises(KeyError): | ||||
|             self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh) | ||||
|  | ||||
|     @with_comms | ||||
|     def test_unflatten_mesh_3d(self): | ||||
|         # Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP). | ||||
|         global_mesh = init_device_mesh( | ||||
|             self.device_type, | ||||
|             (8,), | ||||
|             mesh_dim_names=("world",), | ||||
|         ) | ||||
|         non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp")) | ||||
|         ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp")) | ||||
|         self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh) | ||||
|         self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh) | ||||
|         mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp")) | ||||
|         unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate")) | ||||
|         self.assertEqual( | ||||
|             unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"] | ||||
|         ) | ||||
|         self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh) | ||||
|         self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group()) | ||||
|         self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh) | ||||
|         self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group()) | ||||
|  | ||||
|         # Test unflatten with backend override set. | ||||
|         if not _NCCL_AVAILABLE: | ||||
|             return | ||||
|         opts = dist.ProcessGroupNCCL.Options() | ||||
|         opts._timeout = timedelta(seconds=30) | ||||
|         mesh_2d = global_mesh._unflatten( | ||||
|             0, | ||||
|             (1, 8), | ||||
|             ("pp", "spmd"), | ||||
|             backend_override={"pp": "fake", "spmd": ("nccl", opts)}, | ||||
|         ) | ||||
|         opts = dist.ProcessGroupNCCL.Options() | ||||
|         opts._timeout = timedelta(seconds=60) | ||||
|         mesh_4d = mesh_2d._unflatten( | ||||
|             1, | ||||
|             (2, 2, 2), | ||||
|             ("dp", "cp", "tp"), | ||||
|             backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)}, | ||||
|         ) | ||||
|         self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom") | ||||
|         spmd_pg = mesh_2d["spmd"].get_group() | ||||
|         self.assertEqual(spmd_pg._get_backend_name(), "nccl") | ||||
|         w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank)) | ||||
|         self.assertTrue( | ||||
|             spmd_pg._get_backend( | ||||
|                 torch.device(f"cuda:{self.rank}") | ||||
|             )._verify_work_timeout(w, timedelta(seconds=30)) | ||||
|         ) | ||||
|         w.wait() | ||||
|         tp_pg = mesh_4d["tp"].get_group() | ||||
|         self.assertEqual(tp_pg._get_backend_name(), "nccl") | ||||
|         w = tp_pg.allreduce(torch.rand(10).cuda(self.rank)) | ||||
|         self.assertTrue( | ||||
|             tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout( | ||||
|                 w, timedelta(seconds=60) | ||||
|             ) | ||||
|         ) | ||||
|         w.wait() | ||||
|  | ||||
|     @with_comms | ||||
|     def test_reconstruct_mesh_with_flatten_dim(self): | ||||
|         mesh_3d = init_device_mesh( | ||||
|  | ||||
| @ -273,7 +273,12 @@ class TestFakePG(TestCase): | ||||
|                     kwargs = {} | ||||
|                 return func(*args, **kwargs) | ||||
|  | ||||
|         with self.assertRaisesRegex(TypeError, r"No constructor defined"): | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, | ||||
|             r"FakeProcessGroup cannot be constructed directly\. " | ||||
|             r"Use torch\.distributed\.init_process_group\(backend='fake'\) instead to ensure " | ||||
|             r"proper dispatch system integration\.", | ||||
|         ): | ||||
|             fake_pg = FakeProcessGroup(rank=0, world_size=3) | ||||
|  | ||||
|             with SimpleTensorMode(): | ||||
|  | ||||
| @ -12,7 +12,6 @@ import torch.distributed._symmetric_memory as symm_mem | ||||
| import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem | ||||
| from torch._inductor.runtime.triton_compat import triton | ||||
| from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem | ||||
| from torch.testing._internal.common_cuda import SM100OrLater | ||||
| from torch.testing._internal.common_distributed import MultiProcContinuousTest | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
| @ -265,10 +264,6 @@ def my_reduce_kernel( | ||||
|     nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation) | ||||
|  | ||||
|  | ||||
| @skip_but_pass_in_sandcastle_if( | ||||
|     SM100OrLater, | ||||
|     "Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897", | ||||
| ) | ||||
| @instantiate_parametrized_tests | ||||
| class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def _init_device(self) -> None: | ||||
|  | ||||
| @ -52,9 +52,6 @@ from torch.testing._internal.common_utils import ( | ||||
|  | ||||
| test_contexts = [nullcontext, _test_mode] | ||||
|  | ||||
| # Set environment variable to disable multicast for all tests in this module | ||||
| os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1" | ||||
|  | ||||
| # So that tests are written in device-agnostic way | ||||
| device_type = "cuda" | ||||
| device_module = torch.get_device_module(device_type) | ||||
| @ -549,10 +546,6 @@ class AsyncTPTest(MultiProcContinuousTest): | ||||
|     @skipUnless(SM89OrLater, "Requires compute capability >= 8.9") | ||||
|     @parametrize("scatter_dim", [0, 1]) | ||||
|     @parametrize("rowwise", [True, False]) | ||||
|     @skipIf( | ||||
|         SM100OrLater, | ||||
|         "https://github.com/pytorch/pytorch/issues/162940", | ||||
|     ) | ||||
|     def test_fused_scaled_matmul_reduce_scatter( | ||||
|         self, scatter_dim: int, rowwise: bool | ||||
|     ) -> None: | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py | ||||
| index e599b02c17d..057b6ec01b9 100644 | ||||
| index e599b02c17d..750d7a84fb4 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_baseexception.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_baseexception.py | ||||
| @@ -1,10 +1,64 @@ | ||||
| @ -78,27 +78,7 @@ index e599b02c17d..057b6ec01b9 100644 | ||||
|          self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set) | ||||
|   | ||||
|      interface_tests = ("length", "args", "str", "repr") | ||||
| @@ -122,12 +173,13 @@ class ExceptionClassTests(unittest.TestCase): | ||||
|          # in PyObject_SetAttr. | ||||
|          import gc | ||||
|          d = {} | ||||
| -        class HashThisKeyWillClearTheDict(str): | ||||
| -            def __hash__(self) -> int: | ||||
| -                d.clear() | ||||
| -                return super().__hash__() | ||||
| -        class Value(str): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class HashThisKeyWillClearTheDict(str): | ||||
| +                def __hash__(self) -> int: | ||||
| +                    d.clear() | ||||
| +                    return super().__hash__() | ||||
| +            class Value(str): | ||||
| +                pass | ||||
|          exc = Exception() | ||||
|   | ||||
|          d[HashThisKeyWillClearTheDict()] = Value()  # refcount of Value() is 1 now | ||||
| @@ -142,7 +194,7 @@ class ExceptionClassTests(unittest.TestCase): | ||||
| @@ -142,7 +193,7 @@ class ExceptionClassTests(unittest.TestCase): | ||||
|          gc.collect() | ||||
|   | ||||
|   | ||||
| @ -107,31 +87,7 @@ index e599b02c17d..057b6ec01b9 100644 | ||||
|   | ||||
|      """Test usage of exceptions""" | ||||
|   | ||||
| @@ -182,8 +234,9 @@ class UsageTests(unittest.TestCase): | ||||
|          # BaseException; the ability was not possible until BaseException's | ||||
|          # introduction so no need to support new-style objects that do not | ||||
|          # inherit from it. | ||||
| -        class NewStyleClass(object): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class NewStyleClass(object): | ||||
| +                pass | ||||
|          self.raise_fails(NewStyleClass) | ||||
|          self.raise_fails(NewStyleClass()) | ||||
|   | ||||
| @@ -194,8 +247,9 @@ class UsageTests(unittest.TestCase): | ||||
|      def test_catch_non_BaseException(self): | ||||
|          # Trying to catch an object that does not inherit from BaseException | ||||
|          # is not allowed. | ||||
| -        class NonBaseException(object): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class NonBaseException(object): | ||||
| +                pass | ||||
|          self.catch_fails(NonBaseException) | ||||
|          self.catch_fails(NonBaseException()) | ||||
|   | ||||
| @@ -208,5 +262,5 @@ class UsageTests(unittest.TestCase): | ||||
| @@ -208,5 +259,5 @@ class UsageTests(unittest.TestCase): | ||||
|          self.catch_fails("spam") | ||||
|   | ||||
|   | ||||
|  | ||||
| @ -173,13 +173,12 @@ class ExceptionClassTests(__TestCase): | ||||
|         # in PyObject_SetAttr. | ||||
|         import gc | ||||
|         d = {} | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class HashThisKeyWillClearTheDict(str): | ||||
|                 def __hash__(self) -> int: | ||||
|                     d.clear() | ||||
|                     return super().__hash__() | ||||
|             class Value(str): | ||||
|                 pass | ||||
|         class HashThisKeyWillClearTheDict(str): | ||||
|             def __hash__(self) -> int: | ||||
|                 d.clear() | ||||
|                 return super().__hash__() | ||||
|         class Value(str): | ||||
|             pass | ||||
|         exc = Exception() | ||||
|  | ||||
|         d[HashThisKeyWillClearTheDict()] = Value()  # refcount of Value() is 1 now | ||||
| @ -234,9 +233,8 @@ class UsageTests(__TestCase): | ||||
|         # BaseException; the ability was not possible until BaseException's | ||||
|         # introduction so no need to support new-style objects that do not | ||||
|         # inherit from it. | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class NewStyleClass(object): | ||||
|                 pass | ||||
|         class NewStyleClass(object): | ||||
|             pass | ||||
|         self.raise_fails(NewStyleClass) | ||||
|         self.raise_fails(NewStyleClass()) | ||||
|  | ||||
| @ -247,9 +245,8 @@ class UsageTests(__TestCase): | ||||
|     def test_catch_non_BaseException(self): | ||||
|         # Trying to catch an object that does not inherit from BaseException | ||||
|         # is not allowed. | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class NonBaseException(object): | ||||
|                 pass | ||||
|         class NonBaseException(object): | ||||
|             pass | ||||
|         self.catch_fails(NonBaseException) | ||||
|         self.catch_fails(NonBaseException()) | ||||
|  | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py | ||||
| index c91f6662948..3a62dec411c 100644 | ||||
| index c91f6662948..0ded70db3c7 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_exceptions.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_exceptions.py | ||||
| @@ -1,3 +1,59 @@ | ||||
| @ -71,305 +71,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|   | ||||
|      def raise_catch(self, exc, excname): | ||||
|          with self.subTest(exc=exc, excname=excname): | ||||
| @@ -343,12 +399,13 @@ class ExceptionTests(unittest.TestCase): | ||||
|          # test that setting an exception at the C level works even if the | ||||
|          # exception object can't be constructed. | ||||
|   | ||||
| -        class BadException(Exception): | ||||
| -            def __init__(self_): | ||||
| -                raise RuntimeError("can't instantiate BadException") | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class BadException(Exception): | ||||
| +                def __init__(self_): | ||||
| +                    raise RuntimeError("can't instantiate BadException") | ||||
|   | ||||
| -        class InvalidException: | ||||
| -            pass | ||||
| +            class InvalidException: | ||||
| +                pass | ||||
|   | ||||
|          @unittest.skipIf(_testcapi is None, "requires _testcapi") | ||||
|          def test_capi1(): | ||||
| @@ -636,8 +693,9 @@ class ExceptionTests(unittest.TestCase): | ||||
|          self.assertIsInstance(e, IndexError) | ||||
|          self.assertEqual(e.__traceback__, tb) | ||||
|   | ||||
| -        class MyException(Exception): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(Exception): | ||||
| +                pass | ||||
|   | ||||
|          e = MyException().with_traceback(tb) | ||||
|          self.assertIsInstance(e, MyException) | ||||
| @@ -696,8 +754,9 @@ class ExceptionTests(unittest.TestCase): | ||||
|          self.assertIsNone(e.__context__) | ||||
|          self.assertIsNone(e.__cause__) | ||||
|   | ||||
| -        class MyException(OSError): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(OSError): | ||||
| +                pass | ||||
|   | ||||
|          e = MyException() | ||||
|          self.assertIsNone(e.__context__) | ||||
| @@ -726,10 +785,11 @@ class ExceptionTests(unittest.TestCase): | ||||
|          # but user-defined subclasses can if they want | ||||
|          self.assertRaises(TypeError, BaseException, a=1) | ||||
|   | ||||
| -        class DerivedException(BaseException): | ||||
| -            def __init__(self, fancy_arg): | ||||
| -                BaseException.__init__(self) | ||||
| -                self.fancy_arg = fancy_arg | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class DerivedException(BaseException): | ||||
| +                def __init__(self, fancy_arg): | ||||
| +                    BaseException.__init__(self) | ||||
| +                    self.fancy_arg = fancy_arg | ||||
|   | ||||
|          x = DerivedException(fancy_arg=42) | ||||
|          self.assertEqual(x.fancy_arg, 42) | ||||
| @@ -779,11 +839,12 @@ class ExceptionTests(unittest.TestCase): | ||||
|          # Make sure exception state is cleaned up as soon as the except | ||||
|          # block is left. See #2507 | ||||
|   | ||||
| -        class MyException(Exception): | ||||
| -            def __init__(self, obj): | ||||
| -                self.obj = obj | ||||
| -        class MyObj: | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(Exception): | ||||
| +                def __init__(self, obj): | ||||
| +                    self.obj = obj | ||||
| +            class MyObj: | ||||
| +                pass | ||||
|   | ||||
|          def inner_raising_func(): | ||||
|              # Create some references in exception value and traceback | ||||
| @@ -881,11 +942,12 @@ class ExceptionTests(unittest.TestCase): | ||||
|          self.assertIsNone(obj) | ||||
|   | ||||
|          # Inside an exception-silencing "with" block | ||||
| -        class Context: | ||||
| -            def __enter__(self): | ||||
| -                return self | ||||
| -            def __exit__ (self, exc_type, exc_value, exc_tb): | ||||
| -                return True | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class Context: | ||||
| +                def __enter__(self): | ||||
| +                    return self | ||||
| +                def __exit__ (self, exc_type, exc_value, exc_tb): | ||||
| +                    return True | ||||
|          obj = MyObj() | ||||
|          wr = weakref.ref(obj) | ||||
|          with Context(): | ||||
| @@ -1027,11 +1089,12 @@ class ExceptionTests(unittest.TestCase): | ||||
|      def _check_generator_cleanup_exc_state(self, testfunc): | ||||
|          # Issue #12791: exception state is cleaned up as soon as a generator | ||||
|          # is closed (reference cycles are broken). | ||||
| -        class MyException(Exception): | ||||
| -            def __init__(self, obj): | ||||
| -                self.obj = obj | ||||
| -        class MyObj: | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(Exception): | ||||
| +                def __init__(self, obj): | ||||
| +                    self.obj = obj | ||||
| +            class MyObj: | ||||
| +                pass | ||||
|   | ||||
|          def raising_gen(): | ||||
|              try: | ||||
| @@ -1090,10 +1153,11 @@ class ExceptionTests(unittest.TestCase): | ||||
|      def test_3114(self): | ||||
|          # Bug #3114: in its destructor, MyObject retrieves a pointer to | ||||
|          # obsolete and/or deallocated objects. | ||||
| -        class MyObject: | ||||
| -            def __del__(self): | ||||
| -                nonlocal e | ||||
| -                e = sys.exception() | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyObject: | ||||
| +                def __del__(self): | ||||
| +                    nonlocal e | ||||
| +                    e = sys.exception() | ||||
|          e = () | ||||
|          try: | ||||
|              raise Exception(MyObject()) | ||||
| @@ -1103,12 +1167,13 @@ class ExceptionTests(unittest.TestCase): | ||||
|          self.assertIsNone(e) | ||||
|   | ||||
|      def test_raise_does_not_create_context_chain_cycle(self): | ||||
| -        class A(Exception): | ||||
| -            pass | ||||
| -        class B(Exception): | ||||
| -            pass | ||||
| -        class C(Exception): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class A(Exception): | ||||
| +                pass | ||||
| +            class B(Exception): | ||||
| +                pass | ||||
| +            class C(Exception): | ||||
| +                pass | ||||
|   | ||||
|          # Create a context chain: | ||||
|          # C -> B -> A | ||||
| @@ -1164,12 +1229,13 @@ class ExceptionTests(unittest.TestCase): | ||||
|      def test_no_hang_on_context_chain_cycle2(self): | ||||
|          # See issue 25782. Cycle at head of context chain. | ||||
|   | ||||
| -        class A(Exception): | ||||
| -            pass | ||||
| -        class B(Exception): | ||||
| -            pass | ||||
| -        class C(Exception): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class A(Exception): | ||||
| +                pass | ||||
| +            class B(Exception): | ||||
| +                pass | ||||
| +            class C(Exception): | ||||
| +                pass | ||||
|   | ||||
|          # Context cycle: | ||||
|          # +-----------+ | ||||
| @@ -1200,16 +1266,17 @@ class ExceptionTests(unittest.TestCase): | ||||
|      def test_no_hang_on_context_chain_cycle3(self): | ||||
|          # See issue 25782. Longer context chain with cycle. | ||||
|   | ||||
| -        class A(Exception): | ||||
| -            pass | ||||
| -        class B(Exception): | ||||
| -            pass | ||||
| -        class C(Exception): | ||||
| -            pass | ||||
| -        class D(Exception): | ||||
| -            pass | ||||
| -        class E(Exception): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class A(Exception): | ||||
| +                pass | ||||
| +            class B(Exception): | ||||
| +                pass | ||||
| +            class C(Exception): | ||||
| +                pass | ||||
| +            class D(Exception): | ||||
| +                pass | ||||
| +            class E(Exception): | ||||
| +                pass | ||||
|   | ||||
|          # Context cycle: | ||||
|          #             +-----------+ | ||||
| @@ -1364,11 +1431,12 @@ class ExceptionTests(unittest.TestCase): | ||||
|      def test_badisinstance(self): | ||||
|          # Bug #2542: if issubclass(e, MyException) raises an exception, | ||||
|          # it should be ignored | ||||
| -        class Meta(type): | ||||
| -            def __subclasscheck__(cls, subclass): | ||||
| -                raise ValueError() | ||||
| -        class MyException(Exception, metaclass=Meta): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class Meta(type): | ||||
| +                def __subclasscheck__(cls, subclass): | ||||
| +                    raise ValueError() | ||||
| +            class MyException(Exception, metaclass=Meta): | ||||
| +                pass | ||||
|   | ||||
|          with captured_stderr() as stderr: | ||||
|              try: | ||||
| @@ -1602,8 +1670,9 @@ class ExceptionTests(unittest.TestCase): | ||||
|          self.assertTrue(issubclass(error3, error2)) | ||||
|   | ||||
|          # test with explicit base tuple | ||||
| -        class C(object): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class C(object): | ||||
| +                pass | ||||
|          error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4, | ||||
|                                                     (error3, C)) | ||||
|          self.assertTrue(issubclass(error4, error3)) | ||||
| @@ -1623,8 +1692,9 @@ class ExceptionTests(unittest.TestCase): | ||||
|          # Issue #5437: preallocated MemoryError instances should not keep | ||||
|          # traceback objects alive. | ||||
|          from _testcapi import raise_memoryerror | ||||
| -        class C: | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class C: | ||||
| +                pass | ||||
|          wr = None | ||||
|          def inner(): | ||||
|              nonlocal wr | ||||
| @@ -1644,8 +1714,9 @@ class ExceptionTests(unittest.TestCase): | ||||
|      @no_tracing | ||||
|      def test_recursion_error_cleanup(self): | ||||
|          # Same test as above, but with "recursion exceeded" errors | ||||
| -        class C: | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class C: | ||||
| +                pass | ||||
|          wr = None | ||||
|          def inner(): | ||||
|              nonlocal wr | ||||
| @@ -1670,11 +1741,12 @@ class ExceptionTests(unittest.TestCase): | ||||
|   | ||||
|      def test_unraisable(self): | ||||
|          # Issue #22836: PyErr_WriteUnraisable() should give sensible reports | ||||
| -        class BrokenDel: | ||||
| -            def __del__(self): | ||||
| -                exc = ValueError("del is broken") | ||||
| -                # The following line is included in the traceback report: | ||||
| -                raise exc | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class BrokenDel: | ||||
| +                def __del__(self): | ||||
| +                    exc = ValueError("del is broken") | ||||
| +                    # The following line is included in the traceback report: | ||||
| +                    raise exc | ||||
|   | ||||
|          obj = BrokenDel() | ||||
|          with support.catch_unraisable_exception() as cm: | ||||
| @@ -1728,11 +1800,12 @@ class ExceptionTests(unittest.TestCase): | ||||
|   | ||||
|      def test_yield_in_nested_try_excepts(self): | ||||
|          #Issue #25612 | ||||
| -        class MainError(Exception): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MainError(Exception): | ||||
| +                pass | ||||
|   | ||||
| -        class SubError(Exception): | ||||
| -            pass | ||||
| +            class SubError(Exception): | ||||
| +                pass | ||||
|   | ||||
|          def main(): | ||||
|              try: | ||||
| @@ -1807,8 +1880,9 @@ class ExceptionTests(unittest.TestCase): | ||||
|          # subclass object. Finally, it checks that creating a new MemoryError | ||||
|          # succeeds, proving that the freelist is not corrupted. | ||||
|   | ||||
| -        class TestException(MemoryError): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class TestException(MemoryError): | ||||
| +                pass | ||||
|   | ||||
|          try: | ||||
|              raise MemoryError | ||||
| @@ -1844,7 +1918,7 @@ class ExceptionTests(unittest.TestCase): | ||||
| @@ -1844,7 +1900,7 @@ class ExceptionTests(unittest.TestCase): | ||||
|          self.assertIn(b'MemoryError', err) | ||||
|   | ||||
|   | ||||
| @ -378,18 +80,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|      def test_name_error_has_name(self): | ||||
|          try: | ||||
|              bluch | ||||
| @@ -1886,15 +1960,16 @@ class NameErrorTests(unittest.TestCase): | ||||
|   | ||||
|      def test_gh_111654(self): | ||||
|          def f(): | ||||
| -            class TestClass: | ||||
| -                TestClass | ||||
| +            with torch._dynamo.error_on_graph_break(False): | ||||
| +                class TestClass: | ||||
| +                    TestClass | ||||
|   | ||||
|          self.assertRaises(NameError, f) | ||||
|   | ||||
| @@ -1894,7 +1950,7 @@ class NameErrorTests(unittest.TestCase): | ||||
|      # Note: name suggestion tests live in `test_traceback`. | ||||
|   | ||||
|   | ||||
| @ -398,33 +89,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|      def test_attributes(self): | ||||
|          # Setting 'attr' should not be a problem. | ||||
|          exc = AttributeError('Ouch!') | ||||
| @@ -1907,8 +1982,9 @@ class AttributeErrorTests(unittest.TestCase): | ||||
|          self.assertIs(exc.obj, sentinel) | ||||
|   | ||||
|      def test_getattr_has_name_and_obj(self): | ||||
| -        class A: | ||||
| -            blech = None | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class A: | ||||
| +                blech = None | ||||
|   | ||||
|          obj = A() | ||||
|          try: | ||||
| @@ -1923,9 +1999,10 @@ class AttributeErrorTests(unittest.TestCase): | ||||
|              self.assertEqual(obj, exc.obj) | ||||
|   | ||||
|      def test_getattr_has_name_and_obj_for_method(self): | ||||
| -        class A: | ||||
| -            def blech(self): | ||||
| -                return | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class A: | ||||
| +                def blech(self): | ||||
| +                    return | ||||
|   | ||||
|          obj = A() | ||||
|          try: | ||||
| @@ -1937,7 +2014,7 @@ class AttributeErrorTests(unittest.TestCase): | ||||
| @@ -1937,7 +1993,7 @@ class AttributeErrorTests(unittest.TestCase): | ||||
|      # Note: name suggestion tests live in `test_traceback`. | ||||
|   | ||||
|   | ||||
| @ -433,7 +98,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|   | ||||
|      def test_attributes(self): | ||||
|          # Setting 'name' and 'path' should not be a problem. | ||||
| @@ -2024,7 +2101,7 @@ def run_script(source): | ||||
| @@ -2024,7 +2080,7 @@ def run_script(source): | ||||
|      _rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN) | ||||
|      return err.decode('utf-8').splitlines() | ||||
|   | ||||
| @ -442,7 +107,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|      def tearDown(self): | ||||
|          unlink(TESTFN) | ||||
|   | ||||
| @@ -2159,7 +2236,7 @@ class AssertionErrorTests(unittest.TestCase): | ||||
| @@ -2159,7 +2215,7 @@ class AssertionErrorTests(unittest.TestCase): | ||||
|   | ||||
|   | ||||
|  @support.force_not_colorized_test_class | ||||
| @ -451,19 +116,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|      maxDiff = None | ||||
|   | ||||
|      @force_not_colorized | ||||
| @@ -2254,8 +2331,9 @@ class SyntaxErrorTests(unittest.TestCase): | ||||
|                      the_exception = exc | ||||
|   | ||||
|      def test_subclass(self): | ||||
| -        class MySyntaxError(SyntaxError): | ||||
| -            pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MySyntaxError(SyntaxError): | ||||
| +                pass | ||||
|   | ||||
|          try: | ||||
|              raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7)) | ||||
| @@ -2290,6 +2368,7 @@ class SyntaxErrorTests(unittest.TestCase): | ||||
| @@ -2290,6 +2346,7 @@ class SyntaxErrorTests(unittest.TestCase): | ||||
|          err = run_script(b"\x89") | ||||
|          self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1]) | ||||
|   | ||||
| @ -471,7 +124,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|      def test_string_source(self): | ||||
|          def try_compile(source): | ||||
|              with self.assertRaises(SyntaxError) as cm: | ||||
| @@ -2405,7 +2484,7 @@ class SyntaxErrorTests(unittest.TestCase): | ||||
| @@ -2405,7 +2462,7 @@ class SyntaxErrorTests(unittest.TestCase): | ||||
|          self.assertRaises(TypeError, SyntaxError, "bad bad", args) | ||||
|   | ||||
|   | ||||
| @ -480,7 +133,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|      def test_except_star_invalid_exception_type(self): | ||||
|          with self.assertRaises(TypeError): | ||||
|              try: | ||||
| @@ -2420,7 +2499,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase): | ||||
| @@ -2420,7 +2477,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase): | ||||
|                  pass | ||||
|   | ||||
|   | ||||
| @ -489,42 +142,7 @@ index c91f6662948..3a62dec411c 100644 | ||||
|   | ||||
|      def lineno_after_raise(self, f, *expected): | ||||
|          try: | ||||
| @@ -2499,11 +2578,12 @@ class PEP626Tests(unittest.TestCase): | ||||
|          self.lineno_after_raise(in_finally_except, 4) | ||||
|   | ||||
|      def test_lineno_after_with(self): | ||||
| -        class Noop: | ||||
| -            def __enter__(self): | ||||
| -                return self | ||||
| -            def __exit__(self, *args): | ||||
| -                pass | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class Noop: | ||||
| +                def __enter__(self): | ||||
| +                    return self | ||||
| +                def __exit__(self, *args): | ||||
| +                    pass | ||||
|          def after_with(): | ||||
|              with Noop(): | ||||
|                  1/0 | ||||
| @@ -2518,16 +2598,17 @@ class PEP626Tests(unittest.TestCase): | ||||
|          self.lineno_after_raise(f, None) | ||||
|   | ||||
|      def test_lineno_after_raise_in_with_exit(self): | ||||
| -        class ExitFails: | ||||
| -            def __enter__(self): | ||||
| -                return self | ||||
| -            def __exit__(self, *args): | ||||
| -                raise ValueError | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class ExitFails: | ||||
| +                def __enter__(self): | ||||
| +                    return self | ||||
| +                def __exit__(self, *args): | ||||
| +                    raise ValueError | ||||
|   | ||||
|          def after_with(): | ||||
|              with ExitFails(): | ||||
| @@ -2529,5 +2586,5 @@ class PEP626Tests(unittest.TestCase): | ||||
|                  1/0 | ||||
|          self.lineno_after_raise(after_with, 1, 1) | ||||
|   | ||||
|  | ||||
| @ -399,13 +399,12 @@ class ExceptionTests(__TestCase): | ||||
|         # test that setting an exception at the C level works even if the | ||||
|         # exception object can't be constructed. | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class BadException(Exception): | ||||
|                 def __init__(self_): | ||||
|                     raise RuntimeError("can't instantiate BadException") | ||||
|         class BadException(Exception): | ||||
|             def __init__(self_): | ||||
|                 raise RuntimeError("can't instantiate BadException") | ||||
|  | ||||
|             class InvalidException: | ||||
|                 pass | ||||
|         class InvalidException: | ||||
|             pass | ||||
|  | ||||
|         @unittest.skipIf(_testcapi is None, "requires _testcapi") | ||||
|         def test_capi1(): | ||||
| @ -693,9 +692,8 @@ class ExceptionTests(__TestCase): | ||||
|         self.assertIsInstance(e, IndexError) | ||||
|         self.assertEqual(e.__traceback__, tb) | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(Exception): | ||||
|                 pass | ||||
|         class MyException(Exception): | ||||
|             pass | ||||
|  | ||||
|         e = MyException().with_traceback(tb) | ||||
|         self.assertIsInstance(e, MyException) | ||||
| @ -754,9 +752,8 @@ class ExceptionTests(__TestCase): | ||||
|         self.assertIsNone(e.__context__) | ||||
|         self.assertIsNone(e.__cause__) | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(OSError): | ||||
|                 pass | ||||
|         class MyException(OSError): | ||||
|             pass | ||||
|  | ||||
|         e = MyException() | ||||
|         self.assertIsNone(e.__context__) | ||||
| @ -785,11 +782,10 @@ class ExceptionTests(__TestCase): | ||||
|         # but user-defined subclasses can if they want | ||||
|         self.assertRaises(TypeError, BaseException, a=1) | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class DerivedException(BaseException): | ||||
|                 def __init__(self, fancy_arg): | ||||
|                     BaseException.__init__(self) | ||||
|                     self.fancy_arg = fancy_arg | ||||
|         class DerivedException(BaseException): | ||||
|             def __init__(self, fancy_arg): | ||||
|                 BaseException.__init__(self) | ||||
|                 self.fancy_arg = fancy_arg | ||||
|  | ||||
|         x = DerivedException(fancy_arg=42) | ||||
|         self.assertEqual(x.fancy_arg, 42) | ||||
| @ -839,12 +835,11 @@ class ExceptionTests(__TestCase): | ||||
|         # Make sure exception state is cleaned up as soon as the except | ||||
|         # block is left. See #2507 | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(Exception): | ||||
|                 def __init__(self, obj): | ||||
|                     self.obj = obj | ||||
|             class MyObj: | ||||
|                 pass | ||||
|         class MyException(Exception): | ||||
|             def __init__(self, obj): | ||||
|                 self.obj = obj | ||||
|         class MyObj: | ||||
|             pass | ||||
|  | ||||
|         def inner_raising_func(): | ||||
|             # Create some references in exception value and traceback | ||||
| @ -942,12 +937,11 @@ class ExceptionTests(__TestCase): | ||||
|         self.assertIsNone(obj) | ||||
|  | ||||
|         # Inside an exception-silencing "with" block | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class Context: | ||||
|                 def __enter__(self): | ||||
|                     return self | ||||
|                 def __exit__ (self, exc_type, exc_value, exc_tb): | ||||
|                     return True | ||||
|         class Context: | ||||
|             def __enter__(self): | ||||
|                 return self | ||||
|             def __exit__ (self, exc_type, exc_value, exc_tb): | ||||
|                 return True | ||||
|         obj = MyObj() | ||||
|         wr = weakref.ref(obj) | ||||
|         with Context(): | ||||
| @ -1089,12 +1083,11 @@ class ExceptionTests(__TestCase): | ||||
|     def _check_generator_cleanup_exc_state(self, testfunc): | ||||
|         # Issue #12791: exception state is cleaned up as soon as a generator | ||||
|         # is closed (reference cycles are broken). | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(Exception): | ||||
|                 def __init__(self, obj): | ||||
|                     self.obj = obj | ||||
|             class MyObj: | ||||
|                 pass | ||||
|         class MyException(Exception): | ||||
|             def __init__(self, obj): | ||||
|                 self.obj = obj | ||||
|         class MyObj: | ||||
|             pass | ||||
|  | ||||
|         def raising_gen(): | ||||
|             try: | ||||
| @ -1153,11 +1146,10 @@ class ExceptionTests(__TestCase): | ||||
|     def test_3114(self): | ||||
|         # Bug #3114: in its destructor, MyObject retrieves a pointer to | ||||
|         # obsolete and/or deallocated objects. | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyObject: | ||||
|                 def __del__(self): | ||||
|                     nonlocal e | ||||
|                     e = sys.exception() | ||||
|         class MyObject: | ||||
|             def __del__(self): | ||||
|                 nonlocal e | ||||
|                 e = sys.exception() | ||||
|         e = () | ||||
|         try: | ||||
|             raise Exception(MyObject()) | ||||
| @ -1167,13 +1159,12 @@ class ExceptionTests(__TestCase): | ||||
|         self.assertIsNone(e) | ||||
|  | ||||
|     def test_raise_does_not_create_context_chain_cycle(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class A(Exception): | ||||
|                 pass | ||||
|             class B(Exception): | ||||
|                 pass | ||||
|             class C(Exception): | ||||
|                 pass | ||||
|         class A(Exception): | ||||
|             pass | ||||
|         class B(Exception): | ||||
|             pass | ||||
|         class C(Exception): | ||||
|             pass | ||||
|  | ||||
|         # Create a context chain: | ||||
|         # C -> B -> A | ||||
| @ -1229,13 +1220,12 @@ class ExceptionTests(__TestCase): | ||||
|     def test_no_hang_on_context_chain_cycle2(self): | ||||
|         # See issue 25782. Cycle at head of context chain. | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class A(Exception): | ||||
|                 pass | ||||
|             class B(Exception): | ||||
|                 pass | ||||
|             class C(Exception): | ||||
|                 pass | ||||
|         class A(Exception): | ||||
|             pass | ||||
|         class B(Exception): | ||||
|             pass | ||||
|         class C(Exception): | ||||
|             pass | ||||
|  | ||||
|         # Context cycle: | ||||
|         # +-----------+ | ||||
| @ -1266,17 +1256,16 @@ class ExceptionTests(__TestCase): | ||||
|     def test_no_hang_on_context_chain_cycle3(self): | ||||
|         # See issue 25782. Longer context chain with cycle. | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class A(Exception): | ||||
|                 pass | ||||
|             class B(Exception): | ||||
|                 pass | ||||
|             class C(Exception): | ||||
|                 pass | ||||
|             class D(Exception): | ||||
|                 pass | ||||
|             class E(Exception): | ||||
|                 pass | ||||
|         class A(Exception): | ||||
|             pass | ||||
|         class B(Exception): | ||||
|             pass | ||||
|         class C(Exception): | ||||
|             pass | ||||
|         class D(Exception): | ||||
|             pass | ||||
|         class E(Exception): | ||||
|             pass | ||||
|  | ||||
|         # Context cycle: | ||||
|         #             +-----------+ | ||||
| @ -1431,12 +1420,11 @@ class ExceptionTests(__TestCase): | ||||
|     def test_badisinstance(self): | ||||
|         # Bug #2542: if issubclass(e, MyException) raises an exception, | ||||
|         # it should be ignored | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class Meta(type): | ||||
|                 def __subclasscheck__(cls, subclass): | ||||
|                     raise ValueError() | ||||
|             class MyException(Exception, metaclass=Meta): | ||||
|                 pass | ||||
|         class Meta(type): | ||||
|             def __subclasscheck__(cls, subclass): | ||||
|                 raise ValueError() | ||||
|         class MyException(Exception, metaclass=Meta): | ||||
|             pass | ||||
|  | ||||
|         with captured_stderr() as stderr: | ||||
|             try: | ||||
| @ -1670,9 +1658,8 @@ class ExceptionTests(__TestCase): | ||||
|         self.assertTrue(issubclass(error3, error2)) | ||||
|  | ||||
|         # test with explicit base tuple | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class C(object): | ||||
|                 pass | ||||
|         class C(object): | ||||
|             pass | ||||
|         error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4, | ||||
|                                                    (error3, C)) | ||||
|         self.assertTrue(issubclass(error4, error3)) | ||||
| @ -1692,9 +1679,8 @@ class ExceptionTests(__TestCase): | ||||
|         # Issue #5437: preallocated MemoryError instances should not keep | ||||
|         # traceback objects alive. | ||||
|         from _testcapi import raise_memoryerror | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class C: | ||||
|                 pass | ||||
|         class C: | ||||
|             pass | ||||
|         wr = None | ||||
|         def inner(): | ||||
|             nonlocal wr | ||||
| @ -1714,9 +1700,8 @@ class ExceptionTests(__TestCase): | ||||
|     @no_tracing | ||||
|     def test_recursion_error_cleanup(self): | ||||
|         # Same test as above, but with "recursion exceeded" errors | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class C: | ||||
|                 pass | ||||
|         class C: | ||||
|             pass | ||||
|         wr = None | ||||
|         def inner(): | ||||
|             nonlocal wr | ||||
| @ -1741,12 +1726,11 @@ class ExceptionTests(__TestCase): | ||||
|  | ||||
|     def test_unraisable(self): | ||||
|         # Issue #22836: PyErr_WriteUnraisable() should give sensible reports | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class BrokenDel: | ||||
|                 def __del__(self): | ||||
|                     exc = ValueError("del is broken") | ||||
|                     # The following line is included in the traceback report: | ||||
|                     raise exc | ||||
|         class BrokenDel: | ||||
|             def __del__(self): | ||||
|                 exc = ValueError("del is broken") | ||||
|                 # The following line is included in the traceback report: | ||||
|                 raise exc | ||||
|  | ||||
|         obj = BrokenDel() | ||||
|         with support.catch_unraisable_exception() as cm: | ||||
| @ -1800,12 +1784,11 @@ class ExceptionTests(__TestCase): | ||||
|  | ||||
|     def test_yield_in_nested_try_excepts(self): | ||||
|         #Issue #25612 | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MainError(Exception): | ||||
|                 pass | ||||
|         class MainError(Exception): | ||||
|             pass | ||||
|  | ||||
|             class SubError(Exception): | ||||
|                 pass | ||||
|         class SubError(Exception): | ||||
|             pass | ||||
|  | ||||
|         def main(): | ||||
|             try: | ||||
| @ -1880,9 +1863,8 @@ class ExceptionTests(__TestCase): | ||||
|         # subclass object. Finally, it checks that creating a new MemoryError | ||||
|         # succeeds, proving that the freelist is not corrupted. | ||||
|  | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class TestException(MemoryError): | ||||
|                 pass | ||||
|         class TestException(MemoryError): | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             raise MemoryError | ||||
| @ -1960,9 +1942,8 @@ class NameErrorTests(__TestCase): | ||||
|  | ||||
|     def test_gh_111654(self): | ||||
|         def f(): | ||||
|             with torch._dynamo.error_on_graph_break(False): | ||||
|                 class TestClass: | ||||
|                     TestClass | ||||
|             class TestClass: | ||||
|                 TestClass | ||||
|  | ||||
|         self.assertRaises(NameError, f) | ||||
|  | ||||
| @ -1982,9 +1963,8 @@ class AttributeErrorTests(__TestCase): | ||||
|         self.assertIs(exc.obj, sentinel) | ||||
|  | ||||
|     def test_getattr_has_name_and_obj(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class A: | ||||
|                 blech = None | ||||
|         class A: | ||||
|             blech = None | ||||
|  | ||||
|         obj = A() | ||||
|         try: | ||||
| @ -1999,10 +1979,9 @@ class AttributeErrorTests(__TestCase): | ||||
|             self.assertEqual(obj, exc.obj) | ||||
|  | ||||
|     def test_getattr_has_name_and_obj_for_method(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class A: | ||||
|                 def blech(self): | ||||
|                     return | ||||
|         class A: | ||||
|             def blech(self): | ||||
|                 return | ||||
|  | ||||
|         obj = A() | ||||
|         try: | ||||
| @ -2331,9 +2310,8 @@ class SyntaxErrorTests(__TestCase): | ||||
|                     the_exception = exc | ||||
|  | ||||
|     def test_subclass(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MySyntaxError(SyntaxError): | ||||
|                 pass | ||||
|         class MySyntaxError(SyntaxError): | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7)) | ||||
| @ -2578,12 +2556,11 @@ class PEP626Tests(__TestCase): | ||||
|         self.lineno_after_raise(in_finally_except, 4) | ||||
|  | ||||
|     def test_lineno_after_with(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class Noop: | ||||
|                 def __enter__(self): | ||||
|                     return self | ||||
|                 def __exit__(self, *args): | ||||
|                     pass | ||||
|         class Noop: | ||||
|             def __enter__(self): | ||||
|                 return self | ||||
|             def __exit__(self, *args): | ||||
|                 pass | ||||
|         def after_with(): | ||||
|             with Noop(): | ||||
|                 1/0 | ||||
| @ -2598,12 +2575,11 @@ class PEP626Tests(__TestCase): | ||||
|         self.lineno_after_raise(f, None) | ||||
|  | ||||
|     def test_lineno_after_raise_in_with_exit(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class ExitFails: | ||||
|                 def __enter__(self): | ||||
|                     return self | ||||
|                 def __exit__(self, *args): | ||||
|                     raise ValueError | ||||
|         class ExitFails: | ||||
|             def __enter__(self): | ||||
|                 return self | ||||
|             def __exit__(self, *args): | ||||
|                 raise ValueError | ||||
|  | ||||
|         def after_with(): | ||||
|             with ExitFails(): | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_raise.py b/test/dynamo/cpython/3_13/test_raise.py | ||||
| index 6d26a61bee4..ce748433d28 100644 | ||||
| index 6d26a61bee4..042d1ae3d7c 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_raise.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_raise.py | ||||
| @@ -1,3 +1,58 @@ | ||||
| @ -70,35 +70,7 @@ index 6d26a61bee4..ce748433d28 100644 | ||||
|      def test_invalid_reraise(self): | ||||
|          try: | ||||
|              raise | ||||
| @@ -120,9 +175,10 @@ class TestRaise(unittest.TestCase): | ||||
|          self.assertRaises(StopIteration, lambda: next(g)) | ||||
|   | ||||
|      def test_erroneous_exception(self): | ||||
| -        class MyException(Exception): | ||||
| -            def __init__(self): | ||||
| -                raise RuntimeError() | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(Exception): | ||||
| +                def __init__(self): | ||||
| +                    raise RuntimeError() | ||||
|   | ||||
|          try: | ||||
|              raise MyException | ||||
| @@ -133,9 +189,10 @@ class TestRaise(unittest.TestCase): | ||||
|   | ||||
|      def test_new_returns_invalid_instance(self): | ||||
|          # See issue #11627. | ||||
| -        class MyException(Exception): | ||||
| -            def __new__(cls, *args): | ||||
| -                return object() | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(Exception): | ||||
| +                def __new__(cls, *args): | ||||
| +                    return object() | ||||
|   | ||||
|          with self.assertRaises(TypeError): | ||||
|              raise MyException | ||||
| @@ -148,7 +205,7 @@ class TestRaise(unittest.TestCase): | ||||
| @@ -148,7 +203,7 @@ class TestRaise(unittest.TestCase): | ||||
|   | ||||
|   | ||||
|   | ||||
| @ -107,37 +79,7 @@ index 6d26a61bee4..ce748433d28 100644 | ||||
|   | ||||
|      def testCauseSyntax(self): | ||||
|          try: | ||||
| @@ -186,10 +243,11 @@ class TestCause(unittest.TestCase): | ||||
|              self.fail("No exception raised") | ||||
|   | ||||
|      def test_class_cause_nonexception_result(self): | ||||
| -        class ConstructsNone(BaseException): | ||||
| -            @classmethod | ||||
| -            def __new__(*args, **kwargs): | ||||
| -                return None | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class ConstructsNone(BaseException): | ||||
| +                @classmethod | ||||
| +                def __new__(*args, **kwargs): | ||||
| +                    return None | ||||
|          try: | ||||
|              raise IndexError from ConstructsNone | ||||
|          except TypeError as e: | ||||
| @@ -209,9 +267,10 @@ class TestCause(unittest.TestCase): | ||||
|              self.fail("No exception raised") | ||||
|   | ||||
|      def test_erroneous_cause(self): | ||||
| -        class MyException(Exception): | ||||
| -            def __init__(self): | ||||
| -                raise RuntimeError() | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class MyException(Exception): | ||||
| +                def __init__(self): | ||||
| +                    raise RuntimeError() | ||||
|   | ||||
|          try: | ||||
|              raise IndexError from MyException | ||||
| @@ -221,7 +280,7 @@ class TestCause(unittest.TestCase): | ||||
| @@ -221,7 +276,7 @@ class TestCause(unittest.TestCase): | ||||
|              self.fail("No exception raised") | ||||
|   | ||||
|   | ||||
| @ -146,7 +88,7 @@ index 6d26a61bee4..ce748433d28 100644 | ||||
|   | ||||
|      def test_sets_traceback(self): | ||||
|          try: | ||||
| @@ -242,7 +301,7 @@ class TestTraceback(unittest.TestCase): | ||||
| @@ -242,7 +297,7 @@ class TestTraceback(unittest.TestCase): | ||||
|              self.fail("No exception raised") | ||||
|   | ||||
|   | ||||
| @ -155,7 +97,7 @@ index 6d26a61bee4..ce748433d28 100644 | ||||
|   | ||||
|      def raiser(self): | ||||
|          raise ValueError | ||||
| @@ -308,7 +367,7 @@ class TestTracebackType(unittest.TestCase): | ||||
| @@ -308,7 +363,7 @@ class TestTracebackType(unittest.TestCase): | ||||
|              types.TracebackType(other_tb, frame, 1, "nuh-uh") | ||||
|   | ||||
|   | ||||
| @ -164,45 +106,7 @@ index 6d26a61bee4..ce748433d28 100644 | ||||
|      def test_instance_context_instance_raise(self): | ||||
|          context = IndexError() | ||||
|          try: | ||||
| @@ -392,11 +451,12 @@ class TestContext(unittest.TestCase): | ||||
|              self.fail("No exception raised") | ||||
|   | ||||
|      def test_context_manager(self): | ||||
| -        class ContextManager: | ||||
| -            def __enter__(self): | ||||
| -                pass | ||||
| -            def __exit__(self, t, v, tb): | ||||
| -                xyzzy | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class ContextManager: | ||||
| +                def __enter__(self): | ||||
| +                    pass | ||||
| +                def __exit__(self, t, v, tb): | ||||
| +                    xyzzy | ||||
|          try: | ||||
|              with ContextManager(): | ||||
|                  1/0 | ||||
| @@ -471,12 +531,13 @@ class TestContext(unittest.TestCase): | ||||
|          import gc | ||||
|          # A re-raised exception in a __del__ caused the __context__ | ||||
|          # to be cleared | ||||
| -        class C: | ||||
| -            def __del__(self): | ||||
| -                try: | ||||
| -                    1/0 | ||||
| -                except: | ||||
| -                    raise | ||||
| +        with torch._dynamo.error_on_graph_break(False): | ||||
| +            class C: | ||||
| +                def __del__(self): | ||||
| +                    try: | ||||
| +                        1/0 | ||||
| +                    except: | ||||
| +                        raise | ||||
|   | ||||
|          def f(): | ||||
|              x = C() | ||||
| @@ -498,7 +559,7 @@ class TestContext(unittest.TestCase): | ||||
| @@ -498,7 +553,7 @@ class TestContext(unittest.TestCase): | ||||
|              self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type) | ||||
|   | ||||
|   | ||||
| @ -211,7 +115,7 @@ index 6d26a61bee4..ce748433d28 100644 | ||||
|      def test_tuples(self): | ||||
|          try: | ||||
|              raise (IndexError, KeyError) # This should be a tuple! | ||||
| @@ -517,4 +578,4 @@ class TestRemovedFunctionality(unittest.TestCase): | ||||
| @@ -517,4 +572,4 @@ class TestRemovedFunctionality(unittest.TestCase): | ||||
|   | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
|  | ||||
| @ -175,10 +175,9 @@ class TestRaise(__TestCase): | ||||
|         self.assertRaises(StopIteration, lambda: next(g)) | ||||
|  | ||||
|     def test_erroneous_exception(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(Exception): | ||||
|                 def __init__(self): | ||||
|                     raise RuntimeError() | ||||
|         class MyException(Exception): | ||||
|             def __init__(self): | ||||
|                 raise RuntimeError() | ||||
|  | ||||
|         try: | ||||
|             raise MyException | ||||
| @ -189,10 +188,9 @@ class TestRaise(__TestCase): | ||||
|  | ||||
|     def test_new_returns_invalid_instance(self): | ||||
|         # See issue #11627. | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(Exception): | ||||
|                 def __new__(cls, *args): | ||||
|                     return object() | ||||
|         class MyException(Exception): | ||||
|             def __new__(cls, *args): | ||||
|                 return object() | ||||
|  | ||||
|         with self.assertRaises(TypeError): | ||||
|             raise MyException | ||||
| @ -243,11 +241,10 @@ class TestCause(__TestCase): | ||||
|             self.fail("No exception raised") | ||||
|  | ||||
|     def test_class_cause_nonexception_result(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class ConstructsNone(BaseException): | ||||
|                 @classmethod | ||||
|                 def __new__(*args, **kwargs): | ||||
|                     return None | ||||
|         class ConstructsNone(BaseException): | ||||
|             @classmethod | ||||
|             def __new__(*args, **kwargs): | ||||
|                 return None | ||||
|         try: | ||||
|             raise IndexError from ConstructsNone | ||||
|         except TypeError as e: | ||||
| @ -267,10 +264,9 @@ class TestCause(__TestCase): | ||||
|             self.fail("No exception raised") | ||||
|  | ||||
|     def test_erroneous_cause(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class MyException(Exception): | ||||
|                 def __init__(self): | ||||
|                     raise RuntimeError() | ||||
|         class MyException(Exception): | ||||
|             def __init__(self): | ||||
|                 raise RuntimeError() | ||||
|  | ||||
|         try: | ||||
|             raise IndexError from MyException | ||||
| @ -451,12 +447,11 @@ class TestContext(__TestCase): | ||||
|             self.fail("No exception raised") | ||||
|  | ||||
|     def test_context_manager(self): | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class ContextManager: | ||||
|                 def __enter__(self): | ||||
|                     pass | ||||
|                 def __exit__(self, t, v, tb): | ||||
|                     xyzzy | ||||
|         class ContextManager: | ||||
|             def __enter__(self): | ||||
|                 pass | ||||
|             def __exit__(self, t, v, tb): | ||||
|                 xyzzy | ||||
|         try: | ||||
|             with ContextManager(): | ||||
|                 1/0 | ||||
| @ -531,13 +526,12 @@ class TestContext(__TestCase): | ||||
|         import gc | ||||
|         # A re-raised exception in a __del__ caused the __context__ | ||||
|         # to be cleared | ||||
|         with torch._dynamo.error_on_graph_break(False): | ||||
|             class C: | ||||
|                 def __del__(self): | ||||
|                     try: | ||||
|                         1/0 | ||||
|                     except: | ||||
|                         raise | ||||
|         class C: | ||||
|             def __del__(self): | ||||
|                 try: | ||||
|                     1/0 | ||||
|                 except: | ||||
|                     raise | ||||
|  | ||||
|         def f(): | ||||
|             x = C() | ||||
|  | ||||
| @ -916,41 +916,43 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase): | ||||
|             dedent( | ||||
|                 """\ | ||||
| SeqNr|OrigAten|SrcFn|FwdSrcFn | ||||
| 0|aten.convolution.default|conv2d| | ||||
| 0|aten.add.Tensor|add_| | ||||
| 1|aten._native_batch_norm_legit_functional.default|batch_norm| | ||||
| 2|aten.relu.default|relu| | ||||
| 2|aten.detach.default|relu| | ||||
| 0|aten.convolution.default|l__self___conv1| | ||||
| 0|aten.add.Tensor|l__self___bn1| | ||||
| 1|aten._native_batch_norm_legit_functional.default|l__self___bn1| | ||||
| 2|aten.relu.default|l__self___relu1| | ||||
| 2|aten.detach.default|l__self___relu1| | ||||
| 2|aten.detach.default|l__self___relu1| | ||||
| 3|aten.add.Tensor|add| | ||||
| 4|aten.view.default|flatten| | ||||
| 5|aten.view.default|linear| | ||||
| 6|aten.t.default|linear| | ||||
| 7|aten.addmm.default|linear| | ||||
| 8|aten.view.default|linear| | ||||
| 9|aten.sub.Tensor|l1_loss| | ||||
| 10|aten.abs.default|l1_loss| | ||||
| 11|aten.mean.default|l1_loss| | ||||
| 11|aten.ones_like.default||l1_loss | ||||
| 11|aten.expand.default||l1_loss | ||||
| 11|aten.div.Scalar||l1_loss | ||||
| 10|aten.sgn.default||l1_loss | ||||
| 10|aten.mul.Tensor||l1_loss | ||||
| 8|aten.view.default||linear | ||||
| 7|aten.t.default||linear | ||||
| 7|aten.mm.default||linear | ||||
| 7|aten.t.default||linear | ||||
| 7|aten.mm.default||linear | ||||
| 7|aten.t.default||linear | ||||
| 7|aten.sum.dim_IntList||linear | ||||
| 7|aten.view.default||linear | ||||
| 6|aten.t.default||linear | ||||
| 5|aten.view.default||linear | ||||
| 5|aten.view.default|l__self___fc1| | ||||
| 6|aten.t.default|l__self___fc1| | ||||
| 7|aten.addmm.default|l__self___fc1| | ||||
| 8|aten.view.default|l__self___fc1| | ||||
| 9|aten.sub.Tensor|l__self___loss_fn| | ||||
| 10|aten.abs.default|l__self___loss_fn| | ||||
| 11|aten.mean.default|l__self___loss_fn| | ||||
| 11|aten.ones_like.default||l__self___loss_fn | ||||
| 11|aten.expand.default||l__self___loss_fn | ||||
| 11|aten.div.Scalar||l__self___loss_fn | ||||
| 10|aten.sgn.default||l__self___loss_fn | ||||
| 10|aten.mul.Tensor||l__self___loss_fn | ||||
| 8|aten.view.default||l__self___fc1 | ||||
| 7|aten.t.default||l__self___fc1 | ||||
| 7|aten.mm.default||l__self___fc1 | ||||
| 7|aten.t.default||l__self___fc1 | ||||
| 7|aten.mm.default||l__self___fc1 | ||||
| 7|aten.t.default||l__self___fc1 | ||||
| 7|aten.sum.dim_IntList||l__self___fc1 | ||||
| 7|aten.view.default||l__self___fc1 | ||||
| 6|aten.t.default||l__self___fc1 | ||||
| 5|aten.view.default||l__self___fc1 | ||||
| 4|aten.view.default||flatten | ||||
| 2|aten.detach.default||relu | ||||
| 2|aten.threshold_backward.default||relu | ||||
| 1|aten.native_batch_norm_backward.default||batch_norm | ||||
| 0|aten.convolution_backward.default||conv2d | ||||
| 11|aten.add.Tensor||l1_loss | ||||
| 2|aten.detach.default||l__self___relu1 | ||||
| 2|aten.detach.default||l__self___relu1 | ||||
| 2|aten.threshold_backward.default||l__self___relu1 | ||||
| 1|aten.native_batch_norm_backward.default||l__self___bn1 | ||||
| 0|aten.convolution_backward.default||l__self___conv1 | ||||
| 11|aten.add.Tensor||l__self___loss_fn | ||||
| """ | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| import functools | ||||
| import inspect | ||||
| import os | ||||
| import pickle | ||||
| @ -204,39 +203,6 @@ class TestAOTCompile(torch._inductor.test_case.TestCase): | ||||
|             actual = compiled_fn(*example_inputs) | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_decorated_function_with_functools_wrap_aot(self): | ||||
|         def check_inputs(fn): | ||||
|             @functools.wraps(fn) | ||||
|             def _fn(*args, **kwargs): | ||||
|                 for arg in args: | ||||
|                     assert arg.shape[0] > 1 | ||||
|  | ||||
|                 return fn(*args, **kwargs) | ||||
|  | ||||
|             return _fn | ||||
|  | ||||
|         @check_inputs | ||||
|         def foo(x, y): | ||||
|             a = x + x | ||||
|             b = y + y | ||||
|             c = a + b | ||||
|             return c | ||||
|  | ||||
|         example_inputs = (torch.ones(3), torch.ones(3)) | ||||
|         expected = foo(*example_inputs) | ||||
|  | ||||
|         def backend(gm, example_inputs): | ||||
|             return CustomCompiledFunction(gm, example_inputs) | ||||
|  | ||||
|         with torch.compiler.set_stance("fail_on_recompile"): | ||||
|             compiled_fn = torch.compile( | ||||
|                 foo, | ||||
|                 fullgraph=True, | ||||
|                 backend=backend, | ||||
|             ).aot_compile((example_inputs, {})) | ||||
|             actual = compiled_fn(*example_inputs) | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_aot_compile_disable_guard_check(self): | ||||
|         def fn(x, y): | ||||
|             return x + y | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	