mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			229 Commits
		
	
	
		
			dynamo_sta
			...
			ciflow/tru
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 918fea72b7 | |||
| fc7c7e0f28 | |||
| b4f7c19bbb | |||
| ad78f89ca0 | |||
| 41eb3b3cfc | |||
| b9d8aaad45 | |||
| 5a4997dcae | |||
| 47f638eae7 | |||
| 882b834082 | |||
| b146ea411e | |||
| 8625ffbd45 | |||
| 0977cc4474 | |||
| d9a55faccc | |||
| 75b8295868 | |||
| defb6a80d8 | |||
| f8fccb1e48 | |||
| 5aac4cfce4 | |||
| baf91bbbfc | |||
| cbcb4f7768 | |||
| 2b93d5b450 | |||
| 6b7cd48e7e | |||
| bf5aa9e42e | |||
| b1eb6dede5 | |||
| 673060beae | |||
| 2e8e9a59a8 | |||
| fb277a5916 | |||
| 73fa0d0c63 | |||
| 36c21cc84e | |||
| 0b68814b44 | |||
| e64a814ae7 | |||
| 0b58d87aec | |||
| 757975ad50 | |||
| 291712026b | |||
| 3e77a2b478 | |||
| 82ef1b5db3 | |||
| 5f370f5c42 | |||
| 05b2e02cb4 | |||
| 12f742941d | |||
| 35180fafee | |||
| c746feb86a | |||
| c5f26db5bf | |||
| 18e99b6d45 | |||
| ab9e466928 | |||
| af4ba78543 | |||
| 282f39a4bc | |||
| a479769488 | |||
| 26c7375477 | |||
| d01f15152c | |||
| 4fae6968b1 | |||
| f9953e0f61 | |||
| 34ed7a8f0d | |||
| 2fde10d914 | |||
| 0a93295da0 | |||
| 4b898b51b9 | |||
| 550e3e6efb | |||
| 715449ca76 | |||
| 84d8d06fc3 | |||
| 60992d98b2 | |||
| 59e015e3a1 | |||
| 8904a5a7c9 | |||
| f5df9ca03a | |||
| 2998abd777 | |||
| e13580e41c | |||
| f3b8e15f20 | |||
| 5211f4c108 | |||
| ad9027b80d | |||
| a1005427bf | |||
| 35153d0846 | |||
| 7773a22cdb | |||
| 7cb467a169 | |||
| 12aac12b8d | |||
| 2b748d0a56 | |||
| 16745a882a | |||
| 8daef35cf1 | |||
| 51319ca090 | |||
| d311a3d1dc | |||
| 04adfe5ba9 | |||
| 4be1e3bf92 | |||
| e7592f4005 | |||
| d334c3649d | |||
| 9f82535c5a | |||
| 5b35fc8777 | |||
| 2f38eece7c | |||
| 830e789a55 | |||
| ad4dc52bf6 | |||
| dac9ed9790 | |||
| 1c7fe8f861 | |||
| 4e643422f6 | |||
| 3c3b278872 | |||
| 0bd12c1168 | |||
| ce8a7764e2 | |||
| d1269a0434 | |||
| c87cf1be32 | |||
| 2fc5e45a41 | |||
| f9022ba93b | |||
| ff8be889ad | |||
| 292454942e | |||
| 6c4412f72b | |||
| 78bf6186f2 | |||
| c40048472c | |||
| 3dfd0c7584 | |||
| e6ba4d0725 | |||
| bdf7cb9d9c | |||
| 6aed378958 | |||
| 8b3dc0d1b0 | |||
| 06773663b5 | |||
| 0bff65503c | |||
| 21131a2444 | |||
| 1009790ad8 | |||
| 410e6a4321 | |||
| 23c55c5b66 | |||
| 1290b077f2 | |||
| 9f9ab881b2 | |||
| f2bb22ff84 | |||
| 03f3f7899c | |||
| 771170807b | |||
| ffa90d46e6 | |||
| 0e083942cc | |||
| ce1fcff03e | |||
| a238a9a100 | |||
| fe69a2bbbd | |||
| 0be0de4ffa | |||
| 7406d2e665 | |||
| 303c9cf048 | |||
| d7d4bb7c51 | |||
| 0b1c462979 | |||
| 4a6cf0a93e | |||
| 4c963a68d7 | |||
| b20deec3d1 | |||
| 51d0d8ee67 | |||
| 70592c6819 | |||
| 259cb945f5 | |||
| e20c9bf288 | |||
| 99c8640b5d | |||
| 96b0e7aaa6 | |||
| 850ba8c96d | |||
| 1bcd736f91 | |||
| df64c0c464 | |||
| 1891239a1d | |||
| cf280ca1e8 | |||
| efc277cac7 | |||
| 4f7f43253d | |||
| 779296a3fc | |||
| 8f06a1308f | |||
| 240c13394e | |||
| 150682ba7f | |||
| ca7360e996 | |||
| 0bf604320f | |||
| 9875e70da8 | |||
| 69a4bfe8bb | |||
| 62a263b8d4 | |||
| 0da1f911dc | |||
| 8700d68fef | |||
| ab82456c16 | |||
| b23f4687fd | |||
| 2705937080 | |||
| c1eda348be | |||
| ba93d5636e | |||
| 722b2b86c9 | |||
| e1e8491b31 | |||
| 767199fd9b | |||
| 602ace5eb4 | |||
| 47804ce467 | |||
| e8cb34dd52 | |||
| e9d8973427 | |||
| 61d9a5180e | |||
| 8a8329b51f | |||
| 6b80c94901 | |||
| 8951df03de | |||
| 8139f33fa5 | |||
| a88587348b | |||
| 633a3b7f67 | |||
| fa0db212e7 | |||
| 15ff1cd28b | |||
| c73f5080de | |||
| 22ae059d32 | |||
| 1b121d636e | |||
| 1ba808dd97 | |||
| b2f5c25b27 | |||
| a1114beed2 | |||
| 4888ed440e | |||
| 5d62b63a76 | |||
| 57ba575242 | |||
| ceb11a584d | |||
| 33adb276fe | |||
| e939651972 | |||
| 3255e7872b | |||
| c4f6619330 | |||
| f18041cca8 | |||
| 35e51893bd | |||
| 1f43d17ce6 | |||
| 032bed95cd | |||
| d14cbb4476 | |||
| f510d0dbc0 | |||
| beb6b62e8c | |||
| 4740ce7787 | |||
| ad67170c8b | |||
| fdab48a7c1 | |||
| a0948d4d23 | |||
| 0bbdd6b8db | |||
| 24520b8386 | |||
| c79dfdc655 | |||
| e595136187 | |||
| aaac8cb0f5 | |||
| 0f0b4bf029 | |||
| b8194268a6 | |||
| f02e3947f6 | |||
| 9095a9dfae | |||
| d9f94e0d7d | |||
| 23417ae50f | |||
| e4d6c56ffb | |||
| 017d2985f3 | |||
| c6a8db0b9a | |||
| de09bab4b6 | |||
| c137e222d4 | |||
| cf3a787bbc | |||
| de3da77cf7 | |||
| 543ddbf44c | |||
| e9f4999985 | |||
| 29b029648e | |||
| a25a649e70 | |||
| 69c33898fa | |||
| 1b397420f2 | |||
| fe80f03726 | |||
| e50dc40d28 | |||
| 2e22b1a61e | |||
| 616c6bdf8f | |||
| c18ddfc572 | |||
| 86ebce1766 | 
| @ -83,10 +83,6 @@ function build_cpython { | ||||
|         py_suffix=${py_ver::-1} | ||||
|         py_folder=$py_suffix | ||||
|     fi | ||||
|     # Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4 | ||||
|     if [ "$py_suffix" == "3.14.0" ]; then | ||||
|         py_suffix="3.14.0rc2" | ||||
|     fi | ||||
|     wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz | ||||
|     do_cpython_build $py_ver Python-$py_suffix | ||||
|  | ||||
|  | ||||
| @ -19,7 +19,7 @@ pip_install \ | ||||
|   transformers==4.36.2 | ||||
|  | ||||
| pip_install coloredlogs packaging | ||||
| pip_install onnxruntime==1.23.0 | ||||
| pip_install onnxruntime==1.23.1 | ||||
| pip_install onnxscript==0.5.4 | ||||
|  | ||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||
|  | ||||
| @ -334,12 +334,12 @@ sympy==1.13.3 | ||||
| #Pinned versions: | ||||
| #test that import: | ||||
|  | ||||
| onnx==1.18.0 | ||||
| onnx==1.19.1 | ||||
| #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal | ||||
| #Pinned versions: | ||||
| #test that import: | ||||
|  | ||||
| onnxscript==0.5.3 | ||||
| onnxscript==0.5.4 | ||||
| #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal | ||||
| #Pinned versions: | ||||
| #test that import: | ||||
|  | ||||
| @ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules | ||||
|         logger.info("Successfully cloned %s", target) | ||||
|         return r, commit | ||||
|  | ||||
|     except GitCommandError as e: | ||||
|         logger.error("Git operation failed: %s", e) | ||||
|     except GitCommandError: | ||||
|         logger.exception("Git operation failed") | ||||
|         raise | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -6,7 +6,7 @@ dependencies = [ | ||||
|     "GitPython==3.1.45", | ||||
|     "docker==7.1.0", | ||||
|     "pytest==7.3.2", | ||||
|     "uv==0.8.6" | ||||
|     "uv==0.9.5" | ||||
| ] | ||||
|  | ||||
| [tool.setuptools] | ||||
|  | ||||
| @ -163,8 +163,13 @@ if [[ "$(uname)" != Darwin ]]; then | ||||
|   MEMORY_LIMIT_MAX_JOBS=12 | ||||
|   NUM_CPUS=$(( $(nproc) - 2 )) | ||||
|  | ||||
|   # Defaults here for **binary** linux builds so they can be changed in one place | ||||
|   export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} | ||||
|   if [[ "$(uname)" == Linux ]]; then | ||||
|     # Defaults here for **binary** linux builds so they can be changed in one place | ||||
|     export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} | ||||
|   else | ||||
|     # For other builds | ||||
|     export MAX_JOBS=${NUM_CPUS} | ||||
|   fi | ||||
|  | ||||
|   cat >>"$envfile" <<EOL | ||||
|   export MAX_JOBS="${MAX_JOBS}" | ||||
|  | ||||
							
								
								
									
										4
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								.flake8
									
									
									
									
									
								
							| @ -7,14 +7,12 @@ max-line-length = 120 | ||||
| # C408 ignored because we like the dict keyword argument syntax | ||||
| # E501 is not flexible enough, we're using B950 instead | ||||
| ignore = | ||||
|     E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, | ||||
|     E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, | ||||
|     # shebang has extra meaning in fbcode lints, so I think it's not worth trying | ||||
|     # to line this up with executable bit | ||||
|     EXE001, | ||||
|     # these ignores are from flake8-bugbear; please fix! | ||||
|     B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910 | ||||
|     # these ignores are from flake8-logging-format; please fix! | ||||
|     G100,G101,G200 | ||||
|     # these ignores are from flake8-simplify. please fix or ignore with commented reason | ||||
|     SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, | ||||
|     # SIM104 is already covered by pyupgrade ruff | ||||
|  | ||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -124,3 +124,10 @@ runs: | ||||
|       id: login-ecr | ||||
|       continue-on-error: true | ||||
|       uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 | ||||
|  | ||||
|     - name: Preserve github env variables for use in docker | ||||
|       shell: bash | ||||
|       run: | | ||||
|         env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||
|         env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||
|         env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| 1b013f5b5a87a1882eb143c26d79d091150d6a37 | ||||
| 69bbe7363897764f9e758d851cd0340147d27f94 | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| faffd5cf673615583da6517275e361cb3dbc77e6 | ||||
| 1752fe6809b74921644866275ab80244b96e80bc | ||||
|  | ||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							| @ -133,3 +133,32 @@ | ||||
|  | ||||
| "ciflow/vllm": | ||||
| - .github/ci_commit_pins/vllm.txt | ||||
|  | ||||
| "ciflow/b200": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/**/*cublas* | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
| "ciflow/h100": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/**/*cublas* | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
| "ciflow/rocm": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							| @ -33,6 +33,7 @@ ciflow_push_tags: | ||||
| - ciflow/rocm | ||||
| - ciflow/rocm-mi300 | ||||
| - ciflow/rocm-mi355 | ||||
| - ciflow/rocm-navi31 | ||||
| - ciflow/s390 | ||||
| - ciflow/slow | ||||
| - ciflow/torchbench | ||||
|  | ||||
							
								
								
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							| @ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { | ||||
|         "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" | ||||
|     ), | ||||
|     "12.9": ( | ||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'" | ||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " | ||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | " | ||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | " | ||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | " | ||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | " | ||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | " | ||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | " | ||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " | ||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | " | ||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | " | ||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | " | ||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" | ||||
|     ), | ||||
|     "13.0": ( | ||||
|         "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | " | ||||
|  | ||||
| @ -26,9 +26,8 @@ name: !{{ build_environment }} | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}" | ||||
|           python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}" | ||||
|           freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }} | ||||
| {%- endmacro %} | ||||
|  | ||||
|  | ||||
| @ -79,9 +79,9 @@ jobs: | ||||
|     runs-on: "windows-11-arm64-preview" | ||||
|     {%- else %} | ||||
|     {%- if branches == "nightly" %} | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     {%- else %} | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral" | ||||
|     {%- endif %} | ||||
|     {%- endif %} | ||||
|     timeout-minutes: !{{ common.timeout_minutes_windows_binary }} | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -224,7 +224,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_10-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -473,7 +473,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_11-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -722,7 +722,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_12-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -971,7 +971,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_13-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1220,7 +1220,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_13t-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1469,7 +1469,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_14-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1718,7 +1718,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_14t-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -259,7 +259,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_10-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_10-cuda12_9-test:  # Testing | ||||
| @ -925,7 +925,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_11-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_11-cuda12_9-test:  # Testing | ||||
| @ -1591,7 +1591,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_12-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_12-cuda12_9-test:  # Testing | ||||
| @ -2257,7 +2257,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_13-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_13-cuda12_9-test:  # Testing | ||||
| @ -2923,7 +2923,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_13t-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_13t-cuda12_9-test:  # Testing | ||||
| @ -3589,7 +3589,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_14-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_14-cuda12_9-test:  # Testing | ||||
| @ -4255,7 +4255,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_14t-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_14t-cuda12_9-test:  # Testing | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -63,7 +63,6 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.10.4" | ||||
|           freethreaded: false | ||||
|  | ||||
							
								
								
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -59,7 +59,6 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.10.4" | ||||
|           freethreaded: false | ||||
| @ -169,7 +168,6 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.11.4" | ||||
|           freethreaded: false | ||||
| @ -279,7 +277,6 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.12.4" | ||||
|           freethreaded: false | ||||
| @ -389,7 +386,6 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.13.4" | ||||
|           freethreaded: false | ||||
| @ -499,7 +495,6 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.13.4" | ||||
|           freethreaded: true | ||||
| @ -609,9 +604,8 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.14.0-rc.2" | ||||
|           python-version: "3.14.0" | ||||
|           freethreaded: false | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
| @ -719,9 +713,8 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.14.0-rc.2" | ||||
|           python-version: "3.14.0" | ||||
|           freethreaded: true | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | ||||
|   libtorch-cpu-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -291,7 +291,7 @@ jobs: | ||||
|   libtorch-cuda12_6-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -541,7 +541,7 @@ jobs: | ||||
|   libtorch-cuda12_8-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -791,7 +791,7 @@ jobs: | ||||
|   libtorch-cuda13_0-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | ||||
|   libtorch-cpu-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -291,7 +291,7 @@ jobs: | ||||
|   libtorch-cuda12_6-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -541,7 +541,7 @@ jobs: | ||||
|   libtorch-cuda12_8-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -791,7 +791,7 @@ jobs: | ||||
|   libtorch-cuda13_0-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|  | ||||
							
								
								
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | ||||
|   wheel-py3_10-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -279,7 +279,7 @@ jobs: | ||||
|   wheel-py3_10-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -517,7 +517,7 @@ jobs: | ||||
|   wheel-py3_10-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -755,7 +755,7 @@ jobs: | ||||
|   wheel-py3_10-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -993,7 +993,7 @@ jobs: | ||||
|   wheel-py3_10-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1229,7 +1229,7 @@ jobs: | ||||
|   wheel-py3_11-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1464,7 +1464,7 @@ jobs: | ||||
|   wheel-py3_11-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1702,7 +1702,7 @@ jobs: | ||||
|   wheel-py3_11-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1940,7 +1940,7 @@ jobs: | ||||
|   wheel-py3_11-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2178,7 +2178,7 @@ jobs: | ||||
|   wheel-py3_11-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2414,7 +2414,7 @@ jobs: | ||||
|   wheel-py3_12-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2649,7 +2649,7 @@ jobs: | ||||
|   wheel-py3_12-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2887,7 +2887,7 @@ jobs: | ||||
|   wheel-py3_12-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3125,7 +3125,7 @@ jobs: | ||||
|   wheel-py3_12-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3363,7 +3363,7 @@ jobs: | ||||
|   wheel-py3_12-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3599,7 +3599,7 @@ jobs: | ||||
|   wheel-py3_13-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3834,7 +3834,7 @@ jobs: | ||||
|   wheel-py3_13-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4072,7 +4072,7 @@ jobs: | ||||
|   wheel-py3_13-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4310,7 +4310,7 @@ jobs: | ||||
|   wheel-py3_13-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4548,7 +4548,7 @@ jobs: | ||||
|   wheel-py3_13-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4784,7 +4784,7 @@ jobs: | ||||
|   wheel-py3_13t-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5019,7 +5019,7 @@ jobs: | ||||
|   wheel-py3_13t-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5257,7 +5257,7 @@ jobs: | ||||
|   wheel-py3_13t-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5495,7 +5495,7 @@ jobs: | ||||
|   wheel-py3_13t-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5733,7 +5733,7 @@ jobs: | ||||
|   wheel-py3_13t-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5969,7 +5969,7 @@ jobs: | ||||
|   wheel-py3_14-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6204,7 +6204,7 @@ jobs: | ||||
|   wheel-py3_14-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6442,7 +6442,7 @@ jobs: | ||||
|   wheel-py3_14-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6680,7 +6680,7 @@ jobs: | ||||
|   wheel-py3_14-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6918,7 +6918,7 @@ jobs: | ||||
|   wheel-py3_14-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7154,7 +7154,7 @@ jobs: | ||||
|   wheel-py3_14t-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7389,7 +7389,7 @@ jobs: | ||||
|   wheel-py3_14t-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7627,7 +7627,7 @@ jobs: | ||||
|   wheel-py3_14t-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7865,7 +7865,7 @@ jobs: | ||||
|   wheel-py3_14t-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -8103,7 +8103,7 @@ jobs: | ||||
|   wheel-py3_14t-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -88,7 +88,6 @@ jobs: | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3_10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|  | ||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -147,15 +147,16 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 | ||||
|       cuda-arch-list: 8.9 | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							| @ -347,7 +347,8 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       sync-tag: linux-xpu-n-build | ||||
|       # This should sync with the build in xpu.yml but xpu uses a larger runner | ||||
|       # sync-tag: linux-xpu-n-build | ||||
|       runner_prefix: ${{ needs.get-label-type.outputs.label-type }} | ||||
|       build-environment: linux-jammy-xpu-n-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							| @ -45,7 +45,6 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-noble-rocm-py3.12-mi300 | ||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							| @ -42,7 +42,6 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-noble-rocm-py3.12-mi355 | ||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|  | ||||
							
								
								
									
										75
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | ||||
| name: rocm-navi31 | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/rocm-navi31/* | ||||
|   workflow_dispatch: | ||||
|   schedule: | ||||
|     # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. | ||||
|     # Also run less frequently on weekends. | ||||
|     - cron: 45 */2 * * 1-5 | ||||
|     - cron: 45 4,12 * * 0,6 | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: read-all | ||||
|  | ||||
| jobs: | ||||
|   target-determination: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: before-test | ||||
|     uses: ./.github/workflows/target_determination.yml | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|  | ||||
|   get-label-type: | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-build: | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-test: | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: linux-jammy-rocm-py3_10 | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-rocm-py3_10-build | ||||
|       - target-determination | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||
|       tests-to-include: >- | ||||
|          ${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs | ||||
|          test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark | ||||
|          inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor | ||||
|          inductor/test_torchinductor inductor/test_decompose_mem_bound_mm | ||||
|          inductor/test_flex_attention inductor/test_max_autotune' || '' }} | ||||
|     secrets: inherit | ||||
							
								
								
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							| @ -26,11 +26,23 @@ jobs: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|  | ||||
|   get-label-type: | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-build: | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
| @ -59,29 +71,3 @@ jobs: | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-gfx1100-test: | ||||
|     if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: linux-jammy-rocm-py3_10-gfx1100 | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-rocm-py3_10-build | ||||
|       - target-determination | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|         ]} | ||||
|       tests-to-include: > | ||||
|          test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs | ||||
|          test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark | ||||
|          inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor | ||||
|          inductor/test_torchinductor inductor/test_decompose_mem_bound_mm | ||||
|          inductor/test_flex_attention inductor/test_max_autotune | ||||
|     secrets: inherit | ||||
|  | ||||
							
								
								
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										147
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,8 +58,10 @@ jobs: | ||||
|           else | ||||
|             COMMIT_SHA="${{ github.sha }}" | ||||
|           fi | ||||
|           echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" | ||||
|           echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" | ||||
|           { | ||||
|             echo "sha=${COMMIT_SHA}" | ||||
|             echo "tag_name=trunk/${COMMIT_SHA}" | ||||
|           } >> "${GITHUB_OUTPUT}" | ||||
|  | ||||
|       - name: Validate commit SHA | ||||
|         run: | | ||||
| @ -87,7 +89,7 @@ jobs: | ||||
|             echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" | ||||
|           fi | ||||
|  | ||||
|       - name: Create and push tag with retry | ||||
|       - name: Create and push tag(s) with retry | ||||
|         id: check_tag | ||||
|         env: | ||||
|           TAG_NAME: ${{ steps.commit.outputs.tag_name }} | ||||
| @ -112,14 +114,23 @@ jobs: | ||||
|             return 1 | ||||
|           } | ||||
|  | ||||
|           # Exit early if tag already exists | ||||
|           if check_tag_exists; then | ||||
|             echo "✅ Tag already exists - no action needed" | ||||
|             echo "exists=true" >> "${GITHUB_OUTPUT}" | ||||
|             exit 0 | ||||
|           fi | ||||
|           # Counters for summary reporting | ||||
|           created_count=0 | ||||
|           skipped_count=0 | ||||
|           failed_count=0 | ||||
|  | ||||
|           echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||
|           # Always write outputs once on exit | ||||
|           finish() { | ||||
|             set +e | ||||
|             if [ -n "${GITHUB_OUTPUT:-}" ]; then | ||||
|               { | ||||
|                 echo "created_count=${created_count}" | ||||
|                 echo "skipped_count=${skipped_count}" | ||||
|                 echo "failed_count=${failed_count}" | ||||
|               } >> "${GITHUB_OUTPUT}" | ||||
|             fi | ||||
|           } | ||||
|           trap finish EXIT | ||||
|  | ||||
|           # Retry configuration | ||||
|           MAX_RETRIES=5 | ||||
| @ -194,31 +205,111 @@ jobs: | ||||
|             } | ||||
|           } | ||||
|  | ||||
|           # Execute with retry | ||||
|           if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||
|             echo "exists=false" >> "${GITHUB_OUTPUT}" | ||||
|           # New behavior for push events: enumerate commits in the push and tag each one. | ||||
|           # For workflow_dispatch, retain existing single-SHA behavior. | ||||
|  | ||||
|           # Always fetch tags once up front to improve idempotency in loops | ||||
|           git fetch origin --tags --quiet || true | ||||
|  | ||||
|           if [ "${{ github.event_name }}" = "push" ]; then | ||||
|             BEFORE_SHA="${{ github.event.before }}" | ||||
|             AFTER_SHA="${{ github.sha }}"  # same as event.after | ||||
|  | ||||
|             # List commits introduced by this push (old..new), oldest first for stable ordering | ||||
|             commits_file="$(mktemp)" | ||||
|             git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}" | ||||
|  | ||||
|             if [ ! -s "${commits_file}" ]; then | ||||
|               echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag." | ||||
|               rm -f "${commits_file}" | ||||
|               exit 0 | ||||
|             fi | ||||
|  | ||||
|             commit_count="$(wc -l < "${commits_file}" | tr -d ' ')" | ||||
|             echo "Found ${commit_count} commit(s) to tag for push:" | ||||
|             while IFS= read -r sha; do | ||||
|               printf '  %s\n' "${sha}" | ||||
|             done < "${commits_file}" | ||||
|  | ||||
|             while IFS= read -r sha; do | ||||
|               TAG_NAME="trunk/${sha}" | ||||
|               COMMIT_SHA="${sha}" | ||||
|  | ||||
|               # If tag already exists locally or remotely, skip (idempotent) | ||||
|               if check_tag_exists; then | ||||
|                 echo "✅ Tag ${TAG_NAME} already exists - skipping" | ||||
|                 skipped_count=$((skipped_count + 1)) | ||||
|                 continue | ||||
|               fi | ||||
|  | ||||
|               echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||
|  | ||||
|               if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||
|                 created_count=$((created_count + 1)) | ||||
|               else | ||||
|                 echo "Tag creation failed after all retry attempts for ${TAG_NAME}" | ||||
|                 failed_count=$((failed_count + 1)) | ||||
|               fi | ||||
|             done < "${commits_file}" | ||||
|  | ||||
|             rm -f "${commits_file}" | ||||
|  | ||||
|             if [ "${failed_count}" -gt 0 ]; then | ||||
|               exit 1 | ||||
|             fi | ||||
|             exit 0 | ||||
|           else | ||||
|             echo "Tag creation failed after all retry attempts" | ||||
|             exit 1 | ||||
|             # workflow_dispatch path (single SHA tagging preserved) | ||||
|  | ||||
|             # Exit early if tag already exists | ||||
|             if check_tag_exists; then | ||||
|               echo "✅ Tag already exists - no action needed" | ||||
|               skipped_count=1 | ||||
|               exit 0 | ||||
|             fi | ||||
|  | ||||
|             echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||
|  | ||||
|             if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||
|               created_count=1 | ||||
|               exit 0 | ||||
|             else | ||||
|               echo "Tag creation failed after all retry attempts" | ||||
|               failed_count=1 | ||||
|               exit 1 | ||||
|             fi | ||||
|           fi | ||||
|  | ||||
|       - name: Tag creation summary | ||||
|         if: always() | ||||
|         run: | | ||||
|           if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then | ||||
|             echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" | ||||
|           elif [ "${{ job.status }}" = "success" ]; then | ||||
|             echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|           if [ "${{ github.event_name }}" = "push" ]; then | ||||
|             echo "Trigger: push on main" | ||||
|             echo "Created: ${{ steps.check_tag.outputs.created_count }}" | ||||
|             echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}" | ||||
|             echo "Failed: ${{ steps.check_tag.outputs.failed_count }}" | ||||
|             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||
|               echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}" | ||||
|             else | ||||
|               echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}" | ||||
|             fi | ||||
|           else | ||||
|             echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|           fi | ||||
|             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||
|               if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then | ||||
|                 echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" | ||||
|               else | ||||
|                 echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|               fi | ||||
|             else | ||||
|               echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|             fi | ||||
|  | ||||
|           echo "" | ||||
|           echo "Tag details:" | ||||
|           echo "  Name: ${{ steps.commit.outputs.tag_name }}" | ||||
|           echo "  Commit: ${{ steps.commit.outputs.sha }}" | ||||
|           echo "  Trigger: ${{ github.event_name }}" | ||||
|           if [ -n "${{ github.event.inputs.commit_sha }}" ]; then | ||||
|             echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" | ||||
|             echo "" | ||||
|             echo "Tag details:" | ||||
|             echo "  Name: ${{ steps.commit.outputs.tag_name }}" | ||||
|             echo "  Commit: ${{ steps.commit.outputs.sha }}" | ||||
|             echo "  Trigger: ${{ github.event_name }}" | ||||
|             if [ -n "${{ github.event.inputs.commit_sha }}" ]; then | ||||
|               echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" | ||||
|             fi | ||||
|           fi | ||||
|  | ||||
							
								
								
									
										34
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										34
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							| @ -190,6 +190,40 @@ jobs: | ||||
|       runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-build: | ||||
|     if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-test: | ||||
|     if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-rocm-py3_10-build | ||||
|       - target-determination | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||
|       tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor" | ||||
|     secrets: inherit | ||||
|  | ||||
|   inductor-build: | ||||
|     name: inductor-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|  | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -374,6 +374,7 @@ third_party/ruy/ | ||||
| third_party/glog/ | ||||
|  | ||||
| # Virtualenv | ||||
| .venv/ | ||||
| venv/ | ||||
|  | ||||
| # Log files | ||||
|  | ||||
| @ -1138,11 +1138,8 @@ command = [ | ||||
| [[linter]] | ||||
| code = 'WORKFLOWSYNC' | ||||
| include_patterns = [ | ||||
|     '.github/workflows/pull.yml', | ||||
|     '.github/workflows/trunk.yml', | ||||
|     '.github/workflows/periodic.yml', | ||||
|     '.github/workflows/mac-mps.yml', | ||||
|     '.github/workflows/slow.yml', | ||||
|     '.github/workflows/*.yml', | ||||
|     '.github/workflows/*.yaml', | ||||
| ] | ||||
| command = [ | ||||
|     'python3', | ||||
|  | ||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							| @ -201,3 +201,17 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A | ||||
| /torch/csrc/stable/ @janeyx99 @mikaylagawarecki | ||||
| /torch/headeronly/ @janeyx99 | ||||
| /torch/header_only_apis.txt @janeyx99 | ||||
|  | ||||
| # FlexAttention | ||||
| /torch/nn/attention/flex_attention.py @drisspg | ||||
| /torch/_higher_order_ops/flex_attention.py @drisspg | ||||
| /torch/_inductor/kernel/flex/ @drisspg | ||||
| /torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg | ||||
| /test/inductor/test_flex_attention.py @drisspg | ||||
| /test/inductor/test_flex_decoding.py @drisspg | ||||
|  | ||||
| # Low Precision GEMMs | ||||
| /aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58 | ||||
| /aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58 | ||||
| /aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58 | ||||
| /test/test_scaled_matmul_cuda.py @drisspg @slayton58 | ||||
|  | ||||
| @ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI) | ||||
|  | ||||
|     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||
|  | ||||
|     set(fbgemm_genai_mx8mx8bf16_grouped | ||||
|     set(fbgemm_genai_cuh | ||||
|       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" | ||||
|       "${FBGEMM_GENAI_SRCS}/" | ||||
|     ) | ||||
|  | ||||
|     target_include_directories(fbgemm_genai PRIVATE | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||
|       ${fbgemm_genai_mx8mx8bf16_grouped} | ||||
|       ${fbgemm_genai_cuh} | ||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||
|     ) | ||||
| @ -313,13 +314,14 @@ IF(USE_FBGEMM_GENAI) | ||||
|  | ||||
|     # Add additional HIPCC compiler flags for performance | ||||
|     set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS | ||||
|       -mllvm | ||||
|       -amdgpu-coerce-illegal-types=1 | ||||
|       -mllvm | ||||
|       -enable-post-misched=0 | ||||
|       -mllvm | ||||
|       -greedy-reverse-local-assignment=1 | ||||
|       -fhip-new-launch-api) | ||||
|     if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0") | ||||
|         list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) | ||||
|       endif() | ||||
|  | ||||
|     # Only compile for gfx942 for now. | ||||
|     # This is rather hacky, I could not figure out a clean solution :( | ||||
|  | ||||
| @ -19,6 +19,7 @@ | ||||
| #include <ATen/detail/MPSHooksInterface.h> | ||||
| #include <ATen/detail/MTIAHooksInterface.h> | ||||
| #include <ATen/detail/PrivateUse1HooksInterface.h> | ||||
| #include <ATen/detail/XLAHooksInterface.h> | ||||
| #include <ATen/detail/XPUHooksInterface.h> | ||||
| #include <c10/core/QEngine.h> | ||||
| #include <c10/core/impl/DeviceGuardImplInterface.h> | ||||
| @ -88,6 +89,8 @@ class TORCH_API Context { | ||||
|       return at::detail::getHIPHooks(); | ||||
|     } else if (opt_device_type == at::kHPU) { | ||||
|       return at::detail::getHPUHooks(); | ||||
|     } else if (opt_device_type == at::kXLA) { | ||||
|       return at::detail::getXLAHooks(); | ||||
|     } else { | ||||
|       TORCH_CHECK( | ||||
|           false, | ||||
| @ -196,7 +199,7 @@ class TORCH_API Context { | ||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); | ||||
|   } | ||||
|   static bool hasXLA() { | ||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA); | ||||
|     return detail::getXLAHooks().hasXLA(); | ||||
|   } | ||||
|   static bool hasXPU() { | ||||
|     return detail::getXPUHooks().hasXPU(); | ||||
|  | ||||
| @ -130,6 +130,12 @@ c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) { | ||||
|   impl.uncheckedSetDevice({device_type, device_index}); | ||||
|   return impl.getDevice().index(); | ||||
| } | ||||
|  | ||||
| c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { | ||||
|   const auto device_type = getAccelerator(true).value(); | ||||
|   c10::impl::VirtualGuardImpl impl(device_type); | ||||
|   return impl.getDeviceCapability({device_type, device_index}); | ||||
| } | ||||
| // NOLINTEND(bugprone-unchecked-optional-access) | ||||
|  | ||||
| } // namespace at::accelerator | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/core/CachingDeviceAllocator.h> | ||||
| #include <c10/core/DeviceCapability.h> | ||||
| #include <c10/core/DeviceType.h> | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| @ -94,6 +95,8 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { | ||||
|   at::getDeviceAllocator(device_type)->resetPeakStats(device_index); | ||||
| } | ||||
|  | ||||
| TORCH_API c10::DeviceCapability getDeviceCapability( | ||||
|     c10::DeviceIndex device_index); | ||||
| } // namespace at::accelerator | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| @ -39,7 +39,7 @@ struct HostBlock { | ||||
| }; | ||||
|  | ||||
| template <typename B> | ||||
| struct alignas(64) FreeBlockList { | ||||
| struct alignas(hardware_destructive_interference_size) FreeBlockList { | ||||
|   std::mutex mutex_; | ||||
|   std::deque<B*> list_; | ||||
| }; | ||||
| @ -122,7 +122,7 @@ struct TORCH_API HostStats { | ||||
| // Struct containing memory allocator summary statistics for host, as they | ||||
| // are staged for reporting. This is a temporary struct that is used to | ||||
| // avoid locking the allocator while collecting stats. | ||||
| struct alignas(64) HostStatsStaged { | ||||
| struct alignas(hardware_destructive_interference_size) HostStatsStaged { | ||||
|   std::mutex timing_mutex_; | ||||
|   // COUNT: total allocations (active + free) | ||||
|   // LOCK: access to this stat is protected by the allocator's blocks_mutex_ | ||||
| @ -669,7 +669,7 @@ struct CachingHostAllocatorImpl { | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); | ||||
|   } | ||||
|  | ||||
|   alignas(64) std::mutex blocks_mutex_; | ||||
|   alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_; | ||||
|   ska::flat_hash_set<B*> blocks_; // block list | ||||
|   ska::flat_hash_map<void*, B*> ptr_to_block_; | ||||
|  | ||||
| @ -677,17 +677,17 @@ struct CachingHostAllocatorImpl { | ||||
|   // size. This allows us to quickly find a free block of the right size. | ||||
|   // We use deque to store per size free list and guard the list with its own | ||||
|   // mutex. | ||||
|   alignas(64) std::vector<FreeBlockList<B>> free_list_ = | ||||
|   alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ = | ||||
|       std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX); | ||||
|  | ||||
|   alignas(64) std::mutex events_mutex_; | ||||
|   alignas(hardware_destructive_interference_size) std::mutex events_mutex_; | ||||
|   std::deque<std::pair<E, B*>> events_; // event queue paired with block | ||||
|  | ||||
|   // Indicates whether the object is active. | ||||
|   // Set to false in the destructor to signal background threads to stop. | ||||
|   std::atomic<bool> active_{true}; | ||||
| protected: | ||||
|   alignas(64) HostStatsStaged stats_; | ||||
|   alignas(hardware_destructive_interference_size) HostStatsStaged stats_; | ||||
| }; | ||||
|  | ||||
| struct TORCH_API HostAllocator : public at::Allocator { | ||||
|  | ||||
| @ -59,9 +59,7 @@ struct TORCH_API Generator { | ||||
|  | ||||
|   explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) | ||||
|    : impl_(std::move(gen_impl)) { | ||||
|     if (impl_.get() == nullptr) { | ||||
|       throw std::runtime_error("GeneratorImpl with nullptr is not supported"); | ||||
|     } | ||||
|     TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported"); | ||||
|   } | ||||
|  | ||||
|   bool operator==(const Generator& rhs) const { | ||||
|  | ||||
| @ -111,9 +111,7 @@ class TORCH_API TensorBase { | ||||
|   explicit TensorBase( | ||||
|       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) | ||||
|       : impl_(std::move(tensor_impl)) { | ||||
|     if (impl_.get() == nullptr) { | ||||
|       throw std::runtime_error("TensorImpl with nullptr is not supported"); | ||||
|     } | ||||
|     TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported"); | ||||
|   } | ||||
|   TensorBase(const TensorBase&) = default; | ||||
|   TensorBase(TensorBase&&) noexcept = default; | ||||
|  | ||||
| @ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) { | ||||
|     return it->second; | ||||
|  | ||||
|   auto pos = s.find("::"); | ||||
|   if (pos == std::string::npos) { | ||||
|     std::stringstream ss; | ||||
|     ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|   TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s); | ||||
|   Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); | ||||
|  | ||||
|   Symbol sym(sym_to_info_.size()); | ||||
| @ -121,12 +117,7 @@ std::string Symbol::domainString() const { | ||||
| } | ||||
|  | ||||
| Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { | ||||
|   if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) { | ||||
|     std::ostringstream ss; | ||||
|     ss << "Symbol: domain string is expected to be prefixed with '" | ||||
|        << domain_prefix() << "', e.g. 'org.pytorch.aten'"; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|   TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'"); | ||||
|   std::string qualString = d.substr(domain_prefix().size()) + "::" + s; | ||||
|   return fromQualString(qualString); | ||||
| } | ||||
|  | ||||
| @ -7,6 +7,7 @@ | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <ATen/core/stack.h> | ||||
| #include <ATen/core/type_factory.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/StringUtil.h> | ||||
| #include <c10/util/hash.h> | ||||
| #include <c10/util/irange.h> | ||||
| @ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) { | ||||
|     case Tag::Enum: | ||||
|     case Tag::Stream: | ||||
|     case Tag::Uninitialized: | ||||
|       throw std::runtime_error( | ||||
|       TORCH_CHECK(false, | ||||
|           "unhashable type: '" + v.type()->repr_str() + "'"); | ||||
|   } | ||||
|   // the above switch should be exhaustive | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| #include <ATen/core/type_factory.h> | ||||
| #include <ATen/core/qualified_name.h> | ||||
| #include <c10/util/TypeList.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <optional> | ||||
| #include <c10/core/SymFloat.h> | ||||
| #include <c10/core/SymBool.h> | ||||
| @ -116,10 +117,8 @@ struct SingleElementType : public SharedType { | ||||
|  | ||||
|  protected: | ||||
|   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { | ||||
|     if (!this->elem) { | ||||
|       throw std::runtime_error(c10::str( | ||||
|     TORCH_CHECK(this->elem, c10::str( | ||||
|             "Can not create ", typeKindToString(Kind), " with None type")); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
| @ -416,16 +415,12 @@ struct TORCH_API SymbolicShape { | ||||
|   } | ||||
|  | ||||
|   ShapeSymbol operator[](size_t i) const { | ||||
|     if (!dims_) { | ||||
|       throw std::runtime_error("Rank isn't fixed"); | ||||
|     } | ||||
|     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||
|     return (*dims_).at(i); | ||||
|   } | ||||
|  | ||||
|   ShapeSymbol at(size_t i) const { | ||||
|     if (!dims_) { | ||||
|       throw std::runtime_error("Rank isn't fixed"); | ||||
|     } | ||||
|     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||
|     return (*dims_).at(i); | ||||
|   } | ||||
|  | ||||
| @ -520,9 +515,7 @@ struct VaryingShape { | ||||
|   } | ||||
|  | ||||
|   const std::optional<T> &operator[](size_t i) const { | ||||
|     if (!dims_) { | ||||
|       throw std::runtime_error("Rank isn't fixed"); | ||||
|     } | ||||
|     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||
|     return (*dims_).at(i); | ||||
|   } | ||||
|  | ||||
| @ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType { | ||||
|  | ||||
|   TypePtr createWithContained( | ||||
|       std::vector<TypePtr> contained_types) const override { | ||||
|     if (contained_types.size() != 2) { | ||||
|       throw std::runtime_error("Expected 2 contained types"); | ||||
|     } | ||||
|     TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types"); | ||||
|     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/env.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/flat_hash_map.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <array> | ||||
| @ -826,9 +827,7 @@ TupleType::TupleType( | ||||
|     : NamedType(TypeKind::TupleType, std::move(name)), | ||||
|       elements_(std::move(elements)), | ||||
|       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) { | ||||
|         if (!v) { | ||||
|           throw std::runtime_error("Can not create tuple with None type"); | ||||
|         } | ||||
|         TORCH_CHECK(v, "Can not create tuple with None type"); | ||||
|         return v->hasFreeVariables(); | ||||
|       })), schema_(std::move(schema)) { | ||||
|  | ||||
|  | ||||
| @ -6,9 +6,11 @@ | ||||
| #ifdef __aarch64__ | ||||
| #if !defined(CPU_CAPABILITY_SVE) | ||||
| #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_double_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h> | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cpu/vec/vec128/vec128_convert.h> | ||||
|  | ||||
| @ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | ||||
|  | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) | ||||
|   Vectorized frac() const; | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) | ||||
|  | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   Vectorized<c10::BFloat16> neg() const { | ||||
|     return -values; | ||||
|   } | ||||
|   Vectorized<c10::BFloat16> reciprocal() const { | ||||
|     return 1.0f / values; | ||||
|   } | ||||
|   Vectorized<c10::BFloat16> operator==( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values == other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator!=( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values != other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator<( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values < other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator<=( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values <= other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator>( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values > other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator>=( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values >= other.values; | ||||
|   } | ||||
| #else | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) | ||||
| @ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) | ||||
| #endif | ||||
|  | ||||
| #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||
| #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||
| @ -412,28 +451,52 @@ template <> | ||||
| Vectorized<c10::BFloat16> inline operator+( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x + y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<c10::BFloat16> inline operator-( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x - y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<c10::BFloat16> inline operator*( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x * y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<c10::BFloat16> inline operator/( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x / y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // frac. Implement this here so we can use subtraction | ||||
| @ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return x * y + z; | ||||
| #else | ||||
|   // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also, | ||||
|   // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered | ||||
|   // elements, not the bottom and top half, so they don't seem | ||||
|   // particularly useful here. Ideally we would include dot product in | ||||
|   // the Vectorized interface... | ||||
|   return a * b + c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return (-x) * y + z; | ||||
| #else | ||||
|   // See NOTE [BF16 FMA] above. | ||||
|   return -a * b + c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return x * y - z; | ||||
| #else | ||||
|   // See NOTE [BF16 FMA] above. | ||||
|   return a * b - c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return (-x) * y - z; | ||||
| #else | ||||
|   // See NOTE [BF16 FMA] above. | ||||
|   return -a * b - c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif // !defined(C10_MOBILE) && defined(__aarch64__) | ||||
|  | ||||
							
								
								
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										586
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_double_neon.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,586 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cpu/vec/intrinsics.h> | ||||
| #include <ATen/cpu/vec/vec_base.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <cmath> | ||||
|  | ||||
| namespace at::vec { | ||||
| // Note [CPU_CAPABILITY namespace] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // This header, and all of its subheaders, will be compiled with | ||||
| // different architecture flags for each supported set of vector | ||||
| // intrinsics. So we need to make sure they aren't inadvertently | ||||
| // linked together. We do this by declaring objects in an `inline | ||||
| // namespace` which changes the name mangling, but can still be | ||||
| // accessed as `at::vec`. | ||||
| inline namespace CPU_CAPABILITY { | ||||
|  | ||||
| template <> | ||||
| struct is_vec_specialized_for<double> : std::bool_constant<true> {}; | ||||
|  | ||||
| template <> | ||||
| class Vectorized<double> { | ||||
|  private: | ||||
|   float64x2_t values; | ||||
|  | ||||
|  public: | ||||
|   using value_type = double; | ||||
|   using size_type = int; | ||||
|   static constexpr size_type size() { | ||||
|     return 2; | ||||
|   } | ||||
|   Vectorized() { | ||||
|     values = vdupq_n_f64(0.0); | ||||
|   } | ||||
|   Vectorized(float64x2_t v) : values(v) {} | ||||
|   Vectorized(double val) { | ||||
|     values = vdupq_n_f64(val); | ||||
|   } | ||||
|   template < | ||||
|       typename... Args, | ||||
|       typename = std::enable_if_t<(sizeof...(Args) == size())>> | ||||
|   Vectorized(Args... vals) { | ||||
|     __at_align__ double buffer[size()] = {vals...}; | ||||
|     values = vld1q_f64(buffer); | ||||
|   } | ||||
|   operator float64x2_t() const { | ||||
|     return values; | ||||
|   } | ||||
|   template <int64_t mask> | ||||
|   static Vectorized<double> blend( | ||||
|       const Vectorized<double>& a, | ||||
|       const Vectorized<double>& b) { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint64x2_t maskArray = { | ||||
|         (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||
|         (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_f64(maskArray, b.values, a.values); | ||||
|   } | ||||
|   static Vectorized<double> blendv( | ||||
|       const Vectorized<double>& a, | ||||
|       const Vectorized<double>& b, | ||||
|       const Vectorized<double>& mask_) { | ||||
|     return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values); | ||||
|   } | ||||
|   template <typename step_t> | ||||
|   static Vectorized<double> arange( | ||||
|       double base = 0., | ||||
|       step_t step = static_cast<step_t>(1)) { | ||||
|     return {base, base + static_cast<double>(step)}; | ||||
|   } | ||||
|   static inline Vectorized<double> set( | ||||
|       const Vectorized<double>& a, | ||||
|       const Vectorized<double>& b, | ||||
|       int64_t count = size()) { | ||||
|     if (count == 0) { | ||||
|       return a; | ||||
|     } else if (count >= 2) { | ||||
|       return b; | ||||
|     } else { | ||||
|       float64x2_t c = {b.values[0], a.values[1]}; | ||||
|       return c; | ||||
|     } | ||||
|   } | ||||
|   static Vectorized<double> loadu(const void* ptr, int64_t count = size()) { | ||||
|     if (count == size()) { | ||||
|       return vld1q_f64(reinterpret_cast<const double*>(ptr)); | ||||
|     } else if (count == 1) { | ||||
|       float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr)); | ||||
|       float64x1_t z = {0.0}; | ||||
|       return vcombine_f64(x, z); | ||||
|     } else { | ||||
|       return vdupq_n_f64(0.0); | ||||
|     } | ||||
|   } | ||||
|   void store(void* ptr, int64_t count = size()) const { | ||||
|     if (count == size()) { | ||||
|       vst1q_f64(reinterpret_cast<double*>(ptr), values); | ||||
|     } else if (count == 1) { | ||||
|       vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values)); | ||||
|     } | ||||
|   } | ||||
|   const double& operator[](int idx) const = delete; | ||||
|   double& operator[](int idx) = delete; | ||||
|   int64_t zero_mask() const { | ||||
|     // returns an integer mask where all zero elements are translated to 1-bit | ||||
|     // and others are translated to 0-bit | ||||
|     uint64x2_t cmpReg = vceqzq_f64(values); | ||||
|     uint64x2_t mask = {1, 2}; | ||||
|     uint64x2_t res = vandq_u64(cmpReg, mask); | ||||
|     return res[0] | res[1]; | ||||
|   } | ||||
|   Vectorized<double> isnan() const { | ||||
|     // NaN check | ||||
|     return vreinterpretq_f64_u32( | ||||
|         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values)))); | ||||
|   } | ||||
|   bool has_inf_nan() const { | ||||
|     Vectorized<double> x = vsubq_f64(values, values); | ||||
|     float64x2_t r = x.isnan(); | ||||
|     uint64x2_t u = vreinterpretq_u64_f64(r); | ||||
|     return u[0] | u[1]; | ||||
|   } | ||||
|   Vectorized<double> map(double (*f)(double)) const { | ||||
|     float64x2_t result; | ||||
|     result[0] = f(values[0]); | ||||
|     result[1] = f(values[1]); | ||||
|     return result; | ||||
|   } | ||||
|   Vectorized<double> map2( | ||||
|       const Vectorized<double>& second, | ||||
|       double (*const f)(double, double)) const { | ||||
|     float64x2_t result; | ||||
|     result[0] = f(values[0], second.values[0]); | ||||
|     result[1] = f(values[1], second.values[1]); | ||||
|     return result; | ||||
|   } | ||||
|   Vectorized<double> abs() const { | ||||
|     return vabsq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> angle() const { | ||||
|     auto zero = Vectorized<double>(0.0); | ||||
|     auto pi = Vectorized<double>(c10::pi<double>); | ||||
|     auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values))); | ||||
|     return blendv(tmp, *this, isnan()); | ||||
|   } | ||||
|   Vectorized<double> real() const { | ||||
|     return *this; | ||||
|   } | ||||
|   Vectorized<double> imag() const { | ||||
|     return Vectorized<double>(0.0); | ||||
|   } | ||||
|   Vectorized<double> conj() const { | ||||
|     return *this; | ||||
|   } | ||||
|   Vectorized<double> acos() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos)); | ||||
|   } | ||||
|   Vectorized<double> acosh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh)); | ||||
|   } | ||||
|   Vectorized<double> asin() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin)); | ||||
|   } | ||||
|   Vectorized<double> asinh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh)); | ||||
|   } | ||||
|   Vectorized<double> atan() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan)); | ||||
|   } | ||||
|   Vectorized<double> atanh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh)); | ||||
|   } | ||||
|   Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_b[size()]; | ||||
|         store(tmp); | ||||
|         b.store(tmp_b); | ||||
|         for (int64_t i = 0; i < size(); i++) { | ||||
|           tmp[i] = std::atan2(tmp[i], tmp_b[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} Vectorized<double> copysign(const Vectorized<double>& sign) const { | ||||
|       USE_SLEEF( | ||||
|           { return Vectorized<double>(Sleef_copysignd2(values, sign)); }, | ||||
|           { | ||||
|             __at_align__ double tmp[size()]; | ||||
|             __at_align__ double tmp_sign[size()]; | ||||
|             store(tmp); | ||||
|             sign.store(tmp_sign); | ||||
|             for (int64_t i = 0; i < size(); i++) { | ||||
|               tmp[i] = std::copysign(tmp[i], tmp_sign[i]); | ||||
|             } | ||||
|             return loadu(tmp); | ||||
|           })} Vectorized<double> erf() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf)); | ||||
|   } | ||||
|   Vectorized<double> erfc() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc)); | ||||
|   } | ||||
|   Vectorized<double> exp() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp)); | ||||
|   } | ||||
|   Vectorized<double> exp2() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2)); | ||||
|   } | ||||
|   Vectorized<double> expm1() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1)); | ||||
|   } | ||||
|   Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_fmodd2(values, q)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_q[size()]; | ||||
|         store(tmp); | ||||
|         q.store(tmp_q); | ||||
|         for (int64_t i = 0; i < size(); i++) { | ||||
|           tmp[i] = std::fmod(tmp[i], tmp_q[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} Vectorized<double> hypot(const Vectorized<double>& b) const { | ||||
|       USE_SLEEF( | ||||
|           { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); }, | ||||
|           { | ||||
|             __at_align__ double tmp[size()]; | ||||
|             __at_align__ double tmp_b[size()]; | ||||
|             store(tmp); | ||||
|             b.store(tmp_b); | ||||
|             for (int64_t i = 0; i < size(); i++) { | ||||
|               tmp[i] = std::hypot(tmp[i], tmp_b[i]); | ||||
|             } | ||||
|             return loadu(tmp); | ||||
|           })} Vectorized<double> i0() const { | ||||
|     return map(calc_i0); | ||||
|   } | ||||
|   Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_nextafterd2(values, b)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_b[size()]; | ||||
|         store(tmp); | ||||
|         b.store(tmp_b); | ||||
|         for (int64_t i = 0; i < size(); ++i) { | ||||
|           tmp[i] = std::nextafter(tmp[i], tmp_b[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} Vectorized<double> log() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_logd2_u10(values)), map(std::log)); | ||||
|   } | ||||
|   Vectorized<double> log2() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2)); | ||||
|   } | ||||
|   Vectorized<double> log10() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10)); | ||||
|   } | ||||
|   Vectorized<double> log1p() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p)); | ||||
|   } | ||||
|   Vectorized<double> frac() const; | ||||
|   Vectorized<double> sin() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin)); | ||||
|   } | ||||
|   Vectorized<double> sinh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh)); | ||||
|   } | ||||
|   Vectorized<double> cos() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos)); | ||||
|   } | ||||
|   Vectorized<double> cosh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh)); | ||||
|   } | ||||
|   Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_powd2_u10(values, b)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_b[size()]; | ||||
|         store(tmp); | ||||
|         b.store(tmp_b); | ||||
|         for (int64_t i = 0; i < size(); i++) { | ||||
|           tmp[i] = std::pow(tmp[i], tmp_b[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} // Comparison using the _CMP_**_OQ predicate. | ||||
|           //   `O`: get false if an operand is NaN | ||||
|           //   `Q`: do not raise if an operand is NaN | ||||
|   Vectorized<double> tan() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan)); | ||||
|   } | ||||
|   Vectorized<double> tanh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh)); | ||||
|   } | ||||
|   Vectorized<double> lgamma() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma)); | ||||
|   } | ||||
|   Vectorized<double> erfinv() const { | ||||
|     return map(calc_erfinv); | ||||
|   } | ||||
|   Vectorized<double> exp_u20() const { | ||||
|     return exp(); | ||||
|   } | ||||
|   Vectorized<double> fexp_u20() const { | ||||
|     return exp(); | ||||
|   } | ||||
|   Vectorized<double> i0e() const { | ||||
|     return map(calc_i0e); | ||||
|   } | ||||
|   Vectorized<double> digamma() const { | ||||
|     return map(calc_digamma); | ||||
|   } | ||||
|   Vectorized<double> igamma(const Vectorized<double>& x) const { | ||||
|     __at_align__ double tmp[size()]; | ||||
|     __at_align__ double tmp_x[size()]; | ||||
|     store(tmp); | ||||
|     x.store(tmp_x); | ||||
|     for (int64_t i = 0; i < size(); i++) { | ||||
|       tmp[i] = calc_igamma(tmp[i], tmp_x[i]); | ||||
|     } | ||||
|     return loadu(tmp); | ||||
|   } | ||||
|   Vectorized<double> igammac(const Vectorized<double>& x) const { | ||||
|     __at_align__ double tmp[size()]; | ||||
|     __at_align__ double tmp_x[size()]; | ||||
|     store(tmp); | ||||
|     x.store(tmp_x); | ||||
|     for (int64_t i = 0; i < size(); i++) { | ||||
|       tmp[i] = calc_igammac(tmp[i], tmp_x[i]); | ||||
|     } | ||||
|     return loadu(tmp); | ||||
|   } | ||||
|   Vectorized<double> ceil() const { | ||||
|     return vrndpq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> floor() const { | ||||
|     return vrndmq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> neg() const { | ||||
|     return vnegq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> round() const { | ||||
|     return vrndiq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> trunc() const { | ||||
|     return vrndq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> sqrt() const { | ||||
|     return vsqrtq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> reciprocal() const { | ||||
|     return vdivq_f64(vdupq_n_f64(1.0), values); | ||||
|   } | ||||
|   Vectorized<double> rsqrt() const { | ||||
|     return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values)); | ||||
|   } | ||||
|   double reduce_add() const { | ||||
|     return vaddvq_f64(values); | ||||
|   } | ||||
|   double reduce_max() const { | ||||
|     return vmaxvq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> operator==(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vceqq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator!=(const Vectorized<double>& other) const { | ||||
|     float64x2_t r0 = vreinterpretq_f64_u32( | ||||
|         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values)))); | ||||
|     return Vectorized<double>(r0); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator<(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcltq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator<=(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcleq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator>(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcgtq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator>=(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcgeq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> eq(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> ne(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> gt(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> ge(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> lt(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> le(const Vectorized<double>& other) const; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator+( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vaddq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator-( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vsubq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator*( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vmulq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator/( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vdivq_f64(a, b); | ||||
| } | ||||
|  | ||||
| // frac. Implement this here so we can use subtraction | ||||
| Vectorized<double> inline Vectorized<double>::frac() const { | ||||
|   return *this - this->trunc(); | ||||
| } | ||||
|  | ||||
| // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if | ||||
| // either input is a NaN. | ||||
| template <> | ||||
| Vectorized<double> inline maximum( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vmaxq_f64(a, b); | ||||
| } | ||||
|  | ||||
| // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if | ||||
| // either input is a NaN. | ||||
| template <> | ||||
| Vectorized<double> inline minimum( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vminq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline clamp( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& min, | ||||
|     const Vectorized<double>& max) { | ||||
|   return vminq_f64(max, vmaxq_f64(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline clamp_max( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& max) { | ||||
|   return vminq_f64(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline clamp_min( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& min) { | ||||
|   return vmaxq_f64(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator&( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vreinterpretq_f64_u64( | ||||
|       vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator|( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vreinterpretq_f64_u64( | ||||
|       vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator^( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vreinterpretq_f64_u64( | ||||
|       veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::eq( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this == other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::ne( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this != other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::gt( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this > other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::ge( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this >= other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::lt( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this < other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::le( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this <= other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fmadd( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmaq_f64(c, a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fnmadd( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmsq_f64(c, a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fmsub( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmaq_f64(vnegq_f64(c), a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fnmsub( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmsq_f64(vnegq_f64(c), a, b); | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
							
								
								
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										378
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_uint_aarch64.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,378 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cpu/vec/intrinsics.h> | ||||
| #include <ATen/cpu/vec/vec_base.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/irange.h> | ||||
|  | ||||
| namespace at::vec { | ||||
| // Note [CPU_CAPABILITY namespace] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // This header, and all of its subheaders, will be compiled with | ||||
| // different architecture flags for each supported set of vector | ||||
| // intrinsics. So we need to make sure they aren't inadvertently | ||||
| // linked together. We do this by declaring objects in an `inline | ||||
| // namespace` which changes the name mangling, but can still be | ||||
| // accessed as `at::vec`. | ||||
| inline namespace CPU_CAPABILITY { | ||||
|  | ||||
| #define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \ | ||||
|   template <>                                                                 \ | ||||
|   struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \ | ||||
|                                                                               \ | ||||
|   template <>                                                                 \ | ||||
|   class Vectorized<uint##bit##_t> {                                           \ | ||||
|     using neon_type = uint##bit##x##vl##_t;                                   \ | ||||
|                                                                               \ | ||||
|    private:                                                                   \ | ||||
|     neon_type values;                                                         \ | ||||
|                                                                               \ | ||||
|    public:                                                                    \ | ||||
|     using value_type = uint##bit##_t;                                         \ | ||||
|     using size_type = int;                                                    \ | ||||
|     static constexpr size_type size() {                                       \ | ||||
|       return vl;                                                              \ | ||||
|     }                                                                         \ | ||||
|     Vectorized() {                                                            \ | ||||
|       values = vdupq_n_u##bit(0);                                             \ | ||||
|     }                                                                         \ | ||||
|     Vectorized(neon_type v) : values(v) {}                                    \ | ||||
|     Vectorized(uint##bit##_t val);                                            \ | ||||
|     template <                                                                \ | ||||
|         typename... Args,                                                     \ | ||||
|         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||
|     Vectorized(Args... vals) {                                                \ | ||||
|       __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \ | ||||
|       values = vld1q_u##bit(buffer);                                          \ | ||||
|     }                                                                         \ | ||||
|     operator neon_type() const {                                              \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     static Vectorized<uint##bit##_t> loadu(                                   \ | ||||
|         const void* ptr,                                                      \ | ||||
|         uint64_t count = size());                                             \ | ||||
|     void store(void* ptr, uint64_t count = size()) const;                     \ | ||||
|     template <uint64_t mask>                                                  \ | ||||
|     static Vectorized<uint##bit##_t> blend(                                   \ | ||||
|         const Vectorized<uint##bit##_t>& a,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& b);                                  \ | ||||
|     static Vectorized<uint##bit##_t> blendv(                                  \ | ||||
|         const Vectorized<uint##bit##_t>& a,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& b,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& mask_) {                             \ | ||||
|       return vbslq_u##bit(mask_.values, b, a);                                \ | ||||
|     }                                                                         \ | ||||
|     template <typename step_t>                                                \ | ||||
|     static Vectorized<uint##bit##_t> arange(                                  \ | ||||
|         value_type base = 0,                                                  \ | ||||
|         step_t step = static_cast<step_t>(1));                                \ | ||||
|     static Vectorized<uint##bit##_t> set(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& a,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& b,                                   \ | ||||
|         uint64_t count = size());                                             \ | ||||
|     const uint##bit##_t& operator[](uint idx) const = delete;                 \ | ||||
|     uint##bit##_t& operator[](uint idx) = delete;                             \ | ||||
|     Vectorized<uint##bit##_t> abs() const {                                   \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> real() const {                                  \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> imag() const {                                  \ | ||||
|       return vdupq_n_u##bit(0);                                               \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> conj() const {                                  \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> neg() const {                                   \ | ||||
|       return vreinterpretq_u##bit##_s##bit(                                   \ | ||||
|           vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \ | ||||
|     }                                                                         \ | ||||
|     uint##bit##_t reduce_add() const {                                        \ | ||||
|       return vaddvq_u##bit(values);                                           \ | ||||
|     }                                                                         \ | ||||
|     uint##bit##_t reduce_max() const;                                         \ | ||||
|     Vectorized<uint##bit##_t> operator==(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator!=(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> operator<(                                      \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator<=(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator>(                                      \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator>=(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> eq(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> ne(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> gt(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> ge(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> lt(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> le(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|   };                                                                          \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator+(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vaddq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator-(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vsubq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator&(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vandq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator|(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vorrq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator^(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return veorq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this == other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this != other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this > other) & Vectorized<uint##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this < other) & Vectorized<uint##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   } | ||||
|  | ||||
| VEC_UINT_NEON_TEMPLATE(16, 8) | ||||
|  | ||||
| inline uint8_t Vectorized<uint8_t>::reduce_max() const { | ||||
|   return vmaxvq_u8(values); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator*( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   return vmulq_u8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) { | ||||
|   return vmvnq_u8(a); | ||||
| } | ||||
|  | ||||
| inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=( | ||||
|     const Vectorized<uint8_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline minimum( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   return vminq_u8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline maximum( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   return vmaxq_u8(a, b); | ||||
| } | ||||
|  | ||||
| template <uint64_t mask> | ||||
| Vectorized<uint8_t> Vectorized<uint8_t>::blend( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint8x16_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFF : 0, | ||||
|       (mask & 2LL) ? 0xFF : 0, | ||||
|       (mask & 4LL) ? 0xFF : 0, | ||||
|       (mask & 8LL) ? 0xFF : 0, | ||||
|       (mask & 16LL) ? 0xFF : 0, | ||||
|       (mask & 32LL) ? 0xFF : 0, | ||||
|       (mask & 64LL) ? 0xFF : 0, | ||||
|       (mask & 128LL) ? 0xFF : 0, | ||||
|       (mask & 256LL) ? 0xFF : 0, | ||||
|       (mask & 512LL) ? 0xFF : 0, | ||||
|       (mask & 1024LL) ? 0xFF : 0, | ||||
|       (mask & 2048LL) ? 0xFF : 0, | ||||
|       (mask & 4096LL) ? 0xFF : 0, | ||||
|       (mask & 8192LL) ? 0xFF : 0, | ||||
|       (mask & 16384LL) ? 0xFF : 0, | ||||
|       (mask & 32768LL) ? 0xFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_u8(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| #define VEC_UINT_NEON_OPS(vl, bit)                                             \ | ||||
|   inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \ | ||||
|     values = vdupq_n_u##bit(val);                                              \ | ||||
|   }                                                                            \ | ||||
|   inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \ | ||||
|       const void* ptr, uint64_t count) {                                       \ | ||||
|     if (count == size()) {                                                     \ | ||||
|       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \ | ||||
|     } else {                                                                   \ | ||||
|       __at_align__ uint##bit##_t tmp_values[size()];                           \ | ||||
|       for (const auto i : c10::irange(size())) {                               \ | ||||
|         tmp_values[i] = 0;                                                     \ | ||||
|       }                                                                        \ | ||||
|       std::memcpy(                                                             \ | ||||
|           tmp_values,                                                          \ | ||||
|           reinterpret_cast<const uint##bit##_t*>(ptr),                         \ | ||||
|           count * sizeof(uint##bit##_t));                                      \ | ||||
|       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \ | ||||
|     }                                                                          \ | ||||
|   }                                                                            \ | ||||
|   inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \ | ||||
|       const {                                                                  \ | ||||
|     if (count == size()) {                                                     \ | ||||
|       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \ | ||||
|     } else {                                                                   \ | ||||
|       uint##bit##_t tmp_values[size()];                                        \ | ||||
|       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \ | ||||
|       std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \ | ||||
|     }                                                                          \ | ||||
|   } | ||||
|  | ||||
| VEC_UINT_NEON_OPS(16, 8) | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<uint8_t> Vectorized<uint8_t>::arange( | ||||
|     uint8_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<uint8_t> base_vec(base); | ||||
|   const Vectorized<uint8_t> step_vec(step); | ||||
|   const uint8x16_t step_sizes = { | ||||
|       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||
|   return vmlaq_u8(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator>>( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   uint8x16_t x = a; | ||||
|   uint8x16_t bound = vdupq_n_u8(8); | ||||
|   uint8x16_t z = vminq_u8(b, bound); | ||||
|   return x >> z; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator<<( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   uint8x16_t bound = vdupq_n_u8(8); | ||||
|   uint8x16_t z = vminq_u8(b, bound); | ||||
|   return vshlq_u8(a, vreinterpretq_s8_u8(z)); | ||||
| } | ||||
|  | ||||
| inline Vectorized<uint8_t> Vectorized<uint8_t>::set( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b, | ||||
|     uint64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 16) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint8x16_t maskArray = { | ||||
|         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||
|         0}; | ||||
|  | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_u8(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator/( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   uint8x16_t x = a; | ||||
|   uint8x16_t y = b; | ||||
|   return x / y; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline clamp( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& min, | ||||
|     const Vectorized<uint8_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline clamp_max( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline clamp_min( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
| @ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|  | ||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|     at::vec::Vectorized<uint8_t> src) { | ||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); | ||||
|   auto u8x8 = vget_low_u8(src); | ||||
|   auto u16x8 = vmovl_u8(u8x8); | ||||
|   auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); | ||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||
| @ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float( | ||||
|  | ||||
| Vectorized<float> inline convert_int8_half_register_to_float( | ||||
|     at::vec::Vectorized<uint8_t> src) { | ||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); | ||||
|   auto u8x8 = vget_low_u8(src); | ||||
|   auto u16x8 = vmovl_u8(u8x8); | ||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||
|  | ||||
|  | ||||
							
								
								
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								aten/src/ATen/cuda/CUDAGreenContext.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,192 @@ | ||||
| #include <ATen/cuda/CUDAGreenContext.h> | ||||
|  | ||||
| namespace at::cuda { | ||||
|   GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     int driver_version; | ||||
|     C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); | ||||
|     TORCH_CHECK( | ||||
|         driver_version >= 12080, "cuda driver too old to use green context!"); | ||||
|     CUcontext pctx = nullptr; | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); | ||||
|     if (C10_UNLIKELY(!pctx)) { | ||||
|       TORCH_WARN( | ||||
|           "Attempted to create a green context but" | ||||
|           " there was no primary context! Creating a primary context..."); | ||||
|  | ||||
|       cudaFree(0); | ||||
|     } | ||||
|  | ||||
|     CUdevice device; | ||||
|     device_id_ = device_id; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); | ||||
|  | ||||
|     // Get device resources | ||||
|     CUdevResource device_resource; | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( | ||||
|         device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); | ||||
|  | ||||
|     // Split resources | ||||
|     std::vector<CUdevResource> result(1); | ||||
|     auto result_data = result.data(); | ||||
|     unsigned int nb_groups = 1; | ||||
|     CUdevResource remaining; | ||||
|  | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( | ||||
|             result_data, | ||||
|             &nb_groups, | ||||
|             &device_resource, | ||||
|             &remaining, | ||||
|             0, // default flags | ||||
|             num_sms)); | ||||
|  | ||||
|     TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); | ||||
|  | ||||
|     // Generate resource descriptor | ||||
|     CUdevResourceDesc desc; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( | ||||
|             &desc, result_data, 1)); | ||||
|  | ||||
|     // Create green context | ||||
|     // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: | ||||
|     // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( | ||||
|         &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||||
|  | ||||
|     // Convert to regular context | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); | ||||
|     TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   std::unique_ptr<GreenContext> GreenContext::create( | ||||
|       uint32_t num_sms, | ||||
|       std::optional<uint32_t> device_id) { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     if (!device_id.has_value()) { | ||||
|       device_id = at::cuda::current_device(); | ||||
|     } | ||||
|     return std::make_unique<GreenContext>(device_id.value(), num_sms); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Implement move operations | ||||
|   GreenContext::GreenContext(GreenContext&& other) noexcept{ | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     device_id_ = std::exchange(other.device_id_, -1); | ||||
|     green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||
|     context_ = std::exchange(other.context_, nullptr); | ||||
|     parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     if (this != &other) { | ||||
|       // Clean up current resources | ||||
|       if (green_ctx_) { | ||||
|         CUcontext current = nullptr; | ||||
|         C10_CUDA_DRIVER_CHECK( | ||||
|             c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||
|         if (current == context_) { | ||||
|           TORCH_CHECK( | ||||
|               false, | ||||
|               "attempting to overwrite current green ctx " | ||||
|               "when it is active!"); | ||||
|         } | ||||
|         C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||
|       } | ||||
|  | ||||
|       // Take ownership of other's resources | ||||
|       device_id_ = std::exchange(other.device_id_, -1); | ||||
|       green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||
|       context_ = std::exchange(other.context_, nullptr); | ||||
|       parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||
|     } | ||||
|     return *this; | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   GreenContext::~GreenContext() noexcept{ | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Get the underlying CUDA context | ||||
|   CUcontext GreenContext::getContext() const { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     return context_; | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Get the underlying green context | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|   CUgreenCtx GreenContext::getGreenContext() const { | ||||
|     return green_ctx_; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   // Make this context current | ||||
|   void GreenContext::setContext() { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     auto current_stream = c10::cuda::getCurrentCUDAStream(); | ||||
|     parent_stream_ = current_stream.stream(); | ||||
|  | ||||
|     at::cuda::CUDAEvent ev; | ||||
|     ev.record(current_stream); | ||||
|  | ||||
|     CUcontext current = nullptr; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||
|     if (!current) { | ||||
|       C10_CUDA_DRIVER_CHECK( | ||||
|           c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); | ||||
|     } else { | ||||
|       C10_CUDA_DRIVER_CHECK( | ||||
|           c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); | ||||
|     } | ||||
|     // currently hardcodes the new green context to use the default stream | ||||
|     // TODO(eqy): consider creating a new stream if e.g., it allows interop | ||||
|     // with CUDA Graph captures etc. | ||||
|     auto default_stream = c10::cuda::getDefaultCUDAStream(); | ||||
|     ev.block(default_stream); | ||||
|     c10::cuda::setCurrentCUDAStream(default_stream); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   void GreenContext::popContext() { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     // see above note about stream being hardcoded to the default stream | ||||
|     at::cuda::CUDAEvent ev; | ||||
|     ev.record(c10::cuda::getCurrentCUDAStream()); | ||||
|     CUcontext popped; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         popped == context_, "expected popped context to be the current ctx"); | ||||
|     ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
| } // namespace at::cuda | ||||
							
								
								
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								aten/src/ATen/cuda/CUDAGreenContext.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | ||||
| #pragma once | ||||
| #include <ATen/cuda/CUDAEvent.h> | ||||
|  | ||||
| #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) | ||||
| #include <c10/cuda/driver_api.h> | ||||
| #include <cuda.h> | ||||
| #include <memory> | ||||
| #include <stdexcept> | ||||
| #include <vector> | ||||
| #define CUDA_HAS_GREEN_CONTEXT 1 | ||||
| #else | ||||
| #define CUDA_HAS_GREEN_CONTEXT 0 | ||||
| #endif | ||||
|  | ||||
| namespace at::cuda { | ||||
|  | ||||
| class TORCH_CUDA_CPP_API GreenContext { | ||||
|  public: | ||||
|   GreenContext(uint32_t device_id, uint32_t num_sms); | ||||
|  | ||||
|   static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id); | ||||
|  | ||||
|   // Delete copy constructor and assignment | ||||
|   GreenContext(const GreenContext&) = delete; | ||||
|   GreenContext& operator=(const GreenContext&) = delete; | ||||
|  | ||||
|   // Implement move operations | ||||
|   GreenContext(GreenContext&& other) noexcept; | ||||
|   GreenContext& operator=(GreenContext&& other) noexcept; | ||||
|   ~GreenContext() noexcept; | ||||
|  | ||||
|   // Get the underlying CUDA context | ||||
|   CUcontext getContext() const; | ||||
|  | ||||
|   // Get the underlying green context | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|   CUgreenCtx getGreenContext() const; | ||||
| #endif | ||||
|  | ||||
|   // Make this context current | ||||
|   void setContext(); | ||||
|  | ||||
|   void popContext(); | ||||
|  | ||||
|  private: | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|   int32_t device_id_ = -1; | ||||
|   CUgreenCtx green_ctx_ = nullptr; | ||||
|   CUcontext context_ = nullptr; | ||||
|   cudaStream_t parent_stream_ = nullptr; | ||||
| #endif | ||||
| }; | ||||
| } // namespace at::cuda | ||||
| @ -183,11 +183,6 @@ struct CUDACachingHostAllocatorImpl | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   bool pinned_use_background_threads() override { | ||||
|     return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: | ||||
|         pinned_use_background_threads(); | ||||
|   } | ||||
|  | ||||
|   EventPool::Event create_event_internal(DeviceIndex idx) { | ||||
|     // Leak the event pool to avoid shutdown issue. | ||||
|     static auto* event_pool = new EventPool(); | ||||
|  | ||||
| @ -70,11 +70,7 @@ | ||||
| #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() | ||||
| #endif | ||||
|  | ||||
| #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| namespace at_cuda_detail { | ||||
| #endif | ||||
| #if defined(USE_ROCM) | ||||
|  | ||||
| // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 | ||||
|  | ||||
| @ -96,10 +92,6 @@ template <> | ||||
| struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: | ||||
|        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| } // namespace at_cuda_detail | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| @ -121,7 +113,7 @@ struct cuda_type<c10::Half> { | ||||
|   using type = __half; | ||||
| }; | ||||
|  | ||||
| #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() | ||||
| #if !defined(USE_ROCM) | ||||
|  | ||||
| template<> | ||||
| struct cuda_type<c10::BFloat16> { | ||||
| @ -177,7 +169,6 @@ inline void segmented_sort_pairs( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> | ||||
| inline void unique_by_key( | ||||
|   KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, | ||||
| @ -193,7 +184,6 @@ inline void unique_by_key( | ||||
|   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, | ||||
|     keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| namespace impl { | ||||
|  | ||||
| @ -205,36 +195,6 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera | ||||
|   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); | ||||
| } | ||||
|  | ||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() | ||||
| template<typename ValueT, typename InputIteratorT> | ||||
| struct chained_iterator { | ||||
|   using iterator_category = std::random_access_iterator_tag; | ||||
|   using difference_type   = std::ptrdiff_t; | ||||
|   using value_type        = ValueT; | ||||
|   using pointer           = ValueT*; | ||||
|   using reference         = ValueT&; | ||||
|  | ||||
|   InputIteratorT iter; | ||||
|   ValueT *first; | ||||
|   difference_type offset = 0; | ||||
|  | ||||
|   __device__ ValueT operator[](difference_type i) { | ||||
|     i +=  offset; | ||||
|     if (i == 0) { | ||||
|       return *first; | ||||
|     } else { | ||||
|       return ValueT(iter[i - 1]); | ||||
|     } | ||||
|   } | ||||
|   __device__ chained_iterator operator+(difference_type i) { | ||||
|     return chained_iterator{iter, first, i}; | ||||
|   } | ||||
|   __device__ ValueT operator*() { | ||||
|     return (*this)[0]; | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, | ||||
| // so split at int_max/2 | ||||
| constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30 | ||||
| @ -279,25 +239,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         first_elem_ptr, | ||||
|         scan_op); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() | ||||
|     using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>; | ||||
|     using tuple = typename ArgIndexInputIterator::value_type; | ||||
|     auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  { | ||||
|       if (x.key == 0) { | ||||
|         return *first_elem_ptr; | ||||
|       } else { | ||||
|         return x.value; | ||||
|       } | ||||
|     }; | ||||
|     auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)( | ||||
|       ArgIndexInputIterator(input + i), input_iter_transform); | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, | ||||
|         input_, | ||||
|         output + i, | ||||
|         scan_op, | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #else | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||
|         input + i + 1, | ||||
|         output + i, | ||||
| @ -305,7 +246,6 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #endif | ||||
|   } | ||||
| #endif | ||||
| } | ||||
| @ -557,16 +497,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         first_elem_ptr, | ||||
|         scan_op); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() | ||||
|     auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{ | ||||
|       input + i, first_elem_ptr}; | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, | ||||
|         input_, | ||||
|         output + i, | ||||
|         scan_op, | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #else | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||
|         input + i, | ||||
|         output + i, | ||||
| @ -574,12 +504,10 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #endif | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|  | ||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> | ||||
| inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { | ||||
| @ -607,7 +535,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> | ||||
| void unique(InputIteratorT input, OutputIteratorT output, | ||||
|  | ||||
| @ -10,14 +10,6 @@ | ||||
| #define CUB_VERSION 200001 | ||||
| #endif | ||||
|  | ||||
| // cub sort support for __nv_bfloat16 is added to cub 1.13 in: | ||||
| // https://github.com/NVIDIA/cub/pull/306 | ||||
| #if CUB_VERSION >= 101300 | ||||
| #define CUB_SUPPORTS_NV_BFLOAT16() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_NV_BFLOAT16() false | ||||
| #endif | ||||
|  | ||||
| // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: | ||||
| // https://github.com/NVIDIA/cub/pull/326 | ||||
| // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake | ||||
| @ -28,30 +20,6 @@ | ||||
| #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | ||||
| #endif | ||||
|  | ||||
| // cub support for UniqueByKey is added to cub 1.16 in: | ||||
| // https://github.com/NVIDIA/cub/pull/405 | ||||
| #if CUB_VERSION >= 101600 | ||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() false | ||||
| #endif | ||||
|  | ||||
| // cub support for scan by key is added to cub 1.15 | ||||
| // in https://github.com/NVIDIA/cub/pull/376 | ||||
| #if CUB_VERSION >= 101500 | ||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 1 | ||||
| #else | ||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 0 | ||||
| #endif | ||||
|  | ||||
| // cub support for cub::FutureValue is added to cub 1.15 in: | ||||
| // https://github.com/NVIDIA/cub/pull/305 | ||||
| #if CUB_VERSION >= 101500 | ||||
| #define CUB_SUPPORTS_FUTURE_VALUE() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_FUTURE_VALUE() false | ||||
| #endif | ||||
|  | ||||
| // There were many bc-breaking changes in major version release of CCCL v3.0.0 | ||||
| // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html | ||||
| #if CUB_VERSION >= 200800 | ||||
|  | ||||
							
								
								
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								aten/src/ATen/detail/XLAHooksInterface.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | ||||
| #include <ATen/detail/XLAHooksInterface.h> | ||||
|  | ||||
| namespace at { | ||||
| namespace detail { | ||||
|  | ||||
| const XLAHooksInterface& getXLAHooks() { | ||||
|   auto create_impl = [] { | ||||
|     // Create XLA hooks using the registry | ||||
|     auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{}); | ||||
|     if (hooks) { | ||||
|       return hooks; | ||||
|     } | ||||
|     // If hooks creation fails, fall back to default implementation | ||||
|     return std::make_unique<XLAHooksInterface>(); | ||||
|   }; | ||||
|   static auto hooks = create_impl(); | ||||
|   return *hooks; | ||||
| } | ||||
| } // namespace detail | ||||
|  | ||||
| C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs) | ||||
|  | ||||
| } // namespace at | ||||
							
								
								
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								aten/src/ATen/detail/XLAHooksInterface.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/core/Device.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/Registry.h> | ||||
|  | ||||
| #include <ATen/detail/AcceleratorHooksInterface.h> | ||||
|  | ||||
| C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| constexpr const char* XLA_HELP = | ||||
|   "This error has occurred because you are trying " | ||||
|   "to use some XLA functionality, but the XLA library has not been " | ||||
|   "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`"; | ||||
|  | ||||
| struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface { | ||||
|   ~XLAHooksInterface() override = default; | ||||
|  | ||||
|   void init() const override { | ||||
|     TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   virtual bool hasXLA() const { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   virtual std::string showConfig() const { | ||||
|     TORCH_CHECK( | ||||
|         false, | ||||
|         "Cannot query detailed XLA version without torch_xla library. ", | ||||
|         XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   const Generator& getDefaultGenerator( | ||||
|       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||
|     TORCH_CHECK( | ||||
|         false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   Generator getNewGenerator( | ||||
|       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||
|     TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   virtual DeviceIndex getCurrentDevice() const override { | ||||
|     TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   Device getDeviceFromPtr(void* /*data*/) const override { | ||||
|     TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   Allocator* getPinnedMemoryAllocator() const override { | ||||
|     TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   bool isPinnedPtr(const void* data) const override { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   bool hasPrimaryContext(DeviceIndex device_index) const override { | ||||
|     TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
| }; | ||||
|  | ||||
| struct TORCH_API XLAHooksArgs {}; | ||||
|  | ||||
| TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs); | ||||
| #define REGISTER_XLA_HOOKS(clsname) \ | ||||
|   C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname) | ||||
|  | ||||
| namespace detail { | ||||
| TORCH_API const XLAHooksInterface& getXLAHooks(); | ||||
| } // namespace detail | ||||
| } // namespace at | ||||
| C10_DIAGNOSTIC_POP() | ||||
| @ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ | ||||
|   DispatchKey::CUDA, | ||||
|   DispatchKey::CPU, | ||||
|   DispatchKey::PrivateUse1, | ||||
|   DispatchKey::SparseCPU, | ||||
|   DispatchKey::SparseCUDA, | ||||
|   DispatchKey::SparseCsrCPU, | ||||
|   DispatchKey::SparseCsrCUDA, | ||||
| }); | ||||
|  | ||||
| inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { | ||||
|  | ||||
| @ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) | ||||
|     try { | ||||
|       mkldnn_matmul_i8i8i32(self, mat2, result); | ||||
|       dispatched = true; | ||||
|     } catch (const std::exception& e) { | ||||
|     } catch ([[maybe_unused]] const std::exception& e) { | ||||
|       TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -11,6 +11,8 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto | ||||
|               "pixel_shuffle expects a positive upscale_factor, but got ", | ||||
|               upscale_factor); | ||||
|   int64_t c = self.size(-3); | ||||
|   TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor, | ||||
|         "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor); | ||||
|   int64_t upscale_factor_squared = upscale_factor * upscale_factor; | ||||
|   TORCH_CHECK(c % upscale_factor_squared == 0, | ||||
|               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " | ||||
|  | ||||
| @ -259,11 +259,20 @@ inline void winograd_f2k3_input_transform_inplace__rvv( | ||||
|   const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4); | ||||
|   const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4); | ||||
|   const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4); | ||||
|  | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3); | ||||
|   /* GCC 14.2 (RISC-V RVV) ICE workaround: | ||||
|    * Avoid single-statement read-modify-write on MEM_REF like: | ||||
|    *   *input_tile_val = | ||||
|    *     __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val); | ||||
|    * This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin) | ||||
|    * with -march=rv64gcv. Use a temporary then write back. | ||||
|    * Do NOT refactor into the single-statement form. Clang is unaffected. | ||||
|    */ | ||||
|   vfloat32m1x4_t tmp_input_tile_val = *input_tile_val; | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0); | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1); | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2); | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3); | ||||
|   *input_tile_val = tmp_input_tile_val; | ||||
| } | ||||
|  | ||||
| inline void winograd_f2k3_output_transform_inplace__rvv( | ||||
| @ -277,9 +286,15 @@ inline void winograd_f2k3_output_transform_inplace__rvv( | ||||
|   const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4); | ||||
|   const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4); | ||||
|   const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4); | ||||
|  | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1); | ||||
|   /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above. | ||||
|    * Keep the temporary + write-back pattern to avoid ICE. | ||||
|    * Do NOT rewrite into: | ||||
|    *   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val); | ||||
|    */ | ||||
|   vfloat32m1x4_t tmp_output_tile_val = *input_tile_val; | ||||
|   tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0); | ||||
|   tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1); | ||||
|   *input_tile_val = tmp_output_tile_val; | ||||
| } | ||||
|  | ||||
| inline vfloat32m1_t | ||||
| @ -300,11 +315,17 @@ inline void winograd_f2k3_kernel_transform__rvv( | ||||
|   const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4); | ||||
|   const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4); | ||||
|   vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4); | ||||
|  | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0); | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2); | ||||
|   /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above. | ||||
|    * Keep the temporary + write-back pattern to avoid ICE. | ||||
|    * Do NOT rewrite into: | ||||
|    *   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val); | ||||
|    */ | ||||
|   vfloat32m1x4_t tmp_transform = *transform; | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0); | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2); | ||||
|   *transform = tmp_transform; | ||||
| } | ||||
|  | ||||
| inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) { | ||||
|  | ||||
| @ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa | ||||
|   } | ||||
| } | ||||
|  | ||||
| static bool getDisableAddmmCudaLt() { | ||||
|     static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT"); | ||||
|     if (env_value == "1") { | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
| /* | ||||
|  * Checks whether DISABLE_ADDMM_CUDA_LT is set. | ||||
|  * Additionally, for ROCM we test whether the architecture supports the Lt. | ||||
|  */ | ||||
| static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) { | ||||
|   // When hipBLASLt is not supported on the architecture, return true | ||||
|   #ifdef USE_ROCM | ||||
|   static const std::vector<std::string> archs = { | ||||
|         "gfx90a", "gfx942", | ||||
|     #if ROCM_VERSION >= 60300 | ||||
|         "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", | ||||
|     #endif | ||||
|     #if ROCM_VERSION >= 70000 | ||||
|         "gfx950", "gfx1150", "gfx1151" | ||||
|     #endif | ||||
|   }; | ||||
|   const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index()); | ||||
|   if (!is_hipblas_lt_arch_supported) { | ||||
|     return true; | ||||
|   } | ||||
|   #endif | ||||
|  | ||||
|   // Check whether it is disabled in the env | ||||
|   static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT"); | ||||
|   if (is_addmm_cuda_lt_disabled == "1") { | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| static bool isSupportedHipLtROCmArch(int index) { | ||||
|     static const std::vector<std::string> archs = { | ||||
|         "gfx90a", "gfx942", | ||||
| #if ROCM_VERSION >= 60300 | ||||
|         "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", | ||||
| #endif | ||||
| #if ROCM_VERSION >= 70000 | ||||
|         "gfx950", "gfx1150", "gfx1151" | ||||
| #endif | ||||
|     }; | ||||
|     return at::detail::getCUDAHooks().isGPUArch(archs, index); | ||||
| /* | ||||
|  * Check whether for the given input we want to enable the Lt interface | ||||
|  */ | ||||
| static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { | ||||
|   // Implies 2D bias which we currently not send through Lt. | ||||
|   // TODO: this check is done pre col-major input preparation, | ||||
|   // so, this condition can be ralexed in cases when a col-major | ||||
|   // copy of result is needed. | ||||
|   if (result.is_same(self)) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   #if defined(USE_ROCM) && ROCM_VERSION == 60400 | ||||
|   // hipblaslt TT fp32 regression on ROCm 6.4, cannot use | ||||
|   const auto args = cublasCommonArgs(mat1, mat2, result); | ||||
|   if (args.transa == 't' && args.transb == 't') { | ||||
|     return false; | ||||
|   } | ||||
|   #endif | ||||
|  | ||||
|   const auto mat1_sizes = mat1.sizes(); | ||||
|   const auto mat2_sizes = mat2.sizes(); | ||||
|   #if defined(CUDA_VERSION) || defined(USE_ROCM) | ||||
|   const auto scalar_type = mat1.scalar_type(); | ||||
|   return (beta.toComplexDouble() == 1.0 | ||||
|     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] | ||||
|     // is to use lt interface only when self is bias. | ||||
|     && self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous() | ||||
|     && result.dim() == 2 && result.is_contiguous() | ||||
|     && ( // some dtype restrictions | ||||
|       #ifndef USE_ROCM | ||||
|       scalar_type == at::ScalarType::Double || | ||||
|       #endif | ||||
|       scalar_type == at::ScalarType::Float || | ||||
|       scalar_type == at::ScalarType::Half || | ||||
|       scalar_type == at::ScalarType::BFloat16 | ||||
|     ) | ||||
|     && ( // some shape/stride restrictions | ||||
|       // Strangely, if mat2 has only 1 row or column, we get | ||||
|       // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. | ||||
|       // NOTE: extension to mat1 because mat1/mat2 can be swapped based off | ||||
|       // their row-/col-majorness. | ||||
|       mat1_sizes[0] > 1 && mat1_sizes[1] > 1 && | ||||
|       mat2_sizes[0] > 1 && mat2_sizes[1] > 1 | ||||
|       // The last conditions is to skip 16b transA and non-trans-B having | ||||
|       // leading dim >> rows when they are sliced from a large tensor | ||||
|       // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul | ||||
|       #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM)) | ||||
|       // Related to avoiding the leading stride >> leading dim problematic case | ||||
|       // with 16b dtypes described above. For such dtypes we only allow inputs | ||||
|       // which are either row- or col-major (i.e. non-overlapping, compact memory layout). | ||||
|       // In that case the leading stride will be equal to the outer dim len. | ||||
|       // Why do we catch this case here? The following `prepare_matrix_for_cublas` method | ||||
|       // does not modify inputs as long as there is a stride of length 1 | ||||
|       // and the leading stride is at least max(1, other dim length), so we might | ||||
|       // end up with contiguous cols but not rows (i.e. holes between different rows) | ||||
|       // and vice versa. | ||||
|       mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && | ||||
|       mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && | ||||
|       && ( | ||||
|         // filter by dtype | ||||
|         (scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) || | ||||
|         // check mat1/mat2 is row-/col-major | ||||
|         (mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense()) | ||||
|       ) | ||||
|       #endif | ||||
|     ) | ||||
|   ); | ||||
|   #endif | ||||
|  | ||||
|   // no compliance by default | ||||
|   return false; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) { | ||||
| @ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename res_scalar_t = scalar_t> | ||||
| bool launchGemmAndBiasCublasLt( | ||||
|     // args contains result which is modified | ||||
|     cublasCommonArgs& args, | ||||
|     const Tensor& self, | ||||
|     const Scalar& alpha, | ||||
|     Activation activation = Activation::None | ||||
| ) { | ||||
|   const auto* self_ptr = self.const_data_ptr<scalar_t>(); | ||||
|  | ||||
|   const auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|   if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|     // TODO: maybe also return some success state? | ||||
|     launchTunableGemmAndBias<scalar_t>( | ||||
|       args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation) | ||||
|     ); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>( | ||||
|     args.transa == 't', | ||||
|     args.transb == 't', | ||||
|     args.m, | ||||
|     args.n, | ||||
|     args.k, | ||||
|     alpha.to<at::opmath_type<scalar_t>>(), | ||||
|     args.mata->const_data_ptr<scalar_t>(), | ||||
|     args.lda, | ||||
|     args.matb->const_data_ptr<scalar_t>(), | ||||
|     args.ldb, | ||||
|     self_ptr, | ||||
|     args.result->data_ptr<res_scalar_t>(), | ||||
|     args.result_ld, | ||||
|     activation_to_gemm_and_blas_arg(activation) | ||||
|   ); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename res_scalar_t = scalar_t> | ||||
| bool launchGemmCublas( | ||||
|     // args contains result which is modified | ||||
|     cublasCommonArgs& args, | ||||
|     const Scalar& alpha, | ||||
|     const Scalar& beta | ||||
| ) { | ||||
|   at::cuda::blas::gemm<scalar_t, res_scalar_t>( | ||||
|     args.transa, | ||||
|     args.transb, | ||||
|     args.m, | ||||
|     args.n, | ||||
|     args.k, | ||||
|     alpha.to<at::opmath_type<scalar_t>>(), | ||||
|     args.mata->const_data_ptr<scalar_t>(), | ||||
|     args.lda, | ||||
|     args.matb->const_data_ptr<scalar_t>(), | ||||
|     args.ldb, | ||||
|     beta.to<at::opmath_type<scalar_t>>(), | ||||
|     args.result->data_ptr<res_scalar_t>(), | ||||
|     args.result_ld | ||||
|   ); | ||||
|   return true; // success! | ||||
| } | ||||
|  | ||||
| Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) { | ||||
|   // Shape checks { | ||||
|   // Make sure to keep addmm_cuda below in sync with this code; it | ||||
|   // preflights a check to try to avoid actually needing to call | ||||
|   // expand(). | ||||
| @ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|     "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() | ||||
|   ) | ||||
|  | ||||
|   if (result.is_same(self)) { | ||||
|     TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); | ||||
|     TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0"); | ||||
|     TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1"); | ||||
|   } | ||||
|   // } Shape checks | ||||
|  | ||||
|   // NOLINTNEXTLINE(*c-array*) | ||||
|   TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}}; | ||||
|   checkAllSameGPU(__func__, targs); | ||||
|  | ||||
|   IntArrayRef mat1_sizes = mat1.sizes(); | ||||
|   IntArrayRef mat2_sizes = mat2.sizes(); | ||||
|   IntArrayRef self__sizes; | ||||
|   bool useLtInterface = false; | ||||
| #if defined(USE_ROCM) | ||||
|   // When hipBLASLt is not supported on the architecture, | ||||
|   // disable_addmm_cuda_lt will always be to set to true | ||||
|   static bool disable_addmm_cuda_lt = | ||||
|     !isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt(); | ||||
| #else | ||||
|   static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt(); | ||||
| #endif | ||||
|   // Handle whether to use the Lt interface { | ||||
|   static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()); | ||||
|   // if lt path fails, we recurse back into this function here and force the lt path to off | ||||
|   // we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent | ||||
|   bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override; | ||||
| #if defined(USE_ROCM) && ROCM_VERSION == 60400 | ||||
|   // hipblaslt TT fp32 regression on ROCm 6.4, cannot use | ||||
|   cublasCommonArgs _args(mat1, mat2, result); | ||||
|   if (_args.transa == 't' && _args.transb == 't') { | ||||
|     disable_addmm_cuda_lt_final = true; | ||||
|   } | ||||
| #endif | ||||
|   bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override; | ||||
|   #ifdef USE_ROCM | ||||
|   // Conditioned on the device index, which is not persistent | ||||
|   disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt; | ||||
|   #endif | ||||
|   // Condition on the input | ||||
|   disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt; | ||||
|   // } | ||||
|  | ||||
|   at::ScalarType scalar_type = mat1.scalar_type(); | ||||
|   bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float; | ||||
|   c10::MaybeOwned<Tensor> self_; | ||||
|   if (&result != &self) { | ||||
| #if defined(CUDA_VERSION) || defined(USE_ROCM) | ||||
|     // Strangely, if mat2 has only 1 row or column, we get | ||||
|     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. | ||||
|     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] | ||||
|     // is to use lt interface only when self is bias. | ||||
|     // for cuda 11.4, cublasLtMatmul is activated | ||||
|     // the last two conditions is to skip 16b transA and non-trans-B having | ||||
|     // leading dim >> rows when they are sliced from a large tensor | ||||
|     // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul | ||||
|     if (!disable_addmm_cuda_lt_final) { | ||||
|       useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 && | ||||
|           result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] && | ||||
|           self.is_contiguous() && result.is_contiguous() && | ||||
| #ifdef USE_ROCM | ||||
|           (scalar_type == at::ScalarType::Float || | ||||
|            scalar_type == at::ScalarType::Half || | ||||
|            scalar_type == at::ScalarType::BFloat16) && | ||||
| #else | ||||
|           (scalar_type == at::ScalarType::Double || | ||||
|            scalar_type == at::ScalarType::Float || | ||||
|            scalar_type == at::ScalarType::Half || | ||||
|            scalar_type == at::ScalarType::BFloat16) && | ||||
| #endif | ||||
| #if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM)) | ||||
|           mat2_sizes[0] > 1 && mat2_sizes[1] > 1; | ||||
| #else | ||||
|           mat2_sizes[0] > 1 && mat2_sizes[1] > 1 && | ||||
|           mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && | ||||
|           mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && | ||||
|           // avoid leading dim >> rows bugs | ||||
|           ((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) || | ||||
|            (mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) || | ||||
|            (scalar_type != at::ScalarType::Half && | ||||
|             scalar_type != at::ScalarType::BFloat16)) && | ||||
|           ((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) || | ||||
|            (mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) || | ||||
|            (scalar_type != at::ScalarType::Half && | ||||
|             scalar_type != at::ScalarType::BFloat16)); | ||||
| #endif | ||||
|     } | ||||
| #endif | ||||
|     if (!useLtInterface) { | ||||
|       self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm"); | ||||
|     } | ||||
|     self__sizes = self_->sizes(); | ||||
|   } else { | ||||
|     self_ = c10::MaybeOwned<Tensor>::borrowed(self); | ||||
|     self__sizes = self_->sizes(); | ||||
|     TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); | ||||
|     TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0"); | ||||
|     TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1"); | ||||
|   } | ||||
|  | ||||
|   if (&result != &self) { | ||||
|     at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]}); | ||||
|     if (beta.toComplexDouble() != 0.0 && !useLtInterface) { | ||||
|       at::native::copy_(result, *self_); | ||||
|   // Handle result/self shapes | ||||
|   if (!result.is_same(self)) { | ||||
|     at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]}); | ||||
|  | ||||
|     const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> { | ||||
|       if (disable_addmm_cuda_lt) { | ||||
|         // When in non-Lt path we do expand self even before | ||||
|         // check for beta != 0.0 to make sure that | ||||
|         // test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_* | ||||
|         // runs green. | ||||
|         return expand_size(self, result.sizes(), "addmm"); | ||||
|       } | ||||
|       // copy next, should broadcast | ||||
|       return c10::MaybeOwned<Tensor>::borrowed(self); | ||||
|     }(); | ||||
|     // We copy bias when in the non-Lt path | ||||
|     if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) { | ||||
|       // NOTE: self should broadcast over result | ||||
|       at::native::copy_(result, *self_maybe_expanded); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  | ||||
|   IntArrayRef result_sizes = result.sizes(); | ||||
|   if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { | ||||
|   // Short circuit on empty result | ||||
|   if (result.numel() == 0) { | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   cublasCommonArgs args(mat1, mat2, result); | ||||
|  | ||||
|   if (mat1.numel() == 0) { | ||||
|   // Short circuit if the reduction dim is empty | ||||
|   if (mat1.sizes()[1] == 0) { | ||||
|     // By definition, when beta==0, values in self should be ignored. nans and infs | ||||
|     // should not propagate | ||||
|     if (beta.toComplexDouble() == 0.) { | ||||
| @ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|         result, | ||||
|         self.expand(result.sizes()), | ||||
|         at::native::scalar_tensor( | ||||
|             beta, | ||||
|             self.scalar_type(), | ||||
|             std::nullopt /* layout */, | ||||
|             at::kCPU, | ||||
|             std::nullopt /* pin_memory */)); | ||||
|           beta, | ||||
|           self.scalar_type(), | ||||
|           std::nullopt /* layout */, | ||||
|           at::kCPU, | ||||
|           std::nullopt /* pin_memory */ | ||||
|         ) | ||||
|     ); | ||||
|   } | ||||
|  | ||||
|   cublasCommonArgs args(mat1, mat2, result); | ||||
|   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); | ||||
|  | ||||
|   if (useLtInterface) { | ||||
| #if defined(USE_ROCM) | ||||
|     bool okay = true; | ||||
|   // The Lt path | ||||
|   if (!disable_addmm_cuda_lt) { | ||||
|     bool lt_success = false; | ||||
|     if (is_float_output_with_half_input) { | ||||
|       #ifdef USE_ROCM | ||||
|       TORCH_CHECK(false, "float output with half input is not enabled for ROCm"); | ||||
|     } else { | ||||
|       AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
|         at::ScalarType::BFloat16, | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|         auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|         if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|           launchTunableGemmAndBias<scalar_t>( | ||||
|               args, | ||||
|               alpha, | ||||
|               (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr, | ||||
|               activation_to_gemm_and_blas_arg(activation)); | ||||
|         } else { | ||||
|           okay = at::cuda::blas::gemm_and_bias<scalar_t>( | ||||
|             args.transa == 't', | ||||
|             args.transb == 't', | ||||
|             args.m, | ||||
|             args.n, | ||||
|             args.k, | ||||
|             alpha.to<at::opmath_type<scalar_t>>(), | ||||
|             args.mata->const_data_ptr<scalar_t>(), | ||||
|             args.lda, | ||||
|             args.matb->const_data_ptr<scalar_t>(), | ||||
|             args.ldb, | ||||
|             // This condition is needed for mm case on ROCm for hipblasLt path. | ||||
|             // Passing the bias ptr as null to avoid accuracy issues for mm case. | ||||
|             (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr, | ||||
|             args.result->data_ptr<scalar_t>(), | ||||
|             args.result_ld, | ||||
|             activation_to_gemm_and_blas_arg(activation) | ||||
|           ); | ||||
|         } | ||||
|       }); | ||||
|     } | ||||
|     if (!okay) { | ||||
|       // lt path failed; recurse but disable lt path | ||||
|       return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); | ||||
|     } | ||||
| #else | ||||
|     auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); | ||||
|     bool okay = true; | ||||
|     if (is_float_output_with_half_input) { | ||||
|       #else | ||||
|       if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) { | ||||
|        TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input"); | ||||
|       } | ||||
|       AT_DISPATCH_REDUCED_FLOATING_TYPES( | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|         auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|         if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|           TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input"); | ||||
|           lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation); | ||||
|         } | ||||
|         else { | ||||
|           okay = at::cuda::blas::gemm_and_bias<scalar_t, float>( | ||||
|               args.transa == 't', | ||||
|               args.transb == 't', | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha.to<at::opmath_type<scalar_t>>(), | ||||
|               args.mata->const_data_ptr<scalar_t>(), | ||||
|               args.lda, | ||||
|               args.matb->const_data_ptr<scalar_t>(), | ||||
|               args.ldb, | ||||
|               self.const_data_ptr<scalar_t>(), | ||||
|               args.result->data_ptr<float>(), | ||||
|               args.result_ld, | ||||
|               activation_epilogue | ||||
|           ); | ||||
|         }}); | ||||
|       ); | ||||
|       #endif | ||||
|     } else { | ||||
|       // !is_float_output_with_half_input | ||||
|       AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
|         at::ScalarType::BFloat16, | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|         auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|         if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|           launchTunableGemmAndBias<scalar_t>( | ||||
|               args, | ||||
|               alpha, | ||||
|               self.const_data_ptr<scalar_t>(), | ||||
|               activation_epilogue); | ||||
|           lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation); | ||||
|         } | ||||
|         else { | ||||
|           okay = at::cuda::blas::gemm_and_bias<scalar_t>( | ||||
|               args.transa == 't', | ||||
|               args.transb == 't', | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha.to<at::opmath_type<scalar_t>>(), | ||||
|               args.mata->const_data_ptr<scalar_t>(), | ||||
|               args.lda, | ||||
|               args.matb->const_data_ptr<scalar_t>(), | ||||
|               args.ldb, | ||||
|               self.const_data_ptr<scalar_t>(), | ||||
|               args.result->data_ptr<scalar_t>(), | ||||
|               args.result_ld, | ||||
|               activation_epilogue | ||||
|           ); | ||||
|       }}); | ||||
|     } | ||||
|     if (!okay) { | ||||
|       // lt path failed; recurse but disable lt path | ||||
|       ); | ||||
|     } // end is_float_output_with_half_input | ||||
|  | ||||
|     if (!lt_success) { | ||||
|     // lt path failed; recurse but disable lt path | ||||
|       return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); | ||||
|     } | ||||
| #endif | ||||
|   } else | ||||
|   { | ||||
|     // end Lt path | ||||
|   } else { | ||||
|     // No Lt, we use a GEMM instead | ||||
|     if (is_float_output_with_half_input) { | ||||
|       AT_DISPATCH_REDUCED_FLOATING_TYPES( | ||||
|         scalar_type, | ||||
|         "addmm_cuda", | ||||
|         [&] { | ||||
|           using opmath_t = at::opmath_type<scalar_t>; | ||||
|           opmath_t alpha_val = alpha.to<opmath_t>(); | ||||
|           opmath_t beta_val = beta.to<opmath_t>(); | ||||
|           const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>(); | ||||
|           const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>(); | ||||
|  | ||||
|           float* result_ptr = args.result->mutable_data_ptr<float>(); | ||||
|           at::cuda::blas::gemm<scalar_t, float>( | ||||
|               args.transa, | ||||
|               args.transb, | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha_val, | ||||
|               mat1_ptr, | ||||
|               args.lda, | ||||
|               mat2_ptr, | ||||
|               args.ldb, | ||||
|               beta_val, | ||||
|               result_ptr, | ||||
|               args.result_ld); | ||||
|         }); | ||||
|           launchGemmCublas<scalar_t, float>(args, alpha, beta); | ||||
|         } | ||||
|       ); | ||||
|     } else { | ||||
|       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
| @ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|         scalar_type, | ||||
|         "addmm_cuda", | ||||
|         [&] { | ||||
|           using opmath_t = at::opmath_type<scalar_t>; | ||||
|           opmath_t alpha_val = alpha.to<opmath_t>(); | ||||
|           opmath_t beta_val = beta.to<opmath_t>(); | ||||
|           const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>(); | ||||
|           const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>(); | ||||
|           scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>(); | ||||
|           at::cuda::blas::gemm<scalar_t>( | ||||
|               args.transa, | ||||
|               args.transb, | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha_val, | ||||
|               mat1_ptr, | ||||
|               args.lda, | ||||
|               mat2_ptr, | ||||
|               args.ldb, | ||||
|               beta_val, | ||||
|               result_ptr, | ||||
|               args.result_ld); | ||||
|         }); | ||||
|           launchGemmCublas<scalar_t>(args, alpha, beta); | ||||
|         } | ||||
|       ); | ||||
|     } | ||||
|  | ||||
|     // Apply epilogue | ||||
|     switch (activation) { | ||||
|       case Activation::RELU: | ||||
|         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||||
| @ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|         break; | ||||
|       default: break; | ||||
|     } | ||||
|   } | ||||
|   } // end GEMM path | ||||
|  | ||||
| // Preprocessor gate here needs to match the inverse of the check | ||||
| // gating activation_to_gemm_and_blas_arg above; here we are manually | ||||
| // performing a post-GELU because we weren't able to use the GELU | ||||
| // epilogue above. | ||||
| #if !defined(CUDA_VERSION) && !defined(USE_ROCM) | ||||
|   if (useLtInterface && activation == Activation::GELU) { | ||||
|   if (!disable_addmm_cuda_lt && activation == Activation::GELU) { | ||||
|     at::gelu_(const_cast<Tensor&>(*args.result), "tanh"); | ||||
|   } | ||||
| #endif | ||||
| @ -2322,12 +2314,23 @@ _scaled_nvfp4_nvfp4( | ||||
|           const Tensor& scale_b, const SwizzleType swizzle_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const c10::ScalarType out_dtype, | ||||
|           const bool single_scale, | ||||
|           Tensor& out) { | ||||
|           Tensor& out, | ||||
|           const std::optional<Tensor>& global_scale_a = std::nullopt, | ||||
|           const std::optional<Tensor>& global_scale_b = std::nullopt) { | ||||
| #ifdef USE_ROCM | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM"); | ||||
| #endif | ||||
|   TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported"); | ||||
|   std::optional<Tensor> alpha = std::nullopt; | ||||
|   // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise, | ||||
|   //       if this is "And" we would silently do nothing in the case where one global scale is | ||||
|   //       passed and not the other. | ||||
|   if (global_scale_a.has_value() || global_scale_b.has_value()) { | ||||
|     TORCH_CHECK_VALUE(global_scale_a.has_value(), | ||||
|         "For two-level-scaled NVFP4, global_scale_a must have a value"); | ||||
|     TORCH_CHECK_VALUE(global_scale_b.has_value(), | ||||
|         "For two-level-scaled NVFP4, global_scale_b must have a value"); | ||||
|     alpha = global_scale_a.value().mul(global_scale_b.value()); | ||||
|   } | ||||
|   // Restrictions: | ||||
|   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 | ||||
|   // Scales must be swizzled | ||||
| @ -2349,7 +2352,7 @@ _scaled_nvfp4_nvfp4( | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x16; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x16; | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -2555,9 +2558,10 @@ _scaled_mm_cuda_v2_out( | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) { | ||||
|     return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) { | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported"); | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out, | ||||
|                                scale_a[1], scale_b[1]); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) { | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out); | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) { | ||||
|     return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else { | ||||
|  | ||||
| @ -15,9 +15,7 @@ | ||||
| #include <ATen/native/cuda/block_reduce.cuh> | ||||
| #include <ATen/native/cuda/thread_constants.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| #include <thrust/iterator/reverse_iterator.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -240,10 +238,6 @@ __global__ void renorm_kernel( | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); | ||||
| #endif | ||||
|  | ||||
| Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, | ||||
|                                int64_t num_weights, int64_t padding_idx, | ||||
| @ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | ||||
|  | ||||
|   if (scale_grad_by_freq) { | ||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
| @ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | ||||
|         num_indices | ||||
|       ); | ||||
|     }); | ||||
| #else | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); | ||||
|     }); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, | ||||
|  | ||||
| @ -10,9 +10,7 @@ | ||||
|  | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| #include <thrust/iterator/counting_iterator.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm | ||||
|             partials_per_segment_offset[num_of_segments-1]; | ||||
| } | ||||
|  | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| __global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) { | ||||
|   *num_of_segments_ptr = num_of_segments; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| } // anon namespace | ||||
|  | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| template<typename index_t> | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
| #endif | ||||
|  | ||||
| Tensor embedding_backward_cuda_kernel( | ||||
|         const Tensor &grad, | ||||
| @ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel( | ||||
|   auto segment_offsets = at::empty({numel}, orig_indices.options()); | ||||
|   auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); | ||||
|   int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>(); | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||
|     int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets); | ||||
|     write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   }); | ||||
| #else | ||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||
|     cuda::cub::unique_by_key( | ||||
|       sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0), | ||||
|       segment_offsets.mutable_data_ptr<index_t>(), | ||||
|       num_of_segments_ptr, sorted_indices.numel()); | ||||
|   }); | ||||
| #endif | ||||
|  | ||||
|   int64_t max_segments = std::min<int64_t>(numel, num_weights); | ||||
|  | ||||
|  | ||||
| @ -31,16 +31,10 @@ | ||||
|  | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| #include <thrust/iterator/reverse_iterator.h> | ||||
| #endif | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); | ||||
| #endif | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| @ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg( | ||||
|  | ||||
|   if (scale_grad_by_freq) { | ||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
| @ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg( | ||||
|         num_indices | ||||
|       ); | ||||
|     }); | ||||
| #else | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); | ||||
|     }); | ||||
| #endif | ||||
|   } | ||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, | ||||
|       count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, | ||||
|  | ||||
| @ -1,90 +0,0 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/native/cuda/SortingCommon.cuh> | ||||
| #include <ATen/cuda/cub_definitions.cuh> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| #else | ||||
| #include <ATen/ops/empty_like.h> | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cuda/ThrustAllocator.h> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/execution_policy.h> | ||||
| #include <thrust/sort.h> | ||||
| #include <thrust/unique.h> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/iterator/constant_iterator.h> | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|  | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   at::cuda::ThrustAllocator allocator; | ||||
|   auto policy = thrust::cuda::par(allocator).on(stream); | ||||
|  | ||||
|   auto num_indices = count.numel(); | ||||
|  | ||||
|   // Compute an increasing sequence per unique item in sortedIndices: | ||||
|   // sorted: 2 5 5 5 7 7 8 9 9 | ||||
|   //  count: 1 1 2 3 1 2 1 1 2 | ||||
|   auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); | ||||
|   auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>()); | ||||
|   thrust::inclusive_scan_by_key( | ||||
|     policy, | ||||
|     sorted_data, | ||||
|     sorted_data + num_indices, | ||||
|     thrust::make_constant_iterator(1), | ||||
|     count_data | ||||
|   ); | ||||
|  | ||||
|   // Take the maximum of each count per unique key in reverse: | ||||
|   // sorted: 2 5 5 5 7 7 8 9 9 | ||||
|   //  count: 1 3 3 3 2 2 1 2 2 | ||||
|   thrust::inclusive_scan_by_key( | ||||
|     policy, | ||||
|     thrust::make_reverse_iterator(sorted_data + num_indices), | ||||
|     thrust::make_reverse_iterator(sorted_data), | ||||
|     thrust::make_reverse_iterator(count_data + num_indices), | ||||
|     thrust::make_reverse_iterator(count_data + num_indices), | ||||
|     thrust::equal_to<index_t>(), | ||||
|     thrust::maximum<index_t>() | ||||
|   ); | ||||
| } | ||||
|  | ||||
| template | ||||
| void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count); | ||||
| template | ||||
| void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count); | ||||
|  | ||||
| #endif | ||||
|  | ||||
| template<typename index_t> | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(); | ||||
|   at::cuda::ThrustAllocator allocator; | ||||
|   auto policy = thrust::cuda::par(allocator).on(stream); | ||||
|   const ptrdiff_t numel = sorted_indices.numel(); | ||||
|   auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); | ||||
|   auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
|   auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>()); | ||||
|   auto ends = thrust::unique_by_key_copy( | ||||
|           policy, | ||||
|           sorted_indices_dev, | ||||
|           sorted_indices_dev + numel, | ||||
|           thrust::make_counting_iterator(0), | ||||
|           dummy_dev, | ||||
|           thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>())); | ||||
|   return thrust::get<0>(ends) - dummy_dev; | ||||
| } | ||||
|  | ||||
| template | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
| template | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
|  | ||||
| } // namespace at::native | ||||
| @ -1,18 +1,17 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/OpMathType.h> | ||||
| #include <ATen/cuda/detail/OffsetCalculator.cuh> | ||||
| #include <ATen/detail/FunctionTraits.h> | ||||
| #include <ATen/native/TensorIterator.h> | ||||
| #include <ATen/native/TensorIteratorDynamicCasting.h> | ||||
| #include <ATen/cuda/detail/OffsetCalculator.cuh> | ||||
| #include <ATen/OpMathType.h> | ||||
| #include <ATen/native/cuda/thread_constants.h> | ||||
|  | ||||
| #include <thrust/tuple.h> | ||||
|  | ||||
| #include <ATen/native/cuda/MemoryAccess.cuh> | ||||
|  | ||||
| #include <tuple> | ||||
|  | ||||
|  | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| template<int N> | ||||
| @ -62,7 +61,11 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { | ||||
|   #pragma unroll | ||||
|   for (int i = 0; i < elems_per_thread; i++) { | ||||
|     if (policy.check_inbounds(i)) { | ||||
| #if defined(__HIP__) | ||||
|       results[i] = c10::guts::apply(f, args[i]); | ||||
| #else | ||||
|       results[i] = std::apply(f, args[i]); | ||||
| #endif | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -146,6 +146,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel( | ||||
|   int64_t batch_size = target.size(0); | ||||
|   int64_t H = target.size(1); | ||||
|   int64_t W = target.size(2); | ||||
|   int64_t n_classes = grad_input.size(1); | ||||
|  | ||||
|   CUDA_KERNEL_LOOP(index, n_threads) { | ||||
|     const int64_t b = index % batch_size; | ||||
| @ -156,6 +157,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel( | ||||
|     if (cur_target == ignore_index) { | ||||
|       continue; | ||||
|     } | ||||
|     CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes); | ||||
|     scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1)); | ||||
|     grad_input[b][cur_target][h][w] = value * grad_output[b][h][w]; | ||||
|   } | ||||
|  | ||||
| @ -23,7 +23,7 @@ namespace at::native { | ||||
|  | ||||
| // The maximum number of threads in a block | ||||
| #if defined(USE_ROCM) | ||||
| constexpr int MAX_BLOCK_SIZE = 256; | ||||
| constexpr int MAX_BLOCK_SIZE = 1024; | ||||
| #else | ||||
| constexpr int MAX_BLOCK_SIZE = 512; | ||||
| #endif | ||||
| @ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u; | ||||
| // Number of threads in a block given an input size up to MAX_BLOCK_SIZE | ||||
| static int getNumThreads(int nElem) { | ||||
| #if defined(USE_ROCM) | ||||
|   int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; | ||||
|   int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE }; | ||||
| #else | ||||
|   int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; | ||||
| #endif | ||||
| @ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { | ||||
|   // first the reductions each thread does separately | ||||
|   scalar_t sum = static_cast<scalar_t>(0); | ||||
|   for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { | ||||
| #if defined(USE_ROCM) | ||||
|     constexpr int UNRL = 4; // load deserilize factor | ||||
|     scalar_t tmp[UNRL]; | ||||
|     for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) { | ||||
| #pragma unroll | ||||
|       for (int u = 0; u < UNRL; u++) | ||||
|         tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x))); | ||||
| #pragma unroll | ||||
|       for (int u = 0; u < UNRL; u++) | ||||
|         if (x+u*blockDim.x < tensor.size(2)) | ||||
|           sum += tmp[u]; | ||||
|     } | ||||
| #else | ||||
|     for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { | ||||
|       sum += op(batch, plane, x); | ||||
|     } | ||||
| #endif | ||||
|   } | ||||
|   __shared__ scalar_t shared[C10_WARP_SIZE]; | ||||
|   SumReduceOp<scalar_t> reduce_op; | ||||
| @ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel( | ||||
|   stat_accscalar_t var_n = 0; | ||||
|   int n = 0; | ||||
|   for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { | ||||
| #if defined(USE_ROCM) | ||||
|     constexpr int UNRL = 4; | ||||
|     stat_accscalar_t v_[UNRL]; | ||||
|     for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) { | ||||
|       for (int u = 0; u < UNRL; u++) | ||||
|         v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)]; | ||||
|       for (int u = 0; u < UNRL; u++) { | ||||
|         if (x+u*blockDim.x < input.size(2)) { | ||||
|           stat_accscalar_t d1 = v_[u] - avg; | ||||
|           n++; | ||||
|           avg += d1 / n; | ||||
|           var_n += d1 * (v_[u] - avg); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| #else | ||||
|     for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { | ||||
|       stat_accscalar_t v = input[batch][plane][x]; | ||||
|       stat_accscalar_t d1 = v - avg; | ||||
| @ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel( | ||||
|       avg += d1 / n; | ||||
|       var_n += d1 * (v - avg); | ||||
|     } | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // first warpSum to get one value per thread to | ||||
|  | ||||
| @ -413,14 +413,12 @@ struct ReduceOp { | ||||
|       value = thread_reduce<output_vec_size>(input_slice); | ||||
|     } | ||||
|  | ||||
|     if (config.should_block_y_reduce()) { | ||||
|       value = block_y_reduce<output_vec_size>(value, shared_memory); | ||||
|     } | ||||
|     __syncthreads(); | ||||
|     if (config.should_block_x_reduce()) { | ||||
|       value = block_x_reduce<output_vec_size>(value, shared_memory); | ||||
|     } | ||||
|  | ||||
|     if (config.should_block_y_reduce()) { | ||||
|       value = block_y_reduce<output_vec_size>(value, shared_memory); | ||||
|     } | ||||
|     using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>; | ||||
|     using offset_vec_t = std::array<index_t, output_vec_size>; | ||||
|     offset_vec_t base_offsets; | ||||
| @ -657,8 +655,8 @@ struct ReduceOp { | ||||
|     __syncthreads(); | ||||
|     // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics | ||||
|     // matching Triton, etc. | ||||
|     // todo for AMD | ||||
|     #ifdef USE_ROCM | ||||
|     // TODO(PaulZhang12): AMD and internal | ||||
|     #if defined(USE_ROCM) || defined(FBCODE_CAFFE2) | ||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||
|     #else | ||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||
|  | ||||
| @ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t>  get_index_mapping2d( | ||||
|     output_offset + output_y * output_dim_x + output_x); | ||||
| } | ||||
|  | ||||
| __device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) { | ||||
|   const int64_t two = (len - 1) * 2; | ||||
|   if (two <= 0) { | ||||
|     return 0; | ||||
|   } | ||||
|   int64_t m = x % two; | ||||
|   if (m < 0) m += two; | ||||
|   return (m < len) ? m : (two - m); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| __global__ void reflection_pad1d_out_kernel( | ||||
|     const scalar_t * input, scalar_t * output, | ||||
| @ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| __global__ void reflection_pad1d_flat( | ||||
|     const scalar_t* __restrict__ input, | ||||
|     scalar_t* __restrict__ output, | ||||
|     int64_t input_w, int64_t pad_l, int64_t pad_r, | ||||
|     int64_t out_w, int64_t plane_count) { | ||||
|  | ||||
|   const int64_t bx = blockDim.x; | ||||
|   const int64_t tx = threadIdx.x; | ||||
|  | ||||
|   const int64_t total = plane_count * out_w; | ||||
|   const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x; | ||||
|   int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx; | ||||
|  | ||||
|   for (; linear < total; linear += grid_stride) { | ||||
|     const int64_t plane = linear / out_w; | ||||
|     const int64_t x = linear - plane * out_w; | ||||
|     const int64_t j = reflect_index(x - pad_l, input_w); | ||||
|     output[plane * out_w + x] = input[plane * input_w + j]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| __global__ void reflection_pad1d_backward_out_kernel( | ||||
|     scalar_t * grad_input, const scalar_t * grad_output, | ||||
| @ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda) | ||||
|   int64_t input_w = input_.size(dim_w); | ||||
|   int64_t output_w = input_w + pad_l + pad_r; | ||||
|  | ||||
|   dim3 block_size(output_w > 256 ? 256 : output_w); | ||||
|   dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch); | ||||
|  | ||||
|   Tensor input = input_.contiguous(); | ||||
|  | ||||
|   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( | ||||
|       kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] { | ||||
|         reflection_pad1d_out_kernel<<< | ||||
|             grid_size, | ||||
|             block_size, | ||||
|             0, | ||||
|             at::cuda::getCurrentCUDAStream()>>>( | ||||
|             input.const_data_ptr<scalar_t>(), | ||||
|             output.mutable_data_ptr<scalar_t>(), | ||||
|             input_w, | ||||
|             pad_l, | ||||
|             pad_r); | ||||
|         C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|       }); | ||||
|   const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w))); | ||||
|   const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||||
|   const int max_x = prop->maxGridSize[0]; | ||||
|   const int max_y = prop->maxGridSize[1]; | ||||
|   const int max_z = prop->maxGridSize[2]; | ||||
|  | ||||
|   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] { | ||||
|     auto stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
|     const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x)); | ||||
|  | ||||
|     const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x); | ||||
|  | ||||
|     if (fits3d) { | ||||
|       dim3 block(block_x, 1, 1); | ||||
|       dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch)); | ||||
|       reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|           input.const_data_ptr<scalar_t>(), | ||||
|           output.mutable_data_ptr<scalar_t>(), | ||||
|           input_w, pad_l, pad_r); | ||||
|     } else { | ||||
|       dim3 block(block_x, 1, 1); | ||||
|       const int64_t plane_count = nplane * nbatch; | ||||
|       const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x)); | ||||
|       const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks))); | ||||
|       dim3 grid(grid_x, 1, 1); | ||||
|  | ||||
|       reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>( | ||||
|           input.const_data_ptr<scalar_t>(), | ||||
|           output.mutable_data_ptr<scalar_t>(), | ||||
|           input_w, pad_l, pad_r, output_w, plane_count); | ||||
|     } | ||||
|  | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_, | ||||
|  | ||||
| @ -19,7 +19,6 @@ | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| // TODO: remove this when CUDA <11.6 is no longer supported | ||||
| void topk_out_with_sort( | ||||
|   const Tensor& self, | ||||
|   int64_t k, int64_t dim, bool largest, | ||||
| @ -31,21 +30,12 @@ void topk_out_with_sort( | ||||
|   indices.copy_(sorted_indices.narrow(dim, 0, k)); | ||||
| } | ||||
|  | ||||
| // TODO: remove this when CUDA <11.6 is no longer supported | ||||
| bool disable_sort_for_topk(); | ||||
| bool should_use_sort(const Tensor& self, int64_t dim) { | ||||
| #if defined(USE_ROCM) | ||||
|   if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972 | ||||
|   return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387 | ||||
| #else | ||||
|   if (disable_sort_for_topk()) return false; | ||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632 | ||||
|   if (self.dim() == 0) return false; | ||||
|   if (self.dtype() == kBool) return false; // Bool is not support by topk | ||||
|   int64_t slice_size = self.size(dim); | ||||
|   if (slice_size == 0) return false; | ||||
|   int64_t num_slices = self.numel() / slice_size; | ||||
|   return num_slices <= 10 && slice_size >= 100000; | ||||
|   return false; | ||||
| #endif | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -21,11 +21,6 @@ using namespace at::native; | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| // TODO: remove this when CUDA <11.6 is no longer supported | ||||
| bool disable_sort_for_topk() { | ||||
|   return CUB_SUPPORTS_SCAN_BY_KEY(); | ||||
| } | ||||
|  | ||||
| namespace sbtopk { // single_block_topk | ||||
|  | ||||
| template <typename T> | ||||
| @ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts( | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   return; | ||||
| #endif | ||||
|  | ||||
|   Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS); | ||||
|  | ||||
|   // if largest, then only threads that has tidx > desired_digit are active | ||||
| @ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| // Assumption: slice_size can not be larger than UINT32_MAX | ||||
| template <typename Bitwise> | ||||
| __global__ void computeBlockwiseKthCounts( | ||||
| @ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu | ||||
|     } | ||||
|   } | ||||
| } | ||||
| #endif | ||||
|  | ||||
| int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { | ||||
|   // occupancy of this kernel is limited by registers per threads | ||||
| @ -687,16 +676,12 @@ void launch( | ||||
|   uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get()); | ||||
|   AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream)); | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); | ||||
|   uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get()); | ||||
|   AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); | ||||
|  | ||||
|   auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); | ||||
|   uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get()); | ||||
| #else | ||||
|   uint32_t* withinKCounts = nullptr; | ||||
| #endif | ||||
|  | ||||
|   Bitwise desiredMask = 0; | ||||
|   dim3 grid; | ||||
| @ -743,7 +728,6 @@ void launch( | ||||
|   } | ||||
|   desired = desired_in; | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>( | ||||
|     desired, counts, num_blocks, blocks_per_slice, kthCounts); | ||||
|   C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| @ -759,28 +743,6 @@ void launch( | ||||
|     topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, | ||||
|     blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); | ||||
|   C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| #else | ||||
|   // Find topk values based on kth values | ||||
|   { | ||||
|     dim3 grid; | ||||
|     TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk"); | ||||
|     int warp_size = at::cuda::warp_size(); | ||||
|     dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024)); | ||||
|     sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>( | ||||
|         input, | ||||
|         inputSliceSize, | ||||
|         outputSliceSize, | ||||
|         largest, | ||||
|         numInputSlices, | ||||
|         inputWithinSliceStride, | ||||
|         topK, | ||||
|         topKWithinSliceStride, | ||||
|         indices, | ||||
|         indicesWithinSliceStride, | ||||
|         kthValues); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| } // namespace mbtopk | ||||
| @ -788,7 +750,6 @@ void launch( | ||||
| bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | ||||
|   if (num_slices > std::numeric_limits<uint32_t>::max() || | ||||
|       slice_size > std::numeric_limits<uint32_t>::max()) return false; | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 | ||||
|   return (num_slices <= 20 && slice_size >= 20000) || | ||||
|       (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || | ||||
| @ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | ||||
|       (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || | ||||
|       (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || | ||||
|       (num_slices > 4000 && slice_size >= 400); | ||||
| #else | ||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081 | ||||
|   return (num_slices <= 400 && slice_size >= 5000) || | ||||
|       (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) || | ||||
|       (num_slices >= 4000 && slice_size >= 300); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void launch_gather_topk_kernel( | ||||
|  | ||||
| @ -44,7 +44,7 @@ __global__ void triu_tril_kernel( | ||||
|     const int64_t k, | ||||
|     const int64_t N_padded, | ||||
|     const IndexType last_dim_padded) { | ||||
|   int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread; | ||||
|   int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread; | ||||
|   if (linear_idx >= N_padded) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| @ -52,7 +52,7 @@ struct FusedAdagradMathFunctor { | ||||
|   using opmath_t = at::opmath_type<scalar_t>; | ||||
|  | ||||
|   C10_DEVICE __forceinline__ void operator()( | ||||
|       int chunk_size, | ||||
|       int64_t chunk_size, | ||||
|       FusedOptimizerTensorListMetadata<3>& tl, | ||||
|       const float* lr_ptr, | ||||
|       const double& lr, | ||||
| @ -133,4 +133,4 @@ struct FusedAdagradMathFunctor { | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| } // namespace at::native | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -466,7 +466,7 @@ struct ReduceJitOp { | ||||
|  | ||||
|     __syncthreads(); | ||||
|  | ||||
|     #ifdef USE_ROCM | ||||
|     #if defined(USE_ROCM) || defined(FBCODE_CAFFE2) | ||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||
|     #else | ||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||
|  | ||||
| @ -487,9 +487,7 @@ std::unique_ptr<fe::graph::Graph> build_graph( | ||||
|   auto scaled_dot_product_flash_attention_options = | ||||
|       fe::graph::SDPA_attributes() | ||||
|           .set_name("CUDNN_SDPA") | ||||
|           .set_is_inference(return_softmaxstats == false) | ||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded | ||||
|           // .set_generate_stats(return_softmaxstats) | ||||
|           .set_generate_stats(return_softmaxstats) | ||||
|           .set_causal_mask(is_causal) | ||||
|           .set_attn_scale(attn_scale); | ||||
|   if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { | ||||
| @ -707,9 +705,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor( | ||||
|   auto scaled_dot_product_flash_attention_options = | ||||
|       fe::graph::SDPA_attributes() | ||||
|           .set_name("CUDNN_SDPA_NESTEDTENSOR") | ||||
|           .set_is_inference(return_softmaxstats == false) | ||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded | ||||
|           // .set_generate_stats(return_softmaxstats) | ||||
|           .set_generate_stats(return_softmaxstats) | ||||
|           .set_causal_mask(is_causal) | ||||
|           .set_attn_scale(attn_scale) | ||||
|           .set_seq_len_q(SEQ_LEN_Q_) | ||||
|  | ||||
| @ -441,7 +441,7 @@ kernel void applySYRK( | ||||
|     uint3 tid [[thread_position_in_threadgroup]], | ||||
|     uint3 tgid [[threadgroup_position_in_grid]], | ||||
|     uint3 tpg [[threads_per_threadgroup]], | ||||
|     uint sgitg [[simdgroup_index_in_threadgroup]]) { | ||||
|     uint warp_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   const uint tx = tid.x; | ||||
|   const uint ty = tid.y; | ||||
|   const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32; | ||||
| @ -474,11 +474,8 @@ kernel void applySYRK( | ||||
|       (actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0); | ||||
|  | ||||
|   if (use_simdgroup) { | ||||
|     uint warp_id = sgitg; | ||||
|  | ||||
|     simdgroup_matrix<float, 8, 8> negative_identity = | ||||
|         simdgroup_matrix<float, 8, 8>(-1.0); | ||||
|     simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0); | ||||
|     simdgroup_matrix<float, 8, 8> Prod; | ||||
|     simdgroup_matrix<float, 8, 8> Afrag; | ||||
|     simdgroup_matrix<float, 8, 8> Bfrag; | ||||
| @ -521,8 +518,7 @@ kernel void applySYRK( | ||||
|             /* transpose = */ upper); | ||||
|  | ||||
|         simdgroup_multiply(Prod, Afrag, Bfrag); | ||||
|         simdgroup_multiply(Prod, Prod, negative_identity); | ||||
|         simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod); | ||||
|         simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag); | ||||
|       } | ||||
|  | ||||
|       simdgroup_store( | ||||
|  | ||||
| @ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query, | ||||
|           } | ||||
|  | ||||
|           // upcasting to float32 if needed to improve precision when multiplying by the scale factor | ||||
|           if ([maskedMM dataType] != MPSDataTypeFloat32) { | ||||
|             maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil]; | ||||
|           } | ||||
|           maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32); | ||||
|           maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil]; | ||||
|           if ([maskedMM dataType] != qTensor.dataType) { | ||||
|             maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil]; | ||||
|           } | ||||
|  | ||||
|           if (is_causal) { | ||||
|             auto causalMask = [mpsGraph constantWithScalar:1.0f | ||||
| @ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query, | ||||
|                                                       name:nil]; | ||||
|           } else if (attn_mask) { | ||||
|             graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); | ||||
|             maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil]; | ||||
|             maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM | ||||
|                                            secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType) | ||||
|                                                       name:nil]; | ||||
|           } | ||||
|  | ||||
|           // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) | ||||
| @ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query, | ||||
|           graph->qTensor = qTensor; | ||||
|           graph->kTensor = kTensor; | ||||
|           graph->vTensor = vTensor; | ||||
|           graph->outputTensor = output; | ||||
|           graph->attnTensor = sm; | ||||
|           graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType); | ||||
|           graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType); | ||||
|         }); | ||||
|     auto qPlaceholder = Placeholder(cachedGraph->qTensor, query); | ||||
|     auto kPlaceholder = Placeholder(cachedGraph->kTensor, key); | ||||
|  | ||||
| @ -338,6 +338,8 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, | ||||
|           ". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details."); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   map_mps_decomposition_error_code_to_blas(info); | ||||
| } | ||||
|  | ||||
| static void linalg_solve_out_mps_impl(const Tensor& A, | ||||
| @ -1448,20 +1450,6 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps) | ||||
|   mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info); | ||||
| } | ||||
|  | ||||
| std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { | ||||
|   Tensor info = at::empty({}, A.options().dtype(kInt)); | ||||
|   mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false); | ||||
|   return std::tie(LU, pivots); | ||||
| } | ||||
|  | ||||
| std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) { | ||||
|   Tensor LU = at::empty({0}, A.options()); | ||||
|   Tensor pivots = at::empty({0}, A.options().dtype(kInt)); | ||||
|   Tensor info = at::empty({}, A.options().dtype(kInt)); | ||||
|   mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false); | ||||
|   return std::make_tuple(std::move(LU), std::move(pivots)); | ||||
| } | ||||
|  | ||||
| TORCH_IMPL_FUNC(lu_unpack_out_mps) | ||||
| (const Tensor& LU_data, | ||||
|  const Tensor& LU_pivots, | ||||
|  | ||||
| @ -706,6 +706,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all | ||||
|   tags: reduction | ||||
|  | ||||
|  | ||||
| - func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor | ||||
| @ -715,6 +716,7 @@ | ||||
|   cpp_no_default_args: ['dim'] | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: all_dims_default | ||||
|   tags: reduction | ||||
|  | ||||
| - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -723,6 +725,7 @@ | ||||
|     CPU, CUDA: all_out | ||||
|     MPS: all_out_mps | ||||
|     MTIA: all_out_mtia | ||||
|   tags: reduction | ||||
|  | ||||
| - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -731,13 +734,16 @@ | ||||
|     CPU, CUDA: all_dims_out | ||||
|     CompositeExplicitAutograd: all_dims_out_default | ||||
|   cpp_no_default_args: ['dim'] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool | ||||
|   variants: function, method | ||||
| @ -749,14 +755,14 @@ | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   structured_delegate: any.out | ||||
|   variants: function, method | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   structured_delegate: any.dims_out | ||||
|   variants: function, method | ||||
|   cpp_no_default_args: ['dim'] | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: any_dims_default | ||||
|  | ||||
| @ -766,6 +772,7 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: any_out | ||||
|     MPS: any_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -774,13 +781,16 @@ | ||||
|     CPU, CUDA: any_dims_out | ||||
|     CompositeExplicitAutograd: any_dims_out_default | ||||
|   cpp_no_default_args: ['dim'] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||||
|   dispatch: | ||||
| @ -826,25 +836,27 @@ | ||||
|   structured_delegate: argmax.out | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
|   dispatch: | ||||
|     CPU, CUDA: argmax_out | ||||
|     MPS: argmax_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor | ||||
|   structured_delegate: argmin.out | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
|   dispatch: | ||||
|     CPU, CUDA: argmin_out | ||||
|     MPS: argmin_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: acosh(Tensor self) -> Tensor | ||||
|   variants: function, method | ||||
| @ -1370,6 +1382,7 @@ | ||||
|   dispatch: | ||||
|     SparseCPU: bmm_sparse_cpu | ||||
|     SparseCUDA: bmm_sparse_cuda | ||||
|     SparseMPS: bmm_sparse_mps | ||||
|     NestedTensorCPU: bmm_nested | ||||
|     NestedTensorCUDA: bmm_nested_cuda | ||||
|   tags: core | ||||
| @ -1385,6 +1398,7 @@ | ||||
|     MTIA: bmm_out_mtia | ||||
|     SparseCPU: bmm_out_sparse_cpu | ||||
|     SparseCUDA: bmm_out_sparse_cuda | ||||
|     SparseMPS: bmm_out_sparse_mps | ||||
|     SparseCsrCUDA: bmm_out_sparse_csr_cuda | ||||
|  | ||||
| - func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor | ||||
| @ -1867,12 +1881,14 @@ | ||||
|     CUDA: count_nonzero_cuda | ||||
|     MPS: count_nonzero_mps | ||||
|   autogen: count_nonzero.dim_IntList_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: count_nonzero(Tensor self, int? dim=None) -> Tensor | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: count_nonzero | ||||
|   autogen: count_nonzero.out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor | ||||
|   variants: function, method | ||||
| @ -3793,19 +3809,23 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: logsumexp | ||||
|   tags: reduction | ||||
|  | ||||
| - func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   dispatch: | ||||
|     # calls squeeze | ||||
|     CompositeExplicitAutogradNonFunctional: logsumexp_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor | ||||
|  | ||||
| @ -3855,6 +3875,7 @@ | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   structured_delegate: aminmax.out | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -3862,6 +3883,7 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA, MTIA: aminmax_out | ||||
|     MPS: aminmax_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor | ||||
|   dispatch: | ||||
| @ -3877,7 +3899,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     QuantizedCPU, QuantizedCUDA: qmax | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -3887,13 +3909,16 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA, MTIA: max_out | ||||
|     MPS: max_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor | ||||
|   variants: function | ||||
| @ -3906,13 +3931,14 @@ | ||||
| - func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor | ||||
|   variants: function, method | ||||
|   structured_delegate: amax.out | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
|   dispatch: | ||||
|     CPU, CUDA, MTIA: amax_out | ||||
|     MPS: amax_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| # Return: (Tensor output, Tensor indices) | ||||
| - func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) | ||||
| @ -3974,13 +4000,14 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: mean | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| # For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. | ||||
| - func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: mean_dtype_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   structured_delegate: mean.out | ||||
| @ -3988,7 +4015,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     QuantizedCPU: mean_quantized_cpu | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
| @ -3997,13 +4024,16 @@ | ||||
|     CPU, CUDA: mean_out | ||||
|     MPS: mean_out_mps | ||||
|     QuantizedCPU: mean_out_quantized_cpu | ||||
|   tags: reduction | ||||
|  | ||||
| - func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   device_check: NoCheck   # Composite | ||||
| @ -4066,7 +4096,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     QuantizedCPU, QuantizedCUDA: qmin | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -4076,24 +4106,28 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA, MTIA: min_out | ||||
|     MPS: min_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor | ||||
|   variants: function, method | ||||
|   structured_delegate: amin.out | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
|   dispatch: | ||||
|     CPU, CUDA, MTIA: amin_out | ||||
|     MPS: amin_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| # TODO: Add this function to MPS dispatch key so that we avoid declaring it in | ||||
| # native_functions.yaml | ||||
| @ -4173,7 +4207,7 @@ | ||||
|   structured_delegate: mm.out | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     SparseCPU, SparseCUDA: _sparse_mm | ||||
|     SparseCPU, SparseCUDA, SparseMPS: _sparse_mm | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm | ||||
|   tags: core | ||||
|  | ||||
| @ -5858,6 +5892,7 @@ | ||||
|     SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr | ||||
|   autogen: sum.out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype | ||||
| @ -5868,11 +5903,12 @@ | ||||
|     NestedTensorCPU: NestedTensor_sum_dim_CPU | ||||
|     SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
| @ -5880,9 +5916,11 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: sum_out | ||||
|     MPS: sum_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| # TODO: this function will be replaced once nested expand semantics have been settled on | ||||
| - func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor | ||||
| @ -5894,11 +5932,13 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: nansum | ||||
|     MPS: nansum_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   dispatch: | ||||
|     CPU, CUDA: nansum_out | ||||
|     MPS: nansum_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor | ||||
|   variants: function, method | ||||
| @ -5962,11 +6002,13 @@ | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -5975,16 +6017,19 @@ | ||||
|     CPU, CUDA: std | ||||
|     MPS: std_mps | ||||
|     QuantizedCPU: std_quantized_cpu | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -5993,42 +6038,51 @@ | ||||
|     CPU, CUDA: std_mean | ||||
|     MPS: std_mean_mps | ||||
|   autogen: std_mean.correction_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   dispatch: | ||||
|     CPU, CUDA: std_out | ||||
|     QuantizedCPU: std_out_quantized_cpu | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   tags: reduction | ||||
|  | ||||
| - func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -6037,13 +6091,13 @@ | ||||
|     CPU, CUDA: prod | ||||
|     MPS: prod_mps | ||||
|   autogen: prod.out | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   structured_delegate: prod.int_out | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
| @ -6051,13 +6105,16 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: prod_out | ||||
|     MPS: prod_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: t(Tensor(a) self) -> Tensor(a) | ||||
|   device_check: NoCheck | ||||
| @ -6518,11 +6575,12 @@ | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|  | ||||
| - func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||
| @ -6532,43 +6590,51 @@ | ||||
|     CPU, CUDA: var | ||||
|     MPS: var_mps | ||||
|     MTIA: var_mtia | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   dispatch: | ||||
|     CPU, CUDA: var_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -6577,15 +6643,18 @@ | ||||
|     CPU, CUDA: var_mean | ||||
|     MPS: var_mean_mps | ||||
|   autogen: var_mean.correction_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   cpp_no_default_args: ["unbiased"] | ||||
|   tags: reduction | ||||
|  | ||||
| - func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function | ||||
|   tags: reduction | ||||
|  | ||||
| - func: view_as(Tensor(a) self, Tensor other) -> Tensor(a) | ||||
|   variants: method | ||||
| @ -6845,6 +6914,7 @@ | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: norm | ||||
|   autogen: norm.ScalarOpt_dtype_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -6852,6 +6922,7 @@ | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: norm | ||||
|   autogen: norm.Scalar_out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor | ||||
|   structured_delegate: norm.dtype_out | ||||
| @ -6859,6 +6930,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor | ||||
|   structured_delegate: norm.out | ||||
| @ -6866,6 +6938,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     SparseCPU, SparseCUDA, SparseMPS: sparse_norm | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
| @ -6873,6 +6946,7 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: norm_dtype_out | ||||
|     MPS: norm_dtype_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   structured: True | ||||
| @ -6880,21 +6954,26 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: norm_out | ||||
|     MPS: norm_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| # These four redispatch in their implementation, so OK to be CompositeImplicitAutograd | ||||
| - func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   variants: function, method | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   tags: reduction | ||||
|  | ||||
| - func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) | ||||
|   variants: method, function | ||||
| @ -7112,6 +7191,7 @@ | ||||
|     MTIA: addmm_out_mtia | ||||
|     SparseCPU: addmm_out_sparse_dense_cpu | ||||
|     SparseCUDA: addmm_out_sparse_dense_cuda | ||||
|     SparseMPS: addmm_out_sparse_dense_mps | ||||
|     SparseCsrCPU: addmm_out_sparse_compressed_cpu | ||||
|     SparseCsrCUDA: addmm_out_sparse_compressed_cuda | ||||
|  | ||||
| @ -7121,6 +7201,7 @@ | ||||
|   dispatch: | ||||
|     SparseCPU: addmm_sparse_dense_cpu | ||||
|     SparseCUDA: addmm_sparse_dense_cuda | ||||
|     SparseMPS: addmm_sparse_dense_mps | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense | ||||
|   tags: core | ||||
|  | ||||
| @ -10078,12 +10159,14 @@ | ||||
|     CPU, CUDA: min | ||||
|     MPS: min_mps | ||||
|     QuantizedCPU: min_quantized_cpu | ||||
|   tags: [reduction] | ||||
|  | ||||
| - func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   dispatch: | ||||
|     CPU, CUDA: min_unary_out | ||||
|     QuantizedCPU: min_quantized_unary_out | ||||
|   tags: [reduction] | ||||
|  | ||||
| - func: fmin(Tensor self, Tensor other) -> Tensor | ||||
|   structured_delegate: fmin.out | ||||
| @ -10106,6 +10189,7 @@ | ||||
|     CPU, CUDA: max | ||||
|     MPS: max_mps | ||||
|     QuantizedCPU: max_quantized_cpu | ||||
|   tags: [reduction] | ||||
|  | ||||
| - func: fmax(Tensor self, Tensor other) -> Tensor | ||||
|   structured_delegate: fmax.out | ||||
| @ -10152,6 +10236,7 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: max_unary_out | ||||
|     QuantizedCPU: max_quantized_unary_out | ||||
|   tags: [reduction] | ||||
|  | ||||
| - func: minimum(Tensor self, Tensor other) -> Tensor | ||||
|   structured_delegate: minimum.out | ||||
| @ -10271,6 +10356,7 @@ | ||||
|   device_check: NoCheck   # TensorIterator | ||||
|   structured_delegate: all.all_out | ||||
|   variants: method, function | ||||
|   tags: reduction | ||||
|  | ||||
| - func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck | ||||
| @ -10279,6 +10365,7 @@ | ||||
|     CPU, CUDA: all_all_out | ||||
|     MTIA: all_all_out_mtia | ||||
|     MPS: all_all_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: any(Tensor self) -> Tensor | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -10286,7 +10373,7 @@ | ||||
|   variants: method, function | ||||
|   dispatch: | ||||
|     SparseCPU, SparseCUDA, SparseMPS: any_sparse | ||||
|   tags: core | ||||
|   tags: [core, reduction] | ||||
|  | ||||
| - func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck | ||||
| @ -10294,6 +10381,7 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: any_all_out | ||||
|     MPS: any_all_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   device_check: NoCheck   # TensorIterator | ||||
| @ -14069,16 +14157,10 @@ | ||||
| - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) | ||||
|   python_module: linalg | ||||
|   variants: function | ||||
|   dispatch: | ||||
|     CompositeImplicitAutograd: linalg_lu_factor | ||||
|     MPS: linalg_lu_factor_mps | ||||
|  | ||||
| - func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) | ||||
|   python_module: linalg | ||||
|   variants: function | ||||
|   dispatch: | ||||
|     CompositeImplicitAutograd: linalg_lu_factor_out | ||||
|     MPS: linalg_lu_factor_out_mps | ||||
|  | ||||
| - func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) | ||||
|   python_module: linalg | ||||
| @ -14345,6 +14427,7 @@ | ||||
|   python_module: linalg | ||||
|   variants: function | ||||
|   structured_delegate: linalg_vector_norm.out | ||||
|   tags: reduction | ||||
|  | ||||
| - func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||
|   python_module: linalg | ||||
| @ -14352,6 +14435,7 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: linalg_vector_norm_out | ||||
|     MPS: linalg_vector_norm_out_mps | ||||
|   tags: reduction | ||||
|  | ||||
| - func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||
|   python_module: linalg | ||||
|  | ||||
| @ -40,15 +40,7 @@ | ||||
| #include <thrust/iterator/discard_iterator.h> | ||||
|  | ||||
|  | ||||
| #if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300)) | ||||
| #define IS_CUSPARSE11_AVAILABLE() 1 | ||||
| #else | ||||
| #define IS_CUSPARSE11_AVAILABLE() 0 | ||||
| #endif | ||||
|  | ||||
| #if IS_CUSPARSE11_AVAILABLE() | ||||
| #include <library_types.h> | ||||
| #endif | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| @ -103,17 +95,9 @@ struct csrMatrixRef { | ||||
|   int nnz_{0}; | ||||
|   std::vector<int> size_{}; | ||||
|  | ||||
|   #if IS_CUSPARSE11_AVAILABLE() | ||||
|     cusparseSpMatDescr_t description_{0}; | ||||
|   #else | ||||
|     cusparseMatDescr_t description_{0}; | ||||
|   #endif | ||||
|   cusparseSpMatDescr_t description_{0}; | ||||
|  | ||||
|   csrMatrixRef() { | ||||
|     #if !IS_CUSPARSE11_AVAILABLE() | ||||
|       create_general_description_(description_); | ||||
|     #endif | ||||
|   } | ||||
|   csrMatrixRef() = default; | ||||
|  | ||||
|   csrMatrixRef( | ||||
|       int* csr_indices, | ||||
| @ -126,7 +110,6 @@ struct csrMatrixRef { | ||||
|         csr_values_{csr_values}, | ||||
|         nnz_{nnz}, | ||||
|         size_{size} { | ||||
|     #if IS_CUSPARSE11_AVAILABLE() | ||||
|       cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>(); | ||||
|       TORCH_CUDASPARSE_CHECK(cusparseCreateCsr( | ||||
|         &description_, | ||||
| @ -140,17 +123,10 @@ struct csrMatrixRef { | ||||
|         CUSPARSE_INDEX_32I, | ||||
|         CUSPARSE_INDEX_BASE_ZERO, | ||||
|         cuda_data_type)); | ||||
|     #else | ||||
|       create_general_description_(description_); | ||||
|     #endif | ||||
|   } | ||||
|  | ||||
|   ~csrMatrixRef() { | ||||
|     #if IS_CUSPARSE11_AVAILABLE() | ||||
|       cusparseDestroySpMat(description_); | ||||
|     #else | ||||
|       cusparseDestroyMatDescr(description_); | ||||
|     #endif | ||||
|     cusparseDestroySpMat(description_); | ||||
|   } | ||||
|  | ||||
|   int size(int index) const { | ||||
| @ -196,8 +172,6 @@ struct csrOutput { | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if IS_CUSPARSE11_AVAILABLE() | ||||
|  | ||||
| // RAII guard helps to support cuSparse 11 API for `A @ B` operation | ||||
| // This generic template exists because with cuSparse the `scalar_t` type could be a double or float | ||||
| template <class scalar_t> | ||||
| @ -396,284 +370,6 @@ template struct CusparseMatrixMultiplyOp<float>; | ||||
|  | ||||
| template struct CusparseMatrixMultiplyOp<double>; | ||||
|  | ||||
| #else // if not IS_CUSPARSE11_AVAILABLE() | ||||
|  | ||||
| using DcsrMatrixRef = csrMatrixRef<double>; | ||||
| using ScsrMatrixRef = csrMatrixRef<float>; | ||||
|  | ||||
| // RAII guard helps to support cuSparse 10 API for `A @ B` operation | ||||
| // This generic template exists because with cuSparse the `scalar_t` type could be a double or float | ||||
| template <class scalar_t> | ||||
| struct CusparseMatrixMultiplyOp { | ||||
|   csrOutput operator()( | ||||
|       const csrMatrixRef<scalar_t>& lhs, | ||||
|       const csrMatrixRef<scalar_t>& rhs, | ||||
|       Tensor &output_values, | ||||
|       Tensor &output_indices) | ||||
|   { | ||||
|     static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double."); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Specializacion for `A @ B` operation for double values with cuSparse | ||||
| template<> struct CusparseMatrixMultiplyOp<double> { | ||||
|   csrgemm2Info_t gemm2Info_; | ||||
|  | ||||
|   CusparseMatrixMultiplyOp() { | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_)); | ||||
|   } | ||||
|   ~CusparseMatrixMultiplyOp() { | ||||
|     cusparseDestroyCsrgemm2Info(gemm2Info_); | ||||
|   } | ||||
|  | ||||
|   csrOutput operator ()( | ||||
|       const DcsrMatrixRef& lhs, | ||||
|       const DcsrMatrixRef& rhs, | ||||
|       Tensor &output_values, | ||||
|       Tensor &output_indices) { | ||||
|     double alpha = 1.0; | ||||
|     DcsrMatrixRef empty; | ||||
|     return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices); | ||||
|   } | ||||
|  | ||||
|   csrOutput Dgemm2( | ||||
|       const DcsrMatrixRef& A, | ||||
|       const DcsrMatrixRef& B, | ||||
|       const DcsrMatrixRef& C, | ||||
|       const double* alpha, | ||||
|       const double* beta, | ||||
|       Tensor &output_values, | ||||
|       Tensor &output_indices) { | ||||
|     void* buffer_{nullptr}; | ||||
|     cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle(); | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST)); | ||||
|  | ||||
|     csrOutput out({A.size(0), B.size(1)}); | ||||
|     int innerSize = confirm_mult_size(A.size_, B.size_); | ||||
|     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); | ||||
|  | ||||
|     // Compute needed buffer size | ||||
|     size_t new_bubber_sz; | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt( | ||||
|         cusparseHandle_, | ||||
|         out.size(0), | ||||
|         out.size(1), | ||||
|         innerSize, | ||||
|         alpha, | ||||
|         A.description_, | ||||
|         A.nnz_, | ||||
|         A.csr_pointers_, | ||||
|         A.csr_indices_, | ||||
|         B.description_, | ||||
|         B.nnz_, | ||||
|         B.csr_pointers_, | ||||
|         B.csr_indices_, | ||||
|         beta, | ||||
|         C.description_, | ||||
|         C.nnz_, | ||||
|         C.csr_pointers_, | ||||
|         C.csr_indices_, | ||||
|         gemm2Info_, | ||||
|         &new_bubber_sz)); | ||||
|  | ||||
|     // (Re)allocate buffer if needed | ||||
|     auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); | ||||
|     at::DataPtr data_ptr = allocator.allocate(new_bubber_sz); | ||||
|     buffer_ = data_ptr.get(); | ||||
|  | ||||
|     // Find the resulting non-zero pattern. | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz( | ||||
|         cusparseHandle_, | ||||
|         out.size(0), | ||||
|         out.size(1), | ||||
|         innerSize, | ||||
|         A.description_, | ||||
|         A.nnz_, | ||||
|         A.csr_pointers_, | ||||
|         A.csr_indices_, | ||||
|         B.description_, | ||||
|         B.nnz_, | ||||
|         B.csr_pointers_, | ||||
|         B.csr_indices_, | ||||
|         C.description_, | ||||
|         C.nnz_, | ||||
|         C.csr_pointers_, | ||||
|         C.csr_indices_, | ||||
|         out.description_, | ||||
|         out.csr_pointers_.data_ptr<int>(), | ||||
|         &out.nnz_, | ||||
|         gemm2Info_, | ||||
|         buffer_)); | ||||
|  | ||||
|     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); | ||||
|     out.csr_values_ = at::empty({out.nnz_}, output_values.options()); | ||||
|  | ||||
|     // Perform the gemm2 operation for doubles | ||||
|     // out = alpha ∗ A ∗ B + beta ∗ C | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2( | ||||
|         cusparseHandle_, | ||||
|         out.size(0), | ||||
|         out.size(1), | ||||
|         innerSize, | ||||
|         alpha, | ||||
|         A.description_, | ||||
|         A.nnz_, | ||||
|         A.csr_values_, | ||||
|         A.csr_pointers_, | ||||
|         A.csr_indices_, | ||||
|         B.description_, | ||||
|         B.nnz_, | ||||
|         B.csr_values_, | ||||
|         B.csr_pointers_, | ||||
|         B.csr_indices_, | ||||
|         beta, | ||||
|         C.description_, | ||||
|         C.nnz_, | ||||
|         C.csr_values_, | ||||
|         C.csr_pointers_, | ||||
|         C.csr_indices_, | ||||
|         out.description_, | ||||
|         out.csr_values_.data_ptr<double>(), | ||||
|         out.csr_pointers_.data_ptr<int>(), | ||||
|         out.csr_indices_.data_ptr<int>(), | ||||
|         gemm2Info_, | ||||
|         buffer_)); | ||||
|     return out; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Specializacion for `A @ B` operation for float values with cuSparse | ||||
| template<> struct CusparseMatrixMultiplyOp<float> { | ||||
|   csrgemm2Info_t gemm2Info_; | ||||
|  | ||||
|   CusparseMatrixMultiplyOp() { | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_)); | ||||
|  | ||||
|   } | ||||
|   ~CusparseMatrixMultiplyOp() { | ||||
|     cusparseDestroyCsrgemm2Info(gemm2Info_); | ||||
|   } | ||||
|   csrOutput operator()( | ||||
|       const ScsrMatrixRef& lhs, | ||||
|       const ScsrMatrixRef& rhs, | ||||
|       Tensor &output_values, | ||||
|       Tensor &output_indices) { | ||||
|     float alpha = 1.0; | ||||
|     ScsrMatrixRef empty; | ||||
|     return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices); | ||||
|   } | ||||
|  | ||||
|   csrOutput Sgemm2( | ||||
|       const ScsrMatrixRef& A, | ||||
|       const ScsrMatrixRef& B, | ||||
|       const ScsrMatrixRef& C, | ||||
|       const float* alpha, | ||||
|       const float* beta, | ||||
|       Tensor &output_values, | ||||
|       Tensor &output_indices) { | ||||
|     void* buffer_{nullptr}; | ||||
|     cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle(); | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST)); | ||||
|  | ||||
|     csrOutput out({A.size(0), B.size(1)}); | ||||
|  | ||||
|     int innerSize = confirm_mult_size(A.size_, B.size_); | ||||
|  | ||||
|     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); | ||||
|  | ||||
|     // Compute needed buffer size | ||||
|     size_t new_bubber_sz; | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt( | ||||
|         cusparseHandle_, | ||||
|         out.size(0), | ||||
|         out.size(1), | ||||
|         innerSize, | ||||
|         alpha, | ||||
|         A.description_, | ||||
|         A.nnz_, | ||||
|         A.csr_pointers_, | ||||
|         A.csr_indices_, | ||||
|         B.description_, | ||||
|         B.nnz_, | ||||
|         B.csr_pointers_, | ||||
|         B.csr_indices_, | ||||
|         beta, | ||||
|         C.description_, | ||||
|         C.nnz_, | ||||
|         C.csr_pointers_, | ||||
|         C.csr_indices_, | ||||
|         gemm2Info_, | ||||
|         &new_bubber_sz)); | ||||
|  | ||||
|     auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); | ||||
|     at::DataPtr data_ptr = allocator.allocate(new_bubber_sz); | ||||
|     buffer_ = data_ptr.get(); | ||||
|  | ||||
|     // Find the resulting non-zero pattern. | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz( | ||||
|         cusparseHandle_, | ||||
|         out.size(0), | ||||
|         out.size(1), | ||||
|         innerSize, | ||||
|         A.description_, | ||||
|         A.nnz_, | ||||
|         A.csr_pointers_, | ||||
|         A.csr_indices_, | ||||
|         B.description_, | ||||
|         B.nnz_, | ||||
|         B.csr_pointers_, | ||||
|         B.csr_indices_, | ||||
|         C.description_, | ||||
|         C.nnz_, | ||||
|         C.csr_pointers_, | ||||
|         C.csr_indices_, | ||||
|         out.description_, | ||||
|         out.csr_pointers_.data_ptr<int>(), | ||||
|         &out.nnz_, | ||||
|         gemm2Info_, | ||||
|         buffer_)); | ||||
|  | ||||
|     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); | ||||
|     out.csr_values_ = at::empty({out.nnz_}, output_values.options()); | ||||
|  | ||||
|     // Perform the gemm2 operation for doubles | ||||
|     // out = alpha ∗ A ∗ B + beta ∗ C | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2( | ||||
|         cusparseHandle_, | ||||
|         out.size(0), | ||||
|         out.size(1), | ||||
|         innerSize, | ||||
|         alpha, | ||||
|         A.description_, | ||||
|         A.nnz_, | ||||
|         A.csr_values_, | ||||
|         A.csr_pointers_, | ||||
|         A.csr_indices_, | ||||
|         B.description_, | ||||
|         B.nnz_, | ||||
|         B.csr_values_, | ||||
|         B.csr_pointers_, | ||||
|         B.csr_indices_, | ||||
|         beta, | ||||
|         C.description_, | ||||
|         C.nnz_, | ||||
|         C.csr_values_, | ||||
|         C.csr_pointers_, | ||||
|         C.csr_indices_, | ||||
|         out.description_, | ||||
|         out.csr_values_.data_ptr<float>(), | ||||
|         out.csr_pointers_.data_ptr<int>(), | ||||
|         out.csr_indices_.data_ptr<int>(), | ||||
|         gemm2Info_, | ||||
|         buffer_)); | ||||
|     return out; | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
|  | ||||
| #endif // IS_CUSPARSE11_AVAILABLE() | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void sparse_sparse_matmul_cuda_kernel( | ||||
|     Tensor& result, | ||||
| @ -815,19 +511,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) { | ||||
|   auto output = at::native::empty_like(mat1_); | ||||
|   output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); | ||||
|  | ||||
| #if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM) | ||||
| #if !defined(USE_ROCM) | ||||
|   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] { | ||||
|       sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); | ||||
|   }); | ||||
| #elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM) | ||||
| #else | ||||
|   // ROCm does not support half and bfloat16 types for sparse_matmul | ||||
|   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { | ||||
|       sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); | ||||
|   }); | ||||
| #else | ||||
|   AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { | ||||
|     sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce()); | ||||
|   }); | ||||
| #endif | ||||
|   return output; | ||||
| } | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/native/SparseTensorUtils.h> | ||||
| #include <ATen/ExpandUtils.h> | ||||
| #include <ATen/native/mps/OperationUtils.h> | ||||
| #include <ATen/native/sparse/SparseStubs.h> | ||||
| #include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h> | ||||
| @ -18,6 +19,8 @@ | ||||
| #include <ATen/ops/ones_like.h> | ||||
| #include <ATen/ops/argsort.h> | ||||
| #include <ATen/ops/result_type.h> | ||||
| #include <ATen/ops/bmm_native.h> | ||||
| #include <ATen/ops/addmm_native.h> | ||||
| #include <ATen/ops/copy_sparse_to_sparse.h> | ||||
| #include <ATen/ops/mul.h> | ||||
| #endif | ||||
| @ -33,6 +36,305 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary(); | ||||
| #include <ATen/native/mps/Mul_metallib.h> | ||||
| #endif | ||||
|  | ||||
| static Tensor& s_addmm_out_sparse_dense_mps( | ||||
|     Tensor& r, | ||||
|     const Tensor& t, | ||||
|     const SparseTensor& sparse_, | ||||
|     const Tensor& dense, | ||||
|     const Scalar& beta, | ||||
|     const Scalar& alpha) { | ||||
|   TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim()); | ||||
|   TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim()); | ||||
|   TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim()); | ||||
|   TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim()); | ||||
|  | ||||
|   const int64_t I = sparse_.size(0); | ||||
|   const int64_t J = sparse_.size(1); | ||||
|   const int64_t K = dense.size(1); | ||||
|  | ||||
|   TORCH_CHECK(dense.size(0) == J, | ||||
|       "addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0)); | ||||
|   TORCH_CHECK(t.size(0) == I && t.size(1) == K, | ||||
|       "addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")"); | ||||
|  | ||||
|   r.resize_({I, K}); | ||||
|  | ||||
|   auto sparse = sparse_.coalesce(); | ||||
|   const int64_t nnz = sparse._nnz(); | ||||
|  | ||||
|   if (nnz == 0 || I == 0 || K == 0) { | ||||
|     at::mul_out(r, t, beta); | ||||
|     return r; | ||||
|   } | ||||
|  | ||||
|   const auto v_dtype = sparse._values().scalar_type(); | ||||
|   const auto d_dtype = dense.scalar_type(); | ||||
|   const auto t_dtype = t.scalar_type(); | ||||
|   auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype); | ||||
|  | ||||
|   TORCH_CHECK(canCast(compute_dtype, r.scalar_type()), | ||||
|               "Can't convert computed type ", compute_dtype, " to output ", r.scalar_type()); | ||||
|  | ||||
|   auto indices2d = sparse._indices().contiguous(); | ||||
|   auto values = sparse._values().to(compute_dtype); | ||||
|   auto dense_c = dense.to(compute_dtype).contiguous(); | ||||
|   auto t_c = t.to(compute_dtype).contiguous(); | ||||
|  | ||||
|   const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous(); | ||||
|   Tensor out_buf = out_needs_cast | ||||
|       ? at::empty({I, K}, r.options().dtype(compute_dtype)) | ||||
|       : r; | ||||
|   auto out_contig = out_buf.contiguous(); | ||||
|  | ||||
|   auto device = r.device(); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   const float alpha_f = alpha.to<float>(); | ||||
|   const float beta_f  = beta.to<float>(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values); | ||||
|       auto pso = lib.getPipelineStateForFunc(func); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t gridX = static_cast<uint32_t>(K); | ||||
|       const uint32_t gridZ = static_cast<uint32_t>(I); | ||||
|       const uint32_t tgW = std::min<uint32_t>(gridX, tew); | ||||
|  | ||||
|       MTLSize grid = MTLSizeMake(gridX, 1, gridZ); | ||||
|       MTLSize tgs = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   indices2d, | ||||
|                   values, | ||||
|                   dense_c, | ||||
|                   t_c, | ||||
|                   out_contig, | ||||
|                   std::array<uint32_t, 3>{static_cast<uint32_t>(I), | ||||
|                                            static_cast<uint32_t>(J), | ||||
|                                            static_cast<uint32_t>(K)}, | ||||
|                   std::array<float, 2>{alpha_f, beta_f}, | ||||
|                   static_cast<uint32_t>(nnz)); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   if (out_needs_cast) { | ||||
|     r.copy_(out_contig.to(r.scalar_type())); | ||||
|   } | ||||
|  | ||||
|   return r; | ||||
| } | ||||
|  | ||||
|  | ||||
| static void build_batch_ptr_mps( | ||||
|     const Tensor& indices_dim0, | ||||
|     int64_t B, | ||||
|     Tensor& batch_ptr | ||||
| ) { | ||||
|   // Builds an array of pointers which point to each batches elements. Example: | ||||
|   // idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2]  // 9 non-zero elements | ||||
|   //          └─────┘  └──┘  └─────────┘ | ||||
|   //          batch 0  batch 1  batch 2 | ||||
|   // batch_ptr = [0, 3, 5, 9] | ||||
|   //              │  │  │  └─ end of batch 2 (total nnz) | ||||
|   //              │  │  └──── batch 2 starts at index 5 | ||||
|   //              │  └─────── batch 1 starts at index 3 | ||||
|   //              └────────── batch 0 starts at index 0 | ||||
|   TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected"); | ||||
|   auto device = indices_dim0.device(); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   const int64_t nnz = indices_dim0.numel(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches"); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t Q = static_cast<uint32_t>(B + 1); | ||||
|       const uint32_t tgW = std::min<uint32_t>(Q, tew); | ||||
|       MTLSize grid = MTLSizeMake(Q, 1, 1); | ||||
|       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   indices_dim0, | ||||
|                   batch_ptr, | ||||
|                   std::array<uint32_t, 2>{static_cast<uint32_t>(nnz), | ||||
|                                           static_cast<uint32_t>(B)}); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| static void build_row_ptr_per_batch_mps( | ||||
|     const Tensor& rows, | ||||
|     const Tensor& batch_ptr, | ||||
|     int64_t B, | ||||
|     int64_t I, | ||||
|     Tensor& row_ptr | ||||
| ) { | ||||
|   // Build per-batch CSR-style row pointer arrays from row indices sorted by batch | ||||
|   // Given: | ||||
|   //   rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch | ||||
|   //   batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b | ||||
|   // Produces: | ||||
|   //   - row_ptr: shape [B, I+1] | ||||
|   // | ||||
|   // Example (B = 2, I = 4): | ||||
|   // rows       = [0,   0,   1,  3,  0,   2,    2]   // 7 non-zero elements | ||||
|   //               └─── batch 0 ──┘  └─ batch 1 ─┘ | ||||
|   // batch_ptr  = [0, 4, 7] | ||||
|   //               │  │  └─ end of batch 1 (total nnz) | ||||
|   //               │  └──── end of batch 0/start of batch 1 | ||||
|   //               └─────── start of batch 0 | ||||
|   // | ||||
|   // per-batch row pointers (I+1 entries each): | ||||
|   //   row_ptr[0] = [0, 2, 3, 3, 4] | ||||
|   //   row_ptr[1] = [0, 1, 1, 3, 3] | ||||
|   // laid out in memory: [0, 2, 3, 3, 4,  0, 1, 1, 3, 3] | ||||
|   TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected"); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch"); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t Qx = static_cast<uint32_t>(I + 1); | ||||
|       const uint32_t Qy = static_cast<uint32_t>(B); | ||||
|       const uint32_t tgW = std::min<uint32_t>(Qx, tew); | ||||
|  | ||||
|       MTLSize grid = MTLSizeMake(Qx, Qy, 1); | ||||
|       MTLSize tgs = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   rows, | ||||
|                   batch_ptr, | ||||
|                   row_ptr, | ||||
|                   std::array<uint32_t, 2>{static_cast<uint32_t>(I), | ||||
|                                            static_cast<uint32_t>(B)}); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) { | ||||
|   TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device()); | ||||
|   TORCH_CHECK(self_.is_mps(),  "bmm_sparse: expected 'self' to be MPS, got ", self_.device()); | ||||
|   TORCH_CHECK(mat2_.is_mps(),  "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device()); | ||||
|  | ||||
|   TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim()); | ||||
|   TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim()); | ||||
|   TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim()); | ||||
|  | ||||
|   TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match"); | ||||
|   TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match"); | ||||
|  | ||||
|   const int64_t B = self_.size(0); | ||||
|   const int64_t I = self_.size(1); | ||||
|   const int64_t J = self_.size(2); | ||||
|   const int64_t K = mat2_.size(2); | ||||
|  | ||||
|   auto self = self_.coalesce(); | ||||
|   const int64_t nnz = self._nnz(); | ||||
|   if (nnz == 0) { | ||||
|     return result_.zero_(); | ||||
|   } | ||||
|  | ||||
|   const auto computeDtype = at::kFloat; | ||||
|  | ||||
|   auto indices = self._indices(); | ||||
|   auto values  = self._values(); | ||||
|  | ||||
|   auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype); | ||||
|   auto mat2_c = mat2_.scalar_type()   == computeDtype ? mat2_   : mat2_.to(computeDtype); | ||||
|   auto mat2_contig = mat2_c.contiguous(); | ||||
|  | ||||
|   auto idx_b = indices.select(0, 0).contiguous(); | ||||
|   auto idx_i = indices.select(0, 1).contiguous(); | ||||
|   auto idx_j = indices.select(0, 2).contiguous(); | ||||
|  | ||||
|   // builds an array of pointers of where the batch_idx's pointer starts and ends | ||||
|   // look in function for better explanation | ||||
|   auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong)); | ||||
|   build_batch_ptr_mps(idx_b, B, batch_ptr); | ||||
|   // build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals | ||||
|   auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong)); | ||||
|   build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr); | ||||
|  | ||||
|   const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous(); | ||||
|   Tensor out_buf = out_needs_cast | ||||
|       ? at::empty({B, I, K}, result_.options().dtype(computeDtype)) | ||||
|       : result_; | ||||
|   auto out_contig = out_buf.contiguous(); | ||||
|  | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values)); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew); | ||||
|  | ||||
|       // One threadgroup per (row i, batch b), lanes cover K | ||||
|       MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B); | ||||
|       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   idx_i, | ||||
|                   idx_j, | ||||
|                   values_c, | ||||
|                   mat2_contig, | ||||
|                   out_contig, | ||||
|                   row_ptr, | ||||
|                   std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K}); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
|   if (out_needs_cast) { | ||||
|     result_.copy_(out_contig.to(result_.scalar_type())); | ||||
|   } | ||||
|   return result_; | ||||
| } | ||||
|  | ||||
| Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) { | ||||
|   Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options()); | ||||
|   return bmm_out_sparse_mps(self, mat2, result); | ||||
| } | ||||
|  | ||||
| Tensor& addmm_out_sparse_dense_mps( | ||||
|     const Tensor& self, | ||||
|     const SparseTensor& mat1, | ||||
|     const Tensor& mat2, | ||||
|     const Scalar& beta, | ||||
|     const Scalar& alpha, | ||||
|     Tensor& result) { | ||||
|   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); | ||||
|   return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha); | ||||
| } | ||||
|  | ||||
| Tensor addmm_sparse_dense_mps( | ||||
|     const Tensor& self, | ||||
|     const SparseTensor& mat1, | ||||
|     const Tensor& mat2, | ||||
|     const Scalar& beta, | ||||
|     const Scalar& alpha | ||||
| ) { | ||||
|   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); | ||||
|   Tensor result = at::empty({0}, self.options()); | ||||
|   return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha); | ||||
| } | ||||
|  | ||||
| static SparseTensor& mul_out_dense_sparse_mps( | ||||
|     const Tensor& dense, | ||||
|     const Tensor& sparse, | ||||
|  | ||||
| @ -1,10 +1,103 @@ | ||||
| #include <metal_stdlib> | ||||
| #include <c10/metal/indexing.h> | ||||
| #include <c10/metal/utils.h> | ||||
| using namespace c10::metal; | ||||
| using namespace metal; | ||||
|  | ||||
| inline uint lower_bound_i64(device const long* arr, uint lo, uint hi, long key) { | ||||
|   uint l = lo, r = hi; | ||||
|   while (l < r) { | ||||
|     uint m = (l + r) >> 1; | ||||
|     long v = arr[m]; | ||||
|     if (v < key) { | ||||
|       l = m + 1; | ||||
|     } else { | ||||
|       r = m; | ||||
|     } | ||||
|   } | ||||
|   return l; | ||||
| } | ||||
|  | ||||
| template <typename T> struct MulAccum { using type = float; }; | ||||
| template <> struct MulAccum<float2> { using type = float2; }; | ||||
| inline uint upper_bound_i64(device const long* arr, uint lo, uint hi, long key) { | ||||
|   uint l = lo, r = hi; | ||||
|   while (l < r) { | ||||
|     uint m = (l + r) >> 1; | ||||
|     long v = arr[m]; | ||||
|     if (v <= key) { | ||||
|       l = m + 1; | ||||
|     } else { | ||||
|       r = m; | ||||
|     } | ||||
|   } | ||||
|   return l; | ||||
| } | ||||
|  | ||||
| kernel void build_row_ptr_from_sorted_rows_by_batch( | ||||
|     device const long* rows        [[buffer(0)]], | ||||
|     device const long* batch_ptr   [[buffer(1)]], | ||||
|     device long*       row_ptr     [[buffer(2)]], | ||||
|     constant uint2&    dims        [[buffer(3)]], | ||||
|     uint3              tid         [[thread_position_in_grid]]) | ||||
| { | ||||
|   const uint I = dims.x; | ||||
|   const uint B = dims.y; | ||||
|  | ||||
|   const uint i = tid.x; | ||||
|   const uint b = tid.y; | ||||
|  | ||||
|   if (b >= B || i > I) return; | ||||
|  | ||||
|   const uint base = (uint)batch_ptr[b]; | ||||
|   const uint lim  = (uint)batch_ptr[b + 1]; | ||||
|  | ||||
|   const ulong out_base = (ulong)b * (ulong)(I + 1); | ||||
|  | ||||
|   if (i == I) { | ||||
|     row_ptr[out_base + (ulong)I] = (long)lim; | ||||
|   } else { | ||||
|     const long key = (long)i; | ||||
|     const uint pos = lower_bound_i64(rows, base, lim, key); | ||||
|     row_ptr[out_base + (ulong)i] = (long)pos; | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| kernel void spmm_bmm_coo_rows_grouped( | ||||
|     device const long*   cols      [[buffer(1)]], | ||||
|     device const T*      vals      [[buffer(2)]], | ||||
|     device const T*      dense     [[buffer(3)]], | ||||
|     device T*            out       [[buffer(4)]], | ||||
|     device const long*   row_ptr   [[buffer(5)]], | ||||
|     constant uint4&      dims      [[buffer(6)]], | ||||
|     uint3                tid       [[thread_position_in_grid]], | ||||
|     uint3                ltid      [[thread_position_in_threadgroup]], | ||||
|     uint3                tptg      [[threads_per_threadgroup]]) | ||||
| { | ||||
|   const uint I = dims.y; | ||||
|   const uint J = dims.z; | ||||
|   const uint K = dims.w; | ||||
|  | ||||
|   const uint b = tid.z; | ||||
|   const uint i = tid.y; | ||||
|   const uint lane = ltid.x; | ||||
|   const uint tgW  = tptg.x; | ||||
|  | ||||
|   const ulong rp_base = (ulong)b * (ulong)(I + 1); | ||||
|   const uint start = (uint)row_ptr[rp_base + (ulong)i]; | ||||
|   const uint end   = (uint)row_ptr[rp_base + (ulong)i + 1]; | ||||
|  | ||||
|   for (uint k = lane; k < K; k += tgW) { | ||||
|     auto acc = static_cast<accum_t<T>>(T(0)); | ||||
|     for (uint p = start; p < end; ++p) { | ||||
|       const uint c = (uint)cols[p]; | ||||
|       const auto v = static_cast<accum_t<T>>(vals[p]); | ||||
|       const uint d_off = ((b * J) + c) * K + k; | ||||
|       const auto d = static_cast<accum_t<T>>(dense[d_off]); | ||||
|       acc += mul(v, d); | ||||
|     } | ||||
|     const uint y_off = ((b * I) + i) * K + k; | ||||
|     out[y_off] = static_cast<T>(acc); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| kernel void dense_sparse_mul_kernel( | ||||
| @ -32,10 +125,9 @@ kernel void dense_sparse_mul_kernel( | ||||
|   ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col; | ||||
|   ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col; | ||||
|  | ||||
|   using accum_t = typename MulAccum<T>::type; | ||||
|   const accum_t a = static_cast<accum_t>(values[val_idx]); | ||||
|   const accum_t b = static_cast<accum_t>(dense[dense_idx]); | ||||
|   out_values[val_idx] = static_cast<T>(a * b); | ||||
|   const auto a = static_cast<accum_t<T>>(values[val_idx]); | ||||
|   const auto b = static_cast<accum_t<T>>(dense[dense_idx]); | ||||
|   out_values[val_idx] = static_cast<T>(mul(a, b)); | ||||
| } | ||||
|  | ||||
| kernel void intersect_binary_search( | ||||
| @ -120,6 +212,76 @@ kernel void fused_gather_mul_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| kernel void build_batch_ptr_from_sorted_batches( | ||||
|     device const long* batches       [[buffer(0)]], | ||||
|     device long*       batch_ptr     [[buffer(1)]], | ||||
|     constant uint2&    nnz_B         [[buffer(2)]], | ||||
|     uint3              tid           [[thread_position_in_grid]]) | ||||
| { | ||||
|   uint b = tid.x; | ||||
|   uint nnz = nnz_B.x; | ||||
|   uint batch = nnz_B.y; | ||||
|  | ||||
|   if (b == batch) { | ||||
|     batch_ptr[b] = (long)nnz; | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   uint lo = 0; | ||||
|   uint hi = nnz; | ||||
|   long key = (long)b; | ||||
|   while (lo < hi) { | ||||
|     uint mid = (lo + hi) >> 1; | ||||
|     long v = batches[mid]; | ||||
|     if (v < key) lo = mid + 1; | ||||
|     else         hi = mid; | ||||
|   } | ||||
|   batch_ptr[b] = (long)lo; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| kernel void spmm_addmm_coo( | ||||
|     device const long*   indices2d   [[buffer(0)]], | ||||
|     device const T*      vals        [[buffer(1)]], | ||||
|     device const T*      dense       [[buffer(2)]], | ||||
|     device const T*      t_in        [[buffer(3)]], | ||||
|     device T*            out         [[buffer(4)]], | ||||
|     constant uint3&      dims        [[buffer(5)]], | ||||
|     constant float2&     alpha_beta  [[buffer(6)]], | ||||
|     constant uint&       nnz         [[buffer(7)]], | ||||
|     uint3                tid         [[thread_position_in_grid]]) | ||||
| { | ||||
|   const uint K = dims.z; | ||||
|   const uint k = tid.x; | ||||
|   const uint i = tid.z; | ||||
|   const float alpha = alpha_beta.x; | ||||
|   const float beta = alpha_beta.y; | ||||
|  | ||||
|   device const long* rows = indices2d; | ||||
|   device const long* cols = indices2d + nnz; | ||||
|  | ||||
|   const uint start = lower_bound_i64(rows, 0u, nnz, (long)i); | ||||
|   const uint end = upper_bound_i64(rows, 0u, nnz, (long)i); | ||||
|  | ||||
|   // accumulator is float for scalar/half/bfloat and float2 for float2 | ||||
|   auto acc = static_cast<accum_t<T>>(T(0)); | ||||
|  | ||||
|   for (uint p = start; p < end; ++p) { | ||||
|     const uint c = (uint)cols[p]; | ||||
|     const auto v = static_cast<accum_t<T>>(vals[p]); | ||||
|     const uint dense_off = c * K + k; | ||||
|     const auto d = static_cast<accum_t<T>>(dense[dense_off]); | ||||
|     acc += mul(v, d); | ||||
|   } | ||||
|  | ||||
|   const uint off = i * K + k; | ||||
|   const auto base = (beta != 0.0f) ? (static_cast<accum_t<T>>(t_in[off]) * beta) : static_cast<accum_t<T>>(T(0)); | ||||
|   const auto y = base + alpha * acc; | ||||
|   out[off] = static_cast<T>(y); | ||||
| } | ||||
|  | ||||
|  | ||||
| #define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE)                                 \ | ||||
|   template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void     \ | ||||
|   dense_sparse_mul_kernel<DTYPE>(                                           \ | ||||
| @ -151,6 +313,35 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2); | ||||
|       constant uint2&     dims_output   [[buffer(8)]],                       \ | ||||
|       uint3               gid           [[thread_position_in_grid]]); | ||||
|  | ||||
| INSTANTIATE_FUSED_GATHER_MUL(float); | ||||
| INSTANTIATE_FUSED_GATHER_MUL(half); | ||||
| INSTANTIATE_FUSED_GATHER_MUL(bfloat); | ||||
| INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL); | ||||
|  | ||||
|  | ||||
| #define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE)                         \ | ||||
|   template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void    \ | ||||
|   spmm_bmm_coo_rows_grouped<DTYPE>(                                          \ | ||||
|       device const long*   cols      [[buffer(1)]],                          \ | ||||
|       device const DTYPE*  vals      [[buffer(2)]],                          \ | ||||
|       device const DTYPE*  dense     [[buffer(3)]],                          \ | ||||
|       device DTYPE*        out       [[buffer(4)]],                          \ | ||||
|       device const long*   row_ptr   [[buffer(5)]],                          \ | ||||
|       constant uint4&      dims      [[buffer(6)]],                          \ | ||||
|       uint3                tid       [[thread_position_in_grid]],            \ | ||||
|       uint3                ltid      [[thread_position_in_threadgroup]],     \ | ||||
|       uint3                tptg      [[threads_per_threadgroup]]); | ||||
|  | ||||
| INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED); | ||||
|  | ||||
| #define INSTANTIATE_SPMM_ADDMM_COO(DTYPE) \ | ||||
|   template [[host_name("spmm_addmm_coo_" #DTYPE)]] kernel void  \ | ||||
|   spmm_addmm_coo<DTYPE>(                                        \ | ||||
|     device const long*   indices2d   [[buffer(0)]],             \ | ||||
|     device const DTYPE*  vals        [[buffer(1)]],             \ | ||||
|     device const DTYPE*  dense       [[buffer(2)]],             \ | ||||
|     device const DTYPE*  t_in        [[buffer(3)]],             \ | ||||
|     device DTYPE*        out         [[buffer(4)]],             \ | ||||
|     constant uint3&      dims        [[buffer(5)]],             \ | ||||
|     constant float2&     alpha_beta  [[buffer(6)]],             \ | ||||
|     constant uint&       nnz         [[buffer(7)]],             \ | ||||
|     uint3                tid         [[thread_position_in_grid]]); | ||||
|  | ||||
| INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_ADDMM_COO); | ||||
|  | ||||
| @ -93,3 +93,7 @@ | ||||
|           This operator does not support cudagraphs. The presence of this tag on an operator will cause | ||||
|           Inductor to split the graph around this operator. Note that operators without this tag may still | ||||
|           not support CUDAGraphs. Inductor may have other hardcoded lists around that. | ||||
| - tag: reduction | ||||
|   desc: | | ||||
|           This tag indicates that an operator performs a reduction operation, computing aggregate values | ||||
|           (sum, mean, max, min, etc.) across one or more dimensions of the input tensor(s). | ||||
|  | ||||
| @ -202,7 +202,6 @@ supported: | ||||
|   - select_backward | ||||
|   - _trilinear | ||||
|   - linalg_pinv.atol_rtol_tensor | ||||
|   - svd | ||||
|   - logsumexp.out | ||||
| symint: | ||||
|   - empty.memory_format | ||||
|  | ||||
| @ -1751,8 +1751,8 @@ def maybe_snapshot_memory(should_snapshot_memory, suffix): | ||||
|                         f"{output_filename.rstrip('.csv')}_{suffix}.pickle", | ||||
|                     ) | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 log.error("Failed to save memory snapshot, %s", e) | ||||
|             except Exception: | ||||
|                 log.exception("Failed to save memory snapshot") | ||||
|  | ||||
|             torch.cuda.memory._record_memory_history(enabled=None) | ||||
|  | ||||
|  | ||||
| @ -124,7 +124,7 @@ with open(MODELS_FILENAME) as fh: | ||||
|             continue | ||||
|         batch_size = int(batch_size) | ||||
|         BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size | ||||
| assert len(BATCH_SIZE_KNOWN_MODELS) | ||||
| assert BATCH_SIZE_KNOWN_MODELS | ||||
|  | ||||
|  | ||||
| try: | ||||
|  | ||||
| @ -296,8 +296,8 @@ class OperatorInputsLoader: | ||||
|         for key in self.operator_db.keys(): | ||||
|             try: | ||||
|                 op = eval(key) | ||||
|             except AttributeError as ae: | ||||
|                 log.warning("Evaluating an op name into an OpOverload: %s", ae) | ||||
|             except AttributeError: | ||||
|                 log.warning("Evaluating an op name into an OpOverload", exc_info=True) | ||||
|                 continue | ||||
|             yield op | ||||
|  | ||||
|  | ||||
| @ -3,6 +3,7 @@ import sys | ||||
| from benchmark_base import BenchmarkBase | ||||
|  | ||||
| import torch | ||||
| from torch._dynamo.utils import CompileTimeInstructionCounter | ||||
|  | ||||
|  | ||||
| class Benchmark(BenchmarkBase): | ||||
| @ -32,7 +33,11 @@ class Benchmark(BenchmarkBase): | ||||
|     def _work(self): | ||||
|         # enable_cpp_symbolic_shape_guards has impact on this benchmark | ||||
|         # Keep using False value for consistency. | ||||
|         with torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False): | ||||
|         with ( | ||||
|             torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False), | ||||
|             torch._export.config.patch(use_new_tracer_experimental=True), | ||||
|             CompileTimeInstructionCounter.record(), | ||||
|         ): | ||||
|             torch.export.export(self.m, (self.input,), strict=True) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1,8 +1,8 @@ | ||||
| add_loop_eager,compile_time_instruction_count,3070000000,0.1 | ||||
| add_loop_eager,compile_time_instruction_count,3184000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1 | ||||
| add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1 | ||||
| basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000, | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1 | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000 | ||||
|  | ||||
|  | ||||
|  | ||||
| update_hint_regression,compile_time_instruction_count,1719000000,0.1 | ||||
| update_hint_regression,compile_time_instruction_count,1645000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| sum_floordiv_regression,compile_time_instruction_count,966100000,0.1 | ||||
| sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1 | ||||
| aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1 | ||||
| aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1 | ||||
| aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1 | ||||
| aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1 | ||||
| aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1 | ||||
| aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1 | ||||
| mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1 | ||||
| basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1 | ||||
| basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1 | ||||
|  | ||||
| 
 | 
| @ -43,6 +43,7 @@ tolerance: | ||||
|     - doctr_reco_predictor | ||||
|     - drq | ||||
|     - phlippe_resnet | ||||
|     - pytorch_CycleGAN_and_pix2pix | ||||
|  | ||||
|   higher_bf16: | ||||
|     - doctr_reco_predictor | ||||
|  | ||||
| @ -127,7 +127,7 @@ def trainbench( | ||||
|         bwd_time = bwd_start_event.elapsed_time(bwd_end_event) | ||||
|         return fwd_time, bwd_time | ||||
|  | ||||
|     creator_args = creator_args = { | ||||
|     creator_args = { | ||||
|         "seqLength": seqLength, | ||||
|         "numLayers": numLayers, | ||||
|         "inputSize": inputSize, | ||||
|  | ||||
| @ -12,7 +12,7 @@ def modeldef(request, net_name, executor, fuser): | ||||
|  | ||||
|     # Given a 'net_name' provided by generate_tests, build the thing | ||||
|     name, rnn_creator, context = get_nn_runners(net_name)[0] | ||||
|     creator_args = creator_args = { | ||||
|     creator_args = { | ||||
|         "seqLength": 100, | ||||
|         "numLayers": 1, | ||||
|         "inputSize": 512, | ||||
|  | ||||
| @ -85,7 +85,7 @@ class WeightOnlyInt8QuantHandler: | ||||
|                 cur_state_dict[f"{fqn}.weight"] = int8_weight | ||||
|                 cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) | ||||
|             elif isinstance(mod, ConditionalFeedForward): | ||||
|                 for weight_idx in range(0, 3): | ||||
|                 for weight_idx in range(3): | ||||
|                     weight_name = f"w{weight_idx + 1}" | ||||
|                     scales_name = f"scales{weight_idx + 1}" | ||||
|                     weight = getattr(mod, weight_name) | ||||
|  | ||||
| @ -44,21 +44,101 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho | ||||
| PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000 | ||||
| PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000 | ||||
| PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000 | ||||
| PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000 | ||||
| PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000 | ||||
| PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000 | ||||
| PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000 | ||||
| PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000 | ||||
| PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000 | ||||
| PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000 | ||||
| PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000 | ||||
| PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000 | ||||
| PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000 | ||||
| PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000 | ||||
| PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000 | ||||
| PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000 | ||||
| PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000 | ||||
| PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000 | ||||
| PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000 | ||||
| PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000 | ||||
| PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000 | ||||
| PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000 | ||||
| PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000 | ||||
| PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000 | ||||
| PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000 | ||||
| PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000 | ||||
| PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000 | ||||
| PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000 | ||||
| PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000 | ||||
| PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000 | ||||
| PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000 | ||||
| PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000 | ||||
| PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000 | ||||
| PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000 | ||||
| PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000 | ||||
| PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000 | ||||
| PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000 | ||||
| PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000 | ||||
| PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000 | ||||
| PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000 | ||||
| PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000 | ||||
| PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000 | ||||
| PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000 | ||||
| PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000 | ||||
| PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000 | ||||
| PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000 | ||||
| PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000 | ||||
| PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000 | ||||
| PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000 | ||||
| PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000 | ||||
| PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000 | ||||
| PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000 | ||||
| PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000 | ||||
| PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000 | ||||
| PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000 | ||||
| PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000 | ||||
| PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000 | ||||
| PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000 | ||||
| PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000 | ||||
| PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000 | ||||
| PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000 | ||||
| PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000 | ||||
| PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000 | ||||
| PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000 | ||||
| PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000 | ||||
| PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000 | ||||
| @ -71,6 +151,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313, | ||||
| PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000 | ||||
| PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000 | ||||
| PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000 | ||||
| PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000 | ||||
| PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000 | ||||
| PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000 | ||||
| PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000 | ||||
| PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000 | ||||
| PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000 | ||||
|  | ||||
| 
 | 
| @ -580,6 +580,9 @@ class BenchmarkRunner: | ||||
|                 else "unknown" | ||||
|             ) | ||||
|  | ||||
|             # Extract operator name from test_name | ||||
|             operator_name = test_name.split("_")[0] | ||||
|  | ||||
|             # Create the record | ||||
|             @dataclass | ||||
|             class BenchmarkInfo: | ||||
| @ -593,6 +596,7 @@ class BenchmarkRunner: | ||||
|                 name: str | ||||
|                 type: str | ||||
|                 origins: list[str] | ||||
|                 extra_info: dict[str, Any] | ||||
|  | ||||
|             @dataclass | ||||
|             class MetricInfo: | ||||
| @ -618,10 +622,14 @@ class BenchmarkRunner: | ||||
|                         "device": device, | ||||
|                         "arch": device_arch, | ||||
|                         "use_compile": use_compile, | ||||
|                         "operator_name": operator_name, | ||||
|                     }, | ||||
|                 ), | ||||
|                 model=ModelInfo( | ||||
|                     name=test_name, type="micro-benchmark", origins=["pytorch"] | ||||
|                     name=test_name, | ||||
|                     type="micro-benchmark", | ||||
|                     origins=["pytorch"], | ||||
|                     extra_info={"operator_name": operator_name}, | ||||
|                 ), | ||||
|                 metric=MetricInfo( | ||||
|                     name="latency", | ||||
|  | ||||
| @ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list( | ||||
|     ], | ||||
|     cross_product_configs={ | ||||
|         "device": ["cpu"], | ||||
|         "dtype": [torch.float], | ||||
|         "dtype": [torch.float, torch.bfloat16, torch.float64], | ||||
|     }, | ||||
|     tags=["short"], | ||||
| ) | ||||
| @ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list( | ||||
|     ], | ||||
|     cross_product_configs={ | ||||
|         "device": ["cpu", "cuda"], | ||||
|         "dtype_one": [torch.int32], | ||||
|         "dtype_two": [torch.int32], | ||||
|         "dtype_one": [torch.int32, torch.uint8], | ||||
|         "dtype_two": [torch.int32, torch.uint8], | ||||
|     }, | ||||
|     tags=["short"], | ||||
| ) | ||||
| @ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs( | ||||
|     N=[32, 64], | ||||
|     K=[256, 512], | ||||
|     device=["cpu", "cuda"], | ||||
|     dtype_one=[torch.int8, torch.int32], | ||||
|     dtype_two=[torch.int8, torch.int32], | ||||
|     dtype_one=[torch.int8, torch.int32, torch.uint8], | ||||
|     dtype_two=[torch.int8, torch.int32, torch.uint8], | ||||
|     tags=["long"], | ||||
| ) | ||||
|  | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -176,8 +176,8 @@ THIRD_PARTY_LIBS = { | ||||
|     "omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"], | ||||
|     "pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"], | ||||
|     "psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"], | ||||
|     "pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"], | ||||
|     "pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"], | ||||
|     "pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"], | ||||
|     "pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"], | ||||
|     "moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"], | ||||
|     "pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"], | ||||
|     "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"], | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	