mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 00:24:53 +08:00 
			
		
		
		
	Compare commits
	
		
			17 Commits
		
	
	
		
			ciflow/ind
			...
			main-enabl
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e752a29afd | |||
| 36b622bb72 | |||
| 83a04f38a4 | |||
| 6579829bee | |||
| 2b856676f3 | |||
| 5746261c97 | |||
| b3c94fd0fc | |||
| 6fd366b2c7 | |||
| fe25f6ab59 | |||
| ca89e5732f | |||
| f12cb265d4 | |||
| 7dc6bf5377 | |||
| e5ba464808 | |||
| 7d95185044 | |||
| 77fb3c1cac | |||
| 11a3d1d87b | |||
| 8c6d9feb26 | 
| @ -113,7 +113,6 @@ case "$tag" in | ||||
|     UCX_COMMIT=${_UCX_COMMIT} | ||||
|     UCC_COMMIT=${_UCC_COMMIT} | ||||
|     TRITON=yes | ||||
|     INSTALL_MINGW=yes | ||||
|     ;; | ||||
|   pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11) | ||||
|     CUDA_VERSION=13.0.0 | ||||
| @ -362,7 +361,6 @@ docker build \ | ||||
|        --build-arg "OPENBLAS=${OPENBLAS:-}" \ | ||||
|        --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \ | ||||
|        --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \ | ||||
|        --build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \ | ||||
|        -f $(dirname ${DOCKERFILE})/Dockerfile \ | ||||
|        -t "$tmp_tag" \ | ||||
|        "$@" \ | ||||
|  | ||||
| @ -1,10 +0,0 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| set -ex | ||||
|  | ||||
| # Install MinGW-w64 for Windows cross-compilation | ||||
| apt-get update | ||||
| apt-get install -y g++-mingw-w64-x86-64-posix | ||||
|  | ||||
| echo "MinGW-w64 installed successfully" | ||||
| x86_64-w64-mingw32-g++ --version | ||||
| @ -20,7 +20,7 @@ pip_install \ | ||||
|  | ||||
| pip_install coloredlogs packaging | ||||
| pip_install onnxruntime==1.23.0 | ||||
| pip_install onnxscript==0.5.4 | ||||
| pip_install onnxscript==0.5.3 | ||||
|  | ||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||
| # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ | ||||
|  | ||||
| @ -39,13 +39,9 @@ case ${DOCKER_TAG_PREFIX} in | ||||
|         DOCKER_GPU_BUILD_ARG="" | ||||
|         ;; | ||||
|     rocm*) | ||||
|         # we want the patch version of 7.0 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         # we want the patch version of 6.4 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         BASE_TARGET=rocm | ||||
|         GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete | ||||
|  | ||||
| @ -75,13 +75,9 @@ case ${image} in | ||||
|         DOCKERFILE_SUFFIX="_cuda_aarch64" | ||||
|         ;; | ||||
|     manylinux2_28-builder:rocm*) | ||||
|         # we want the patch version of 7.0 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         # we want the patch version of 6.4 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         TARGET=rocm_final | ||||
|         MANY_LINUX_VERSION="2_28" | ||||
|  | ||||
| @ -103,11 +103,6 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt | ||||
| RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi | ||||
| RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt | ||||
|  | ||||
| ARG INSTALL_MINGW | ||||
| COPY ./common/install_mingw.sh install_mingw.sh | ||||
| RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi | ||||
| RUN rm install_mingw.sh | ||||
|  | ||||
| ARG TRITON | ||||
| ARG TRITON_CPU | ||||
|  | ||||
|  | ||||
| @ -485,22 +485,6 @@ test_inductor_aoti() { | ||||
|   /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile | ||||
| } | ||||
|  | ||||
| test_inductor_aoti_cross_compile_for_windows() { | ||||
|  | ||||
|   TEST_REPORTS_DIR=$(pwd)/test/test-reports | ||||
|   mkdir -p "$TEST_REPORTS_DIR" | ||||
|  | ||||
|   # Set WINDOWS_CUDA_HOME environment variable | ||||
|   WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted" | ||||
|   export WINDOWS_CUDA_HOME | ||||
|  | ||||
|   echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME" | ||||
|   echo "Contents:" | ||||
|   ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true | ||||
|  | ||||
|   python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib" | ||||
| } | ||||
|  | ||||
| test_inductor_cpp_wrapper_shard() { | ||||
|   if [[ -z "$NUM_TEST_SHARDS" ]]; then | ||||
|     echo "NUM_TEST_SHARDS must be defined to run a Python test shard" | ||||
| @ -916,7 +900,7 @@ test_inductor_set_cpu_affinity(){ | ||||
|   export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD" | ||||
|   export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1" | ||||
|  | ||||
|   if [[ "$(uname -m)" != "aarch64" ]]; then | ||||
|   if [[ "${TEST_CONFIG}" != *aarch64* ]]; then | ||||
|     # Use Intel OpenMP for x86 | ||||
|     IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so" | ||||
|     export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD" | ||||
| @ -930,7 +914,7 @@ test_inductor_set_cpu_affinity(){ | ||||
|   cores=$((cpus / thread_per_core)) | ||||
|  | ||||
|   # Set number of cores to 16 on aarch64 for performance runs | ||||
|   if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then | ||||
|   if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then | ||||
|     cores=16 | ||||
|   fi | ||||
|   export OMP_NUM_THREADS=$cores | ||||
| @ -1631,7 +1615,6 @@ test_operator_benchmark() { | ||||
|   TEST_REPORTS_DIR=$(pwd)/test/test-reports | ||||
|   mkdir -p "$TEST_REPORTS_DIR" | ||||
|   TEST_DIR=$(pwd) | ||||
|   ARCH=$(uname -m) | ||||
|  | ||||
|   test_inductor_set_cpu_affinity | ||||
|  | ||||
| @ -1646,7 +1629,7 @@ test_operator_benchmark() { | ||||
|   pip_install pandas | ||||
|   python check_perf_csv.py \ | ||||
|       --actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \ | ||||
|       --expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv" | ||||
|       --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv" | ||||
| } | ||||
|  | ||||
| test_operator_microbenchmark() { | ||||
| @ -1683,7 +1666,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then | ||||
|     python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 | ||||
|   fi | ||||
|   python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py | ||||
| elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then | ||||
| elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then | ||||
|   test_linux_aarch64 | ||||
| elif [[ "${TEST_CONFIG}" == *backward* ]]; then | ||||
|   test_forward_backward_compatibility | ||||
| @ -1734,8 +1717,6 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then | ||||
|   test_inductor_triton_cpu | ||||
| elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then | ||||
|   test_inductor_micro_benchmark | ||||
| elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then | ||||
|   test_inductor_aoti_cross_compile_for_windows | ||||
| elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then | ||||
|   install_torchvision | ||||
|   id=$((SHARD_NUMBER-1)) | ||||
|  | ||||
							
								
								
									
										2
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								.flake8
									
									
									
									
									
								
							| @ -13,6 +13,8 @@ ignore = | ||||
|     EXE001, | ||||
|     # these ignores are from flake8-bugbear; please fix! | ||||
|     B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910 | ||||
|     # these ignores are from flake8-comprehensions; please fix! | ||||
|     C407, | ||||
|     # these ignores are from flake8-logging-format; please fix! | ||||
|     G100,G101,G200 | ||||
|     # these ignores are from flake8-simplify. please fix or ignore with commented reason | ||||
|  | ||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							| @ -133,32 +133,3 @@ | ||||
|  | ||||
| "ciflow/vllm": | ||||
| - .github/ci_commit_pins/vllm.txt | ||||
|  | ||||
| "ciflow/b200": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/**/*cublas* | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
| "ciflow/h100": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/**/*cublas* | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
| "ciflow/rocm": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
							
								
								
									
										42
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										42
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							| @ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { | ||||
|         "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" | ||||
|     ), | ||||
|     "12.9": ( | ||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " | ||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | " | ||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | " | ||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | " | ||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | " | ||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | " | ||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | " | ||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " | ||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | " | ||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | " | ||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | " | ||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" | ||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'" | ||||
|     ), | ||||
|     "13.0": ( | ||||
|         "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | " | ||||
| @ -241,11 +241,7 @@ def generate_libtorch_matrix( | ||||
|             arches += CUDA_ARCHES | ||||
|             arches += ROCM_ARCHES | ||||
|         elif os == "windows": | ||||
|             # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up | ||||
|             # in 2.10 | ||||
|             windows_cuda_arches = CUDA_ARCHES.copy() | ||||
|             windows_cuda_arches.remove("12.9") | ||||
|             arches += windows_cuda_arches | ||||
|             arches += CUDA_ARCHES | ||||
|     if libtorch_variants is None: | ||||
|         libtorch_variants = [ | ||||
|             "shared-with-deps", | ||||
| @ -309,11 +305,7 @@ def generate_wheels_matrix( | ||||
|         if os == "linux": | ||||
|             arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES | ||||
|         elif os == "windows": | ||||
|             # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up | ||||
|             # in 2.10 | ||||
|             windows_cuda_arches = CUDA_ARCHES.copy() | ||||
|             windows_cuda_arches.remove("12.9") | ||||
|             arches += windows_cuda_arches + XPU_ARCHES | ||||
|             arches += CUDA_ARCHES + XPU_ARCHES | ||||
|         elif os == "linux-aarch64": | ||||
|             # Separate new if as the CPU type is different and | ||||
|             # uses different build/test scripts | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							| @ -1092,7 +1092,7 @@ class GitHubPR: | ||||
|         editor = node["editor"] | ||||
|         return GitHubComment( | ||||
|             body_text=node["bodyText"], | ||||
|             created_at=node.get("createdAt", ""), | ||||
|             created_at=node["createdAt"] if "createdAt" in node else "", | ||||
|             author_login=node["author"]["login"], | ||||
|             author_url=node["author"].get("url", None), | ||||
|             author_association=node["authorAssociation"], | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							| @ -37,7 +37,7 @@ on: | ||||
|       runner: | ||||
|         required: false | ||||
|         type: string | ||||
|         default: "linux.c7i.2xlarge" | ||||
|         default: "linux.2xlarge" | ||||
|         description: | | ||||
|           Label of the runner this job should run on. | ||||
|       test-matrix: | ||||
|  | ||||
							
								
								
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							| @ -224,46 +224,6 @@ jobs: | ||||
|         continue-on-error: true | ||||
|         uses: ./.github/actions/download-td-artifacts | ||||
|  | ||||
|       - name: Download Windows torch wheel for cross-compilation | ||||
|         if: matrix.win_torch_wheel_artifact != '' | ||||
|         uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0 | ||||
|         with: | ||||
|           name: ${{ matrix.win_torch_wheel_artifact }} | ||||
|           path: win-torch-wheel | ||||
|  | ||||
|       - name: Extract Windows wheel and setup CUDA libraries | ||||
|         if: matrix.win_torch_wheel_artifact != '' | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -x | ||||
|  | ||||
|           # Find the wheel file | ||||
|           WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1) | ||||
|           if [ -z "$WHEEL_FILE" ]; then | ||||
|             echo "Error: No wheel file found in win-torch-wheel directory" | ||||
|             exit 1 | ||||
|           fi | ||||
|           echo "Found wheel file: $WHEEL_FILE" | ||||
|  | ||||
|           # Unzip the wheel file | ||||
|           unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted | ||||
|           echo "Extracted wheel contents" | ||||
|  | ||||
|           # Setup CUDA libraries (cuda.lib and cudart.lib) directory | ||||
|           mkdir -p win-torch-wheel-extracted/lib/x64 | ||||
|           if [ -f "win-torch-wheel/cuda.lib" ]; then | ||||
|             mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/ | ||||
|             echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/" | ||||
|           fi | ||||
|           if [ -f "win-torch-wheel/cudart.lib" ]; then | ||||
|             mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/ | ||||
|             echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/" | ||||
|           fi | ||||
|  | ||||
|           # Verify CUDA libraries are present | ||||
|           echo "CUDA libraries:" | ||||
|           ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found" | ||||
|  | ||||
|       - name: Parse ref | ||||
|         id: parse-ref | ||||
|         run: .github/scripts/parse_ref.py | ||||
|  | ||||
							
								
								
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							| @ -168,31 +168,6 @@ jobs: | ||||
|         run: | | ||||
|           .ci/pytorch/win-build.sh | ||||
|  | ||||
|       # Collect Windows torch libs and CUDA libs for cross-compilation | ||||
|       - name: Collect Windows CUDA libs for cross-compilation | ||||
|         if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu' | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -ex | ||||
|  | ||||
|           # Create directory structure if does not exist | ||||
|           mkdir -p /c/${{ github.run_id }}/build-results | ||||
|  | ||||
|           # Copy CUDA libs | ||||
|           CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}" | ||||
|  | ||||
|           if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then | ||||
|             cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/ | ||||
|           fi | ||||
|  | ||||
|           if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then | ||||
|             cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/ | ||||
|           fi | ||||
|  | ||||
|           # List collected files | ||||
|           echo "Collected CUDA libs:" | ||||
|           ls -lah /c/${{ github.run_id }}/build-results/*.lib | ||||
|  | ||||
|       # Upload to github so that people can click and download artifacts | ||||
|       - name: Upload artifacts to s3 | ||||
|         if: steps.build.outcome != 'skipped' | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -224,7 +224,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_10-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -473,7 +473,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_11-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -722,7 +722,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_12-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -971,7 +971,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_13-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1220,7 +1220,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_13t-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1469,7 +1469,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_14-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1718,7 +1718,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_14t-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -259,7 +259,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_10-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_10-cuda12_9-test:  # Testing | ||||
| @ -925,7 +925,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_11-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_11-cuda12_9-test:  # Testing | ||||
| @ -1591,7 +1591,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_12-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_12-cuda12_9-test:  # Testing | ||||
| @ -2257,7 +2257,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_13-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_13-cuda12_9-test:  # Testing | ||||
| @ -2923,7 +2923,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_13t-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_13t-cuda12_9-test:  # Testing | ||||
| @ -3589,7 +3589,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_14-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_14-cuda12_9-test:  # Testing | ||||
| @ -4255,7 +4255,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_14t-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_14t-cuda12_9-test:  # Testing | ||||
|  | ||||
							
								
								
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -788,6 +788,256 @@ jobs: | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|     uses: ./.github/workflows/_binary-upload.yml | ||||
|   libtorch-cuda12_9-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|       PACKAGE_TYPE: libtorch | ||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in | ||||
|       #       favor of GPU_ARCH_VERSION | ||||
|       DESIRED_CUDA: cu129 | ||||
|       GPU_ARCH_VERSION: "12.9" | ||||
|       GPU_ARCH_TYPE: cuda | ||||
|       SKIP_ALL_TESTS: 1 | ||||
|       LIBTORCH_CONFIG: debug | ||||
|       LIBTORCH_VARIANT: shared-with-deps | ||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts | ||||
|       # without this value pip does not get installed for some reason | ||||
|       DESIRED_PYTHON: "3.10" | ||||
|     steps: | ||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally | ||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the | ||||
|       #       runner.temp variable, which we need. | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" | ||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" | ||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" | ||||
|       - name: Display EC2 information | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -euo pipefail | ||||
|           function get_ec2_metadata() { | ||||
|             # Pulled from instance metadata endpoint for EC2 | ||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html | ||||
|             category=$1 | ||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" | ||||
|           } | ||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" | ||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" | ||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" | ||||
|           echo "system info $(uname -a)" | ||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" | ||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||||
|         continue-on-error: true | ||||
|         with: | ||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} | ||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon | ||||
|         shell: bash | ||||
|         run: | | ||||
|           git config --global core.longpaths true | ||||
|           git config --global core.symlinks true | ||||
|  | ||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock | ||||
|           # the directory on Windows and prevent GHA from checking out as reported | ||||
|           # in https://github.com/actions/checkout/issues/1018 | ||||
|           git config --global core.fsmonitor false | ||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 | ||||
|       - name: Enable long paths on Windows | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 | ||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be | ||||
|       # removed once Windows Defender is removed from the AMI | ||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch | ||||
|         continue-on-error: true | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore | ||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure | ||||
|           # that it doesn't interfere | ||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
|         with: | ||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} | ||||
|           submodules: recursive | ||||
|           path: pytorch | ||||
|           show-progress: false | ||||
|       - name: Clean PyTorch checkout | ||||
|         run: | | ||||
|           # Remove any artifacts from the previous checkouts | ||||
|           git clean -fxd | ||||
|         working-directory: pytorch | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" | ||||
|       - name: Build PyTorch binary | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" | ||||
|       - uses: actions/upload-artifact@v4.4.0 | ||||
|         if: always() | ||||
|         with: | ||||
|           name: libtorch-cuda12_9-shared-with-deps-debug | ||||
|           retention-days: 14 | ||||
|           if-no-files-found: error | ||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" | ||||
|       - name: Wait until all sessions have drained | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         timeout-minutes: 120 | ||||
|         run: | | ||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 | ||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         run: | | ||||
|           .github\scripts\kill_active_ssh_sessions.ps1 | ||||
|  | ||||
|   libtorch-cuda12_9-shared-with-deps-debug-test:  # Testing | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: | ||||
|       - libtorch-cuda12_9-shared-with-deps-debug-build | ||||
|       - get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|       PACKAGE_TYPE: libtorch | ||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in | ||||
|       #       favor of GPU_ARCH_VERSION | ||||
|       DESIRED_CUDA: cu129 | ||||
|       GPU_ARCH_VERSION: "12.9" | ||||
|       GPU_ARCH_TYPE: cuda | ||||
|       SKIP_ALL_TESTS: 1 | ||||
|       LIBTORCH_CONFIG: debug | ||||
|       LIBTORCH_VARIANT: shared-with-deps | ||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts | ||||
|       # without this value pip does not get installed for some reason | ||||
|       DESIRED_PYTHON: "3.10" | ||||
|     steps: | ||||
|       - name: Display EC2 information | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -euo pipefail | ||||
|           function get_ec2_metadata() { | ||||
|             # Pulled from instance metadata endpoint for EC2 | ||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html | ||||
|             category=$1 | ||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" | ||||
|           } | ||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" | ||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" | ||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" | ||||
|           echo "system info $(uname -a)" | ||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" | ||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||||
|         continue-on-error: true | ||||
|         with: | ||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} | ||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon | ||||
|         shell: bash | ||||
|         run: | | ||||
|           git config --global core.longpaths true | ||||
|           git config --global core.symlinks true | ||||
|  | ||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock | ||||
|           # the directory on Windows and prevent GHA from checking out as reported | ||||
|           # in https://github.com/actions/checkout/issues/1018 | ||||
|           git config --global core.fsmonitor false | ||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 | ||||
|       - name: Enable long paths on Windows | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 | ||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be | ||||
|       # removed once Windows Defender is removed from the AMI | ||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch | ||||
|         continue-on-error: true | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore | ||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure | ||||
|           # that it doesn't interfere | ||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
|         with: | ||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} | ||||
|           submodules: recursive | ||||
|           path: pytorch | ||||
|           show-progress: false | ||||
|       - name: Clean PyTorch checkout | ||||
|         run: | | ||||
|           # Remove any artifacts from the previous checkouts | ||||
|           git clean -fxd | ||||
|         working-directory: pytorch | ||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally | ||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the | ||||
|       #       runner.temp variable, which we need. | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" | ||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" | ||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" | ||||
|       - uses: actions/download-artifact@v4.1.7 | ||||
|         name: Download Build Artifacts | ||||
|         with: | ||||
|           name: libtorch-cuda12_9-shared-with-deps-debug | ||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" | ||||
|       - name: Test PyTorch binary | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" | ||||
|       - name: Wait until all sessions have drained | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         timeout-minutes: 120 | ||||
|         run: | | ||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 | ||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         run: | | ||||
|           .github\scripts\kill_active_ssh_sessions.ps1 | ||||
|   libtorch-cuda12_9-shared-with-deps-debug-upload:  # Uploading | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     needs: libtorch-cuda12_9-shared-with-deps-debug-test | ||||
|     with: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|       PACKAGE_TYPE: libtorch | ||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in | ||||
|       #       favor of GPU_ARCH_VERSION | ||||
|       DESIRED_CUDA: cu129 | ||||
|       GPU_ARCH_VERSION: "12.9" | ||||
|       GPU_ARCH_TYPE: cuda | ||||
|       LIBTORCH_CONFIG: debug | ||||
|       LIBTORCH_VARIANT: shared-with-deps | ||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts | ||||
|       # without this value pip does not get installed for some reason | ||||
|       DESIRED_PYTHON: "3.10" | ||||
|       build_name: libtorch-cuda12_9-shared-with-deps-debug | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|     uses: ./.github/workflows/_binary-upload.yml | ||||
|   libtorch-cuda13_0-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|  | ||||
							
								
								
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -788,6 +788,256 @@ jobs: | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|     uses: ./.github/workflows/_binary-upload.yml | ||||
|   libtorch-cuda12_9-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|       PACKAGE_TYPE: libtorch | ||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in | ||||
|       #       favor of GPU_ARCH_VERSION | ||||
|       DESIRED_CUDA: cu129 | ||||
|       GPU_ARCH_VERSION: "12.9" | ||||
|       GPU_ARCH_TYPE: cuda | ||||
|       SKIP_ALL_TESTS: 1 | ||||
|       LIBTORCH_CONFIG: release | ||||
|       LIBTORCH_VARIANT: shared-with-deps | ||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts | ||||
|       # without this value pip does not get installed for some reason | ||||
|       DESIRED_PYTHON: "3.10" | ||||
|     steps: | ||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally | ||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the | ||||
|       #       runner.temp variable, which we need. | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" | ||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" | ||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" | ||||
|       - name: Display EC2 information | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -euo pipefail | ||||
|           function get_ec2_metadata() { | ||||
|             # Pulled from instance metadata endpoint for EC2 | ||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html | ||||
|             category=$1 | ||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" | ||||
|           } | ||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" | ||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" | ||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" | ||||
|           echo "system info $(uname -a)" | ||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" | ||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||||
|         continue-on-error: true | ||||
|         with: | ||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} | ||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon | ||||
|         shell: bash | ||||
|         run: | | ||||
|           git config --global core.longpaths true | ||||
|           git config --global core.symlinks true | ||||
|  | ||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock | ||||
|           # the directory on Windows and prevent GHA from checking out as reported | ||||
|           # in https://github.com/actions/checkout/issues/1018 | ||||
|           git config --global core.fsmonitor false | ||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 | ||||
|       - name: Enable long paths on Windows | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 | ||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be | ||||
|       # removed once Windows Defender is removed from the AMI | ||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch | ||||
|         continue-on-error: true | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore | ||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure | ||||
|           # that it doesn't interfere | ||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
|         with: | ||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} | ||||
|           submodules: recursive | ||||
|           path: pytorch | ||||
|           show-progress: false | ||||
|       - name: Clean PyTorch checkout | ||||
|         run: | | ||||
|           # Remove any artifacts from the previous checkouts | ||||
|           git clean -fxd | ||||
|         working-directory: pytorch | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" | ||||
|       - name: Build PyTorch binary | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" | ||||
|       - uses: actions/upload-artifact@v4.4.0 | ||||
|         if: always() | ||||
|         with: | ||||
|           name: libtorch-cuda12_9-shared-with-deps-release | ||||
|           retention-days: 14 | ||||
|           if-no-files-found: error | ||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" | ||||
|       - name: Wait until all sessions have drained | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         timeout-minutes: 120 | ||||
|         run: | | ||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 | ||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         run: | | ||||
|           .github\scripts\kill_active_ssh_sessions.ps1 | ||||
|  | ||||
|   libtorch-cuda12_9-shared-with-deps-release-test:  # Testing | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: | ||||
|       - libtorch-cuda12_9-shared-with-deps-release-build | ||||
|       - get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|       PACKAGE_TYPE: libtorch | ||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in | ||||
|       #       favor of GPU_ARCH_VERSION | ||||
|       DESIRED_CUDA: cu129 | ||||
|       GPU_ARCH_VERSION: "12.9" | ||||
|       GPU_ARCH_TYPE: cuda | ||||
|       SKIP_ALL_TESTS: 1 | ||||
|       LIBTORCH_CONFIG: release | ||||
|       LIBTORCH_VARIANT: shared-with-deps | ||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts | ||||
|       # without this value pip does not get installed for some reason | ||||
|       DESIRED_PYTHON: "3.10" | ||||
|     steps: | ||||
|       - name: Display EC2 information | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -euo pipefail | ||||
|           function get_ec2_metadata() { | ||||
|             # Pulled from instance metadata endpoint for EC2 | ||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html | ||||
|             category=$1 | ||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" | ||||
|           } | ||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" | ||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" | ||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" | ||||
|           echo "system info $(uname -a)" | ||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" | ||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||||
|         continue-on-error: true | ||||
|         with: | ||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} | ||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon | ||||
|         shell: bash | ||||
|         run: | | ||||
|           git config --global core.longpaths true | ||||
|           git config --global core.symlinks true | ||||
|  | ||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock | ||||
|           # the directory on Windows and prevent GHA from checking out as reported | ||||
|           # in https://github.com/actions/checkout/issues/1018 | ||||
|           git config --global core.fsmonitor false | ||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 | ||||
|       - name: Enable long paths on Windows | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 | ||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be | ||||
|       # removed once Windows Defender is removed from the AMI | ||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch | ||||
|         continue-on-error: true | ||||
|         shell: powershell | ||||
|         run: | | ||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore | ||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure | ||||
|           # that it doesn't interfere | ||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
|         with: | ||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} | ||||
|           submodules: recursive | ||||
|           path: pytorch | ||||
|           show-progress: false | ||||
|       - name: Clean PyTorch checkout | ||||
|         run: | | ||||
|           # Remove any artifacts from the previous checkouts | ||||
|           git clean -fxd | ||||
|         working-directory: pytorch | ||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally | ||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the | ||||
|       #       runner.temp variable, which we need. | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" | ||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" | ||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" | ||||
|       - uses: actions/download-artifact@v4.1.7 | ||||
|         name: Download Build Artifacts | ||||
|         with: | ||||
|           name: libtorch-cuda12_9-shared-with-deps-release | ||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" | ||||
|       - name: Populate binary env | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" | ||||
|       - name: Test PyTorch binary | ||||
|         shell: bash | ||||
|         run: | | ||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" | ||||
|       - name: Wait until all sessions have drained | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         timeout-minutes: 120 | ||||
|         run: | | ||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 | ||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) | ||||
|         shell: powershell | ||||
|         working-directory: pytorch | ||||
|         if: always() | ||||
|         run: | | ||||
|           .github\scripts\kill_active_ssh_sessions.ps1 | ||||
|   libtorch-cuda12_9-shared-with-deps-release-upload:  # Uploading | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     needs: libtorch-cuda12_9-shared-with-deps-release-test | ||||
|     with: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|       PACKAGE_TYPE: libtorch | ||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in | ||||
|       #       favor of GPU_ARCH_VERSION | ||||
|       DESIRED_CUDA: cu129 | ||||
|       GPU_ARCH_VERSION: "12.9" | ||||
|       GPU_ARCH_TYPE: cuda | ||||
|       LIBTORCH_CONFIG: release | ||||
|       LIBTORCH_VARIANT: shared-with-deps | ||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts | ||||
|       # without this value pip does not get installed for some reason | ||||
|       DESIRED_PYTHON: "3.10" | ||||
|       build_name: libtorch-cuda12_9-shared-with-deps-release | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|     uses: ./.github/workflows/_binary-upload.yml | ||||
|   libtorch-cuda13_0-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|  | ||||
							
								
								
									
										1666
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1666
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -88,27 +88,27 @@ jobs: | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							| @ -118,9 +118,9 @@ jobs: | ||||
|         CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" | ||||
|         echo "Running all other linters" | ||||
|         if [ "$CHANGED_FILES" = '*' ]; then | ||||
|           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh | ||||
|           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh | ||||
|         else | ||||
|           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh | ||||
|           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh | ||||
|         fi | ||||
|  | ||||
|   quick-checks: | ||||
|  | ||||
							
								
								
									
										38
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							| @ -30,9 +30,9 @@ permissions: | ||||
|   contents: read | ||||
|  | ||||
| jobs: | ||||
|   x86-opbenchmark-build: | ||||
|   opbenchmark-build: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: x86-opbenchmark-build | ||||
|     name: opbenchmark-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       build-environment: linux-jammy-py3.10-gcc11-build | ||||
| @ -43,36 +43,12 @@ jobs: | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   x86-opbenchmark-test: | ||||
|     name: x86-opbenchmark-test | ||||
|   opbenchmark-test: | ||||
|     name: opbenchmark-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: x86-opbenchmark-build | ||||
|     needs: opbenchmark-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-py3.10-gcc11-build | ||||
|       docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
|   aarch64-opbenchmark-build: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: aarch64-opbenchmark-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       build-environment: linux-jammy-aarch64-py3.10 | ||||
|       runner: linux.arm64.m7g.4xlarge | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   aarch64-opbenchmark-test: | ||||
|     name: aarch64-opbenchmark-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: aarch64-opbenchmark-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-aarch64-py3.10 | ||||
|       docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }} | ||||
|       docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							| @ -45,12 +45,12 @@ jobs: | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
							
								
								
									
										17
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							| @ -200,23 +200,6 @@ jobs: | ||||
|       cuda-arch-list: '8.0' | ||||
|     secrets: inherit | ||||
|  | ||||
|   # Test cross-compiled models with Windows libs extracted from wheel | ||||
|   cross-compile-linux-test: | ||||
|     name: cross-compile-linux-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-cuda12_8-py3_10-gcc11-build | ||||
|       - get-label-type | ||||
|       - win-vs2022-cuda12_8-py3-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc11 | ||||
|       docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   verify-cachebench-cpu-build: | ||||
|     name: verify-cachebench-cpu-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|  | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -374,7 +374,6 @@ third_party/ruy/ | ||||
| third_party/glog/ | ||||
|  | ||||
| # Virtualenv | ||||
| .venv/ | ||||
| venv/ | ||||
|  | ||||
| # Log files | ||||
|  | ||||
| @ -209,46 +209,6 @@ command = [ | ||||
|     '@{{PATHSFILE}}' | ||||
| ] | ||||
|  | ||||
|  | ||||
| [[linter]] | ||||
| code = 'PYREFLY' | ||||
| include_patterns = [ | ||||
|     'torch/**/*.py', | ||||
|     'torch/**/*.pyi', | ||||
|     'torchgen/**/*.py', | ||||
|     'torchgen/**/*.pyi', | ||||
|     'functorch/**/*.py', | ||||
|     'functorch/**/*.pyi', | ||||
| ] | ||||
| exclude_patterns = [] | ||||
| command = [ | ||||
|     'python3', | ||||
|     'tools/linter/adapters/pyrefly_linter.py', | ||||
|     '--config=pyrefly.toml', | ||||
| ] | ||||
| init_command = [ | ||||
|     'python3', | ||||
|     'tools/linter/adapters/pip_init.py', | ||||
|     '--dry-run={{DRYRUN}}', | ||||
|     'numpy==2.1.0 ; python_version >= "3.12"', | ||||
|     'expecttest==0.3.0', | ||||
|     'pyrefly==0.36.2', | ||||
|     'sympy==1.13.3', | ||||
|     'types-requests==2.27.25', | ||||
|     'types-pyyaml==6.0.2', | ||||
|     'types-tabulate==0.8.8', | ||||
|     'types-protobuf==5.29.1.20250403', | ||||
|     'types-setuptools==79.0.0.20250422', | ||||
|     'types-jinja2==2.11.9', | ||||
|     'types-colorama==0.4.6', | ||||
|     'filelock==3.18.0', | ||||
|     'junitparser==2.1.1', | ||||
|     'rich==14.1.0', | ||||
|     'optree==0.17.0', | ||||
|     'types-openpyxl==3.1.5.20250919', | ||||
|     'types-python-dateutil==2.9.0.20251008' | ||||
| ] | ||||
|  | ||||
| [[linter]] | ||||
| code = 'CLANGTIDY' | ||||
| include_patterns = [ | ||||
|  | ||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							| @ -201,17 +201,3 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A | ||||
| /torch/csrc/stable/ @janeyx99 @mikaylagawarecki | ||||
| /torch/headeronly/ @janeyx99 | ||||
| /torch/header_only_apis.txt @janeyx99 | ||||
|  | ||||
| # FlexAttention | ||||
| /torch/nn/attention/flex_attention.py @drisspg | ||||
| /torch/_higher_order_ops/flex_attention.py @drisspg | ||||
| /torch/_inductor/kernel/flex/ @drisspg | ||||
| /torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg | ||||
| /test/inductor/test_flex_attention.py @drisspg | ||||
| /test/inductor/test_flex_decoding.py @drisspg | ||||
|  | ||||
| # Low Precision GEMMs | ||||
| /aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58 | ||||
| /aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58 | ||||
| /aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58 | ||||
| /test/test_scaled_matmul_cuda.py @drisspg @slayton58 | ||||
|  | ||||
| @ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI) | ||||
|  | ||||
|     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||
|  | ||||
|     set(fbgemm_genai_cuh | ||||
|     set(fbgemm_genai_mx8mx8bf16_grouped | ||||
|       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" | ||||
|       "${FBGEMM_GENAI_SRCS}/" | ||||
|     ) | ||||
|  | ||||
|     target_include_directories(fbgemm_genai PRIVATE | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||
|       ${fbgemm_genai_cuh} | ||||
|       ${fbgemm_genai_mx8mx8bf16_grouped} | ||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||
|     ) | ||||
|  | ||||
| @ -229,10 +229,10 @@ private: | ||||
|   } | ||||
|  | ||||
|  | ||||
|   static constexpr uint32_t kPhilox10A = 0x9E3779B9; | ||||
|   static constexpr uint32_t kPhilox10B = 0xBB67AE85; | ||||
|   static constexpr uint32_t kPhiloxSA = 0xD2511F53; | ||||
|   static constexpr uint32_t kPhiloxSB = 0xCD9E8D57; | ||||
|   static const uint32_t kPhilox10A = 0x9E3779B9; | ||||
|   static const uint32_t kPhilox10B = 0xBB67AE85; | ||||
|   static const uint32_t kPhiloxSA = 0xD2511F53; | ||||
|   static const uint32_t kPhiloxSB = 0xCD9E8D57; | ||||
| }; | ||||
|  | ||||
| typedef philox_engine Philox4_32; | ||||
|  | ||||
| @ -8,7 +8,6 @@ | ||||
| #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cpu/vec/vec128/vec128_convert.h> | ||||
|  | ||||
| @ -1,794 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cpu/vec/intrinsics.h> | ||||
| #include <ATen/cpu/vec/vec_base.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/irange.h> | ||||
|  | ||||
| namespace at::vec { | ||||
| // Note [CPU_CAPABILITY namespace] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // This header, and all of its subheaders, will be compiled with | ||||
| // different architecture flags for each supported set of vector | ||||
| // intrinsics. So we need to make sure they aren't inadvertently | ||||
| // linked together. We do this by declaring objects in an `inline | ||||
| // namespace` which changes the name mangling, but can still be | ||||
| // accessed as `at::vec`. | ||||
| inline namespace CPU_CAPABILITY { | ||||
|  | ||||
| #define VEC_INT_NEON_TEMPLATE(vl, bit)                                        \ | ||||
|   template <>                                                                 \ | ||||
|   struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {};  \ | ||||
|                                                                               \ | ||||
|   template <>                                                                 \ | ||||
|   class Vectorized<int##bit##_t> {                                            \ | ||||
|     using neon_type = int##bit##x##vl##_t;                                    \ | ||||
|                                                                               \ | ||||
|    private:                                                                   \ | ||||
|     neon_type values;                                                         \ | ||||
|                                                                               \ | ||||
|    public:                                                                    \ | ||||
|     using value_type = int##bit##_t;                                          \ | ||||
|     using size_type = int;                                                    \ | ||||
|     static constexpr size_type size() {                                       \ | ||||
|       return vl;                                                              \ | ||||
|     }                                                                         \ | ||||
|     Vectorized() {                                                            \ | ||||
|       values = vdupq_n_s##bit(0);                                             \ | ||||
|     }                                                                         \ | ||||
|     Vectorized(neon_type v) : values(v) {}                                    \ | ||||
|     Vectorized(int##bit##_t val);                                             \ | ||||
|     template <                                                                \ | ||||
|         typename... Args,                                                     \ | ||||
|         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||
|     Vectorized(Args... vals) {                                                \ | ||||
|       __at_align__ int##bit##_t buffer[size()] = {vals...};                   \ | ||||
|       values = vld1q_s##bit(buffer);                                          \ | ||||
|     }                                                                         \ | ||||
|     operator neon_type() const {                                              \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     static Vectorized<int##bit##_t> loadu(                                    \ | ||||
|         const void* ptr,                                                      \ | ||||
|         int64_t count = size());                                              \ | ||||
|     void store(void* ptr, int64_t count = size()) const;                      \ | ||||
|     template <int64_t mask>                                                   \ | ||||
|     static Vectorized<int##bit##_t> blend(                                    \ | ||||
|         const Vectorized<int##bit##_t>& a,                                    \ | ||||
|         const Vectorized<int##bit##_t>& b);                                   \ | ||||
|     static Vectorized<int##bit##_t> blendv(                                   \ | ||||
|         const Vectorized<int##bit##_t>& a,                                    \ | ||||
|         const Vectorized<int##bit##_t>& b,                                    \ | ||||
|         const Vectorized<int##bit##_t>& mask_) {                              \ | ||||
|       return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \ | ||||
|     }                                                                         \ | ||||
|     template <typename step_t>                                                \ | ||||
|     static Vectorized<int##bit##_t> arange(                                   \ | ||||
|         value_type base = 0,                                                  \ | ||||
|         step_t step = static_cast<step_t>(1));                                \ | ||||
|     static Vectorized<int##bit##_t> set(                                      \ | ||||
|         const Vectorized<int##bit##_t>& a,                                    \ | ||||
|         const Vectorized<int##bit##_t>& b,                                    \ | ||||
|         int64_t count = size());                                              \ | ||||
|     const int##bit##_t& operator[](int idx) const = delete;                   \ | ||||
|     int##bit##_t& operator[](int idx) = delete;                               \ | ||||
|     Vectorized<int##bit##_t> abs() const {                                    \ | ||||
|       return vabsq_s##bit(values);                                            \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> real() const {                                   \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> imag() const {                                   \ | ||||
|       return vdupq_n_s##bit(0);                                               \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> conj() const {                                   \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> neg() const {                                    \ | ||||
|       return vnegq_s##bit(values);                                            \ | ||||
|     }                                                                         \ | ||||
|     int##bit##_t reduce_add() const {                                         \ | ||||
|       return vaddvq_s##bit(values);                                           \ | ||||
|     }                                                                         \ | ||||
|     int##bit##_t reduce_max() const;                                          \ | ||||
|     Vectorized<int##bit##_t> operator==(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator!=(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const;                         \ | ||||
|     Vectorized<int##bit##_t> operator<(                                       \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator<=(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator>(                                       \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator>=(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \ | ||||
|   };                                                                          \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator+(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vaddq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator-(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vsubq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator&(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vandq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator|(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vorrq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator^(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return veorq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this == other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this != other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this > other) & Vectorized<int##bit##_t>(1);                     \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this >= other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this < other) & Vectorized<int##bit##_t>(1);                     \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this <= other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   } | ||||
|  | ||||
| VEC_INT_NEON_TEMPLATE(2, 64) | ||||
| VEC_INT_NEON_TEMPLATE(4, 32) | ||||
| VEC_INT_NEON_TEMPLATE(8, 16) | ||||
| VEC_INT_NEON_TEMPLATE(16, 8) | ||||
|  | ||||
| inline int32_t Vectorized<int32_t>::reduce_max() const { | ||||
|   return vmaxvq_s32(values); | ||||
| } | ||||
|  | ||||
| inline int16_t Vectorized<int16_t>::reduce_max() const { | ||||
|   return vmaxvq_s16(values); | ||||
| } | ||||
|  | ||||
| inline int8_t Vectorized<int8_t>::reduce_max() const { | ||||
|   return vmaxvq_s8(values); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator*( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   return vmulq_s32(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator*( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   return vmulq_s16(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator*( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   return vmulq_s8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) { | ||||
|   int64x2_t val = a; | ||||
|   return ~val; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) { | ||||
|   return vmvnq_s32(a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) { | ||||
|   return vmvnq_s16(a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) { | ||||
|   return vmvnq_s8(a); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int64_t> Vectorized<int64_t>::operator!=( | ||||
|     const Vectorized<int64_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int32_t> Vectorized<int32_t>::operator!=( | ||||
|     const Vectorized<int32_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int16_t> Vectorized<int16_t>::operator!=( | ||||
|     const Vectorized<int16_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int8_t> Vectorized<int8_t>::operator!=( | ||||
|     const Vectorized<int8_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline minimum( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   return vminq_s32(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline minimum( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   return vminq_s16(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline minimum( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   return vminq_s8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline maximum( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   return vmaxq_s32(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline maximum( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   return vmaxq_s16(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline maximum( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   return vmaxq_s8(a, b); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int64_t> Vectorized<int64_t>::blend( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint64x2_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||
|       (mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s64(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int32_t> Vectorized<int32_t>::blend( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint32x4_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFFFFFFFF : 0, | ||||
|       (mask & 2LL) ? 0xFFFFFFFF : 0, | ||||
|       (mask & 4LL) ? 0xFFFFFFFF : 0, | ||||
|       (mask & 8LL) ? 0xFFFFFFFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s32(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int16_t> Vectorized<int16_t>::blend( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint16x8_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFFFF : 0, | ||||
|       (mask & 2LL) ? 0xFFFF : 0, | ||||
|       (mask & 4LL) ? 0xFFFF : 0, | ||||
|       (mask & 8LL) ? 0xFFFF : 0, | ||||
|       (mask & 16LL) ? 0xFFFF : 0, | ||||
|       (mask & 32LL) ? 0xFFFF : 0, | ||||
|       (mask & 64LL) ? 0xFFFF : 0, | ||||
|       (mask & 128LL) ? 0xFFFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s16(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int8_t> Vectorized<int8_t>::blend( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint8x16_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFF : 0, | ||||
|       (mask & 2LL) ? 0xFF : 0, | ||||
|       (mask & 4LL) ? 0xFF : 0, | ||||
|       (mask & 8LL) ? 0xFF : 0, | ||||
|       (mask & 16LL) ? 0xFF : 0, | ||||
|       (mask & 32LL) ? 0xFF : 0, | ||||
|       (mask & 64LL) ? 0xFF : 0, | ||||
|       (mask & 128LL) ? 0xFF : 0, | ||||
|       (mask & 256LL) ? 0xFF : 0, | ||||
|       (mask & 512LL) ? 0xFF : 0, | ||||
|       (mask & 1024LL) ? 0xFF : 0, | ||||
|       (mask & 2048LL) ? 0xFF : 0, | ||||
|       (mask & 4096LL) ? 0xFF : 0, | ||||
|       (mask & 8192LL) ? 0xFF : 0, | ||||
|       (mask & 16384LL) ? 0xFF : 0, | ||||
|       (mask & 32768LL) ? 0xFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s8(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| #define VEC_INT_NEON_OPS(vl, bit)                                             \ | ||||
|   inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) {             \ | ||||
|     values = vdupq_n_s##bit(val);                                             \ | ||||
|   }                                                                           \ | ||||
|   inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu(            \ | ||||
|       const void* ptr, int64_t count) {                                       \ | ||||
|     if (count == size()) {                                                    \ | ||||
|       return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr));        \ | ||||
|     } else {                                                                  \ | ||||
|       __at_align__ int##bit##_t tmp_values[size()];                           \ | ||||
|       for (const auto i : c10::irange(size())) {                              \ | ||||
|         tmp_values[i] = 0;                                                    \ | ||||
|       }                                                                       \ | ||||
|       std::memcpy(                                                            \ | ||||
|           tmp_values,                                                         \ | ||||
|           reinterpret_cast<const int##bit##_t*>(ptr),                         \ | ||||
|           count * sizeof(int##bit##_t));                                      \ | ||||
|       return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \ | ||||
|     }                                                                         \ | ||||
|   }                                                                           \ | ||||
|   inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count)       \ | ||||
|       const {                                                                 \ | ||||
|     if (count == size()) {                                                    \ | ||||
|       vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values);             \ | ||||
|     } else {                                                                  \ | ||||
|       int##bit##_t tmp_values[size()];                                        \ | ||||
|       vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values);      \ | ||||
|       std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t));             \ | ||||
|     }                                                                         \ | ||||
|   } | ||||
|  | ||||
| VEC_INT_NEON_OPS(2, 64) | ||||
| VEC_INT_NEON_OPS(4, 32) | ||||
| VEC_INT_NEON_OPS(8, 16) | ||||
| VEC_INT_NEON_OPS(16, 8) | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator*( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return x * y; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator/( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return x / y; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator/( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   int32x4_t x = a; | ||||
|   int32x4_t y = b; | ||||
|   return x / y; | ||||
| } | ||||
|  | ||||
| inline int64_t Vectorized<int64_t>::reduce_max() const { | ||||
|   return std::max(values[0], values[1]); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline minimum( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return {std::min(x[0], y[0]), std::min(x[1], y[1])}; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline maximum( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return {std::max(x[0], y[0]), std::max(x[1], y[1])}; | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int64_t> Vectorized<int64_t>::arange( | ||||
|     int64_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<int64_t> base_vec(base); | ||||
|   const Vectorized<int64_t> step_vec(step); | ||||
|   const int64x2_t step_sizes = {0, 1}; | ||||
|   return base_vec.values + step_sizes * step_vec.values; | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int32_t> Vectorized<int32_t>::arange( | ||||
|     int32_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<int32_t> base_vec(base); | ||||
|   const Vectorized<int32_t> step_vec(step); | ||||
|   const int32x4_t step_sizes = {0, 1, 2, 3}; | ||||
|   return vmlaq_s32(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int16_t> Vectorized<int16_t>::arange( | ||||
|     int16_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<int16_t> base_vec(base); | ||||
|   const Vectorized<int16_t> step_vec(step); | ||||
|   const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7}; | ||||
|   return vmlaq_s16(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) { | ||||
|   const Vectorized<int8_t> base_vec(base); | ||||
|   const Vectorized<int8_t> step_vec(step); | ||||
|   const int8x16_t step_sizes = { | ||||
|       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||
|   return vmlaq_s8(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator>>( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   uint64x2_t u = vreinterpretq_u64_s64(y); | ||||
|   uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)}; | ||||
|   return x >> vreinterpretq_s64_u64(z); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator>>( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   int32x4_t x = a; | ||||
|   int32x4_t y = b; | ||||
|   uint32x4_t bound = vdupq_n_u32(31); | ||||
|   uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); | ||||
|   return x >> vreinterpretq_s32_u32(z); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator>>( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   int16x8_t x = a; | ||||
|   int16x8_t y = b; | ||||
|   uint16x8_t bound = vdupq_n_u16(15); | ||||
|   uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); | ||||
|   return x >> vreinterpretq_s16_u16(z); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator>>( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   int8x16_t x = a; | ||||
|   int8x16_t y = b; | ||||
|   uint8x16_t bound = vdupq_n_u8(7); | ||||
|   int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); | ||||
|   return x >> z; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator<<( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t y = b; | ||||
|   uint64x2_t u = vreinterpretq_u64_s64(y); | ||||
|   uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)}; | ||||
|   return vshlq_s64(a, vreinterpretq_s64_u64(z)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator<<( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   int32x4_t y = b; | ||||
|   uint32x4_t bound = vdupq_n_u32(32); | ||||
|   uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); | ||||
|   return vshlq_s32(a, vreinterpretq_s32_u32(z)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator<<( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   int16x8_t y = b; | ||||
|   uint16x8_t bound = vdupq_n_u16(16); | ||||
|   uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); | ||||
|   return vshlq_s16(a, vreinterpretq_s16_u16(z)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator<<( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   int8x16_t y = b; | ||||
|   uint8x16_t bound = vdupq_n_u8(8); | ||||
|   int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); | ||||
|   return vshlq_s8(a, z); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int64_t> Vectorized<int64_t>::set( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 2) { | ||||
|     return b; | ||||
|   } else { | ||||
|     int64x2_t c = {b.values[0], a.values[1]}; | ||||
|     return c; | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline Vectorized<int32_t> Vectorized<int32_t>::set( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 4) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint32x4_t maskArray = { | ||||
|         (count >= 1LL) ? 0xFFFFFFFF : 0, | ||||
|         (count >= 2LL) ? 0xFFFFFFFF : 0, | ||||
|         (count >= 3LL) ? 0xFFFFFFFF : 0, | ||||
|         0}; | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_s32(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline Vectorized<int16_t> Vectorized<int16_t>::set( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 8) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint16x8_t maskArray = { | ||||
|         static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0), | ||||
|         0}; | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_s16(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline Vectorized<int8_t> Vectorized<int8_t>::set( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 16) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint8x16_t maskArray = { | ||||
|         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||
|         0}; | ||||
|  | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_s8(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator/( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   Vectorized<int32_t> highBitsA = vmovl_high_s16(a); | ||||
|   Vectorized<int32_t> highBitsB = vmovl_high_s16(b); | ||||
|   Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a)); | ||||
|   Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b)); | ||||
|   int32x4_t highBitsResult = highBitsA / highBitsB; | ||||
|   int32x4_t lowBitsResult = lowBitsA / lowBitsB; | ||||
|   return vuzp1q_s16( | ||||
|       vreinterpretq_s16_s32(lowBitsResult), | ||||
|       vreinterpretq_s16_s32(highBitsResult)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator/( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   Vectorized<int16_t> highBitsA = vmovl_high_s8(a); | ||||
|   Vectorized<int16_t> highBitsB = vmovl_high_s8(b); | ||||
|   Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a)); | ||||
|   Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b)); | ||||
|   int16x8_t highBitsResult = highBitsA / highBitsB; | ||||
|   int16x8_t lowBitsResult = lowBitsA / lowBitsB; | ||||
|   return vuzp1q_s8( | ||||
|       vreinterpretq_s8_s16(lowBitsResult), | ||||
|       vreinterpretq_s8_s16(highBitsResult)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline clamp( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& min, | ||||
|     const Vectorized<int64_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline clamp( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& min, | ||||
|     const Vectorized<int32_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline clamp( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& min, | ||||
|     const Vectorized<int16_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline clamp( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& min, | ||||
|     const Vectorized<int8_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline clamp_max( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline clamp_max( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline clamp_max( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline clamp_max( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline clamp_min( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline clamp_min( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline clamp_min( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline clamp_min( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
| @ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum( | ||||
| #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | ||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|     at::vec::Vectorized<int8_t> src) { | ||||
|   auto s8x8 = vget_low_s8(src); | ||||
|   auto s8x8 = vld1_s8(src.operator const int8_t*()); | ||||
|   auto s16x8 = vmovl_s8(s8x8); | ||||
|  | ||||
|   auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); | ||||
| @ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|  | ||||
| Vectorized<float> inline convert_int8_half_register_to_float( | ||||
|     at::vec::Vectorized<int8_t> src) { | ||||
|   auto s8x8 = vget_low_s8(src); | ||||
|   auto s8x8 = vld1_s8(src.operator const int8_t*()); | ||||
|   auto s16x8 = vmovl_s8(s8x8); | ||||
|  | ||||
|   auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); | ||||
|  | ||||
| @ -16,8 +16,6 @@ | ||||
| #include <c10/util/irange.h> | ||||
| #include <c10/core/ScalarType.h> | ||||
|  | ||||
| #include <ATen/cuda/detail/BLASConstants.h> | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| #include <c10/cuda/CUDAStream.h> | ||||
| #include <hipblaslt/hipblaslt-ext.hpp> | ||||
| @ -1956,15 +1954,13 @@ void scaled_gemm( | ||||
|     const void *result_scale_ptr, | ||||
|     int64_t result_ld, | ||||
|     ScalarType result_dtype, | ||||
|     bool use_fast_accum, | ||||
|     const std::optional<Tensor>& alpha) { | ||||
|     bool use_fast_accum) { | ||||
|   // Note: see `cublasCommonArgs` for various non-intuitive manupulations | ||||
|   // of input arguments to this function. | ||||
|   const auto computeType = CUBLAS_COMPUTE_32F; | ||||
|   const auto scaleType = CUDA_R_32F; | ||||
|   // Note: alpha_val may change later depending on user-passed argument | ||||
|   float alpha_val = 1.0; | ||||
|   float beta_val = 0.0; | ||||
|   const float alpha_val = 1.0; | ||||
|   const float beta_val = 0.0; | ||||
|   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); | ||||
|   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); | ||||
|   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); | ||||
| @ -2035,33 +2031,6 @@ void scaled_gemm( | ||||
|     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS); | ||||
|     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); | ||||
|   } | ||||
|  | ||||
|   // Handle user-passed alpha | ||||
|   float *alpha_ptr = &alpha_val; | ||||
|   float *beta_ptr = &beta_val; | ||||
|  | ||||
|   if (alpha.has_value()) { | ||||
|     auto& a = alpha.value(); | ||||
|  | ||||
|     // if device-tensor | ||||
|     if (a.is_cuda()) { | ||||
|       // NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be | ||||
|       //       valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus | ||||
|       //       we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically | ||||
|       //       managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory | ||||
|       //       for beta respectively. | ||||
|       float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr(); | ||||
|       at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat)); | ||||
|       user_alpha.copy_(a); | ||||
|       // Tell cublasLt we're using device-side pointers for alpha/beta | ||||
|       auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; | ||||
|       computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode); | ||||
|       alpha_ptr = user_alpha.data_ptr<float>(); | ||||
|       beta_ptr = at::cuda::detail::get_cublas_device_zero(); | ||||
|     } else { | ||||
|       alpha_val = a.item<float>(); | ||||
|     } | ||||
|   } | ||||
|     // For other data types, use the get_scale_mode function based on scaling type | ||||
|     // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt, | ||||
|     // but we must invoke get_scale_mode anyways to trigger the version checks. | ||||
| @ -2079,7 +2048,6 @@ void scaled_gemm( | ||||
|   cublasLtMatmulHeuristicResult_t heuristicResult = {}; | ||||
|   int returnedResult = 0; | ||||
|   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); | ||||
|  | ||||
|   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( | ||||
|       ltHandle, | ||||
|       computeDesc.descriptor(), | ||||
| @ -2120,10 +2088,10 @@ void scaled_gemm( | ||||
|         auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( | ||||
|                 ltHandle, | ||||
|                 computeDesc.descriptor(), | ||||
|                 alpha_ptr, | ||||
|                 &alpha_val, | ||||
|                 Adesc.descriptor(), | ||||
|                 Bdesc.descriptor(), | ||||
|                 beta_ptr, | ||||
|                 &beta_val, | ||||
|                 Cdesc.descriptor(), | ||||
|                 Ddesc.descriptor(), | ||||
|                 all_algos[i].algo, | ||||
| @ -2142,14 +2110,17 @@ void scaled_gemm( | ||||
|   cublasStatus_t cublasStatus = cublasLtMatmul( | ||||
|       ltHandle, | ||||
|       computeDesc.descriptor(), | ||||
|       alpha_ptr, | ||||
|       &alpha_val, | ||||
|       mat1_ptr, | ||||
|       Adesc.descriptor(), | ||||
|       mat2_ptr, | ||||
|       Bdesc.descriptor(), | ||||
|       beta_ptr, | ||||
|       // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either | ||||
|       &beta_val, | ||||
| #ifdef USE_ROCM | ||||
|       result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr | ||||
| #else | ||||
|       nullptr, | ||||
| #endif // ifdef USE_ROCM | ||||
|       Cdesc.descriptor(), | ||||
|       result_ptr, | ||||
|       Ddesc.descriptor(), | ||||
|  | ||||
| @ -161,8 +161,7 @@ void scaled_gemm( | ||||
|     const void* result_scale_ptr, | ||||
|     int64_t result_ld, | ||||
|     ScalarType result_dtype, | ||||
|     bool use_fast_accum, | ||||
|     const std::optional<Tensor>& alpha); | ||||
|     bool use_fast_accum); | ||||
|  | ||||
| #define CUDABLAS_BGEMM_ARGTYPES(Dtype)  CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) | ||||
|  | ||||
|  | ||||
| @ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() { | ||||
|  */ | ||||
| c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | ||||
|   // The RNG state comprises the seed, and an offset used for Philox. | ||||
|   constexpr size_t seed_size = sizeof(uint64_t); | ||||
|   constexpr size_t offset_size = sizeof(int64_t); | ||||
|   constexpr size_t total_size = seed_size + offset_size; | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(int64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
|  | ||||
|   auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); | ||||
|   auto rng_state = state_tensor.data_ptr<uint8_t>(); | ||||
| @ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | ||||
|  * and size of the internal state. | ||||
|  */ | ||||
| void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||||
|   constexpr size_t seed_size = sizeof(uint64_t); | ||||
|   constexpr size_t offset_size = sizeof(int64_t); | ||||
|   constexpr size_t total_size = seed_size + offset_size; | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(int64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
|  | ||||
|   detail::check_rng_state(new_state); | ||||
|  | ||||
|  | ||||
| @ -177,6 +177,7 @@ inline void segmented_sort_pairs( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> | ||||
| inline void unique_by_key( | ||||
|   KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, | ||||
| @ -192,6 +193,7 @@ inline void unique_by_key( | ||||
|   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, | ||||
|     keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| namespace impl { | ||||
|  | ||||
| @ -577,6 +579,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|  | ||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> | ||||
| inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { | ||||
| @ -604,6 +607,7 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> | ||||
| void unique(InputIteratorT input, OutputIteratorT output, | ||||
|  | ||||
| @ -28,6 +28,22 @@ | ||||
| #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | ||||
| #endif | ||||
|  | ||||
| // cub support for UniqueByKey is added to cub 1.16 in: | ||||
| // https://github.com/NVIDIA/cub/pull/405 | ||||
| #if CUB_VERSION >= 101600 | ||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() false | ||||
| #endif | ||||
|  | ||||
| // cub support for scan by key is added to cub 1.15 | ||||
| // in https://github.com/NVIDIA/cub/pull/376 | ||||
| #if CUB_VERSION >= 101500 | ||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 1 | ||||
| #else | ||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 0 | ||||
| #endif | ||||
|  | ||||
| // cub support for cub::FutureValue is added to cub 1.15 in: | ||||
| // https://github.com/NVIDIA/cub/pull/305 | ||||
| #if CUB_VERSION >= 101500 | ||||
|  | ||||
| @ -1,54 +0,0 @@ | ||||
| #include <ATen/Functions.h> | ||||
| #include <ATen/Tensor.h> | ||||
| #include <ATen/cuda/Exceptions.h> | ||||
|  | ||||
| #include <mutex> | ||||
|  | ||||
| namespace at { | ||||
| namespace cuda { | ||||
| namespace detail { | ||||
|  | ||||
| __device__ __constant__ float cublas_one_device; | ||||
| __device__ __constant__ float cublas_zero_device; | ||||
|  | ||||
| float *get_cublas_device_one() { | ||||
|   static c10::once_flag init_flag; | ||||
|  | ||||
|   c10::call_once(init_flag, []() { | ||||
|     const float one = 1.f; | ||||
|     AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); | ||||
|   }); | ||||
|  | ||||
|   float *ptr; | ||||
|   AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device)); | ||||
|   return ptr; | ||||
| } | ||||
|  | ||||
| float *get_cublas_device_zero() { | ||||
|   static c10::once_flag init_flag; | ||||
|  | ||||
|   c10::call_once(init_flag, []() { | ||||
|     const float zero = 0.f; | ||||
|     AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); | ||||
|   }); | ||||
|  | ||||
|   float *ptr; | ||||
|   AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device)); | ||||
|   return ptr; | ||||
| } | ||||
|  | ||||
| float *get_user_alpha_ptr() { | ||||
|   static float *alpha_ptr; | ||||
|  | ||||
|   static c10::once_flag init_flag; | ||||
|  | ||||
|   c10::call_once(init_flag, []() { | ||||
|     AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); | ||||
|   }); | ||||
|  | ||||
|   return alpha_ptr; | ||||
| } | ||||
|  | ||||
| } // namespace detail | ||||
| } // namespace cuda | ||||
| } // namespace at | ||||
| @ -1,11 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/core/TensorBase.h> | ||||
|  | ||||
| namespace at::cuda::detail { | ||||
|  | ||||
| float *get_cublas_device_one(); | ||||
| float *get_cublas_device_zero(); | ||||
| float *get_user_alpha_ptr(); | ||||
|  | ||||
| } // namespace at::cuda::detail | ||||
| @ -109,8 +109,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> { | ||||
|           params->c_scale_ptr, | ||||
|           params->ldc, | ||||
|           params->c_dtype, | ||||
|           params->use_fast_accum, | ||||
|           std::nullopt /* alpha */); | ||||
|           params->use_fast_accum); | ||||
|       return OK; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| @ -160,10 +160,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ | ||||
|   DispatchKey::CUDA, | ||||
|   DispatchKey::CPU, | ||||
|   DispatchKey::PrivateUse1, | ||||
|   DispatchKey::SparseCPU, | ||||
|   DispatchKey::SparseCUDA, | ||||
|   DispatchKey::SparseCsrCPU, | ||||
|   DispatchKey::SparseCsrCUDA, | ||||
| }); | ||||
|  | ||||
| inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { | ||||
|  | ||||
| @ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) ( | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717; | ||||
| static constexpr double SELU_SCALE = 1.0507009873554804934193349852946; | ||||
| static const double SELU_ALPHA = 1.6732632423543772848170429916717; | ||||
| static const double SELU_SCALE = 1.0507009873554804934193349852946; | ||||
|  | ||||
| DEFINE_DISPATCH(elu_stub); | ||||
| DEFINE_DISPATCH(elu_backward_stub); | ||||
|  | ||||
| @ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in | ||||
| #if AT_BUILD_WITH_BLAS() | ||||
| template <> | ||||
| bool scal_use_fast_path<double>(int64_t n, int64_t incx) { | ||||
|   auto constexpr intmax = std::numeric_limits<int>::max(); | ||||
|   auto intmax = std::numeric_limits<int>::max(); | ||||
|   return n <= intmax && incx <= intmax; | ||||
| } | ||||
|  | ||||
| @ -315,7 +315,7 @@ bool gemv_use_fast_path<float>( | ||||
|     int64_t incx, | ||||
|     [[maybe_unused]] float beta, | ||||
|     int64_t incy) { | ||||
|   auto constexpr intmax = std::numeric_limits<int>::max(); | ||||
|   auto intmax = std::numeric_limits<int>::max(); | ||||
|   return (m <= intmax) && (n <= intmax) && (lda <= intmax) && | ||||
|          (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); | ||||
| } | ||||
|  | ||||
| @ -658,7 +658,6 @@ static void check_shape_forward(const at::Tensor& input, | ||||
|   TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); | ||||
|   TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); | ||||
|   TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); | ||||
|   TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups); | ||||
|  | ||||
|   TORCH_CHECK(weight_dim == k, | ||||
|            "Expected ", weight_dim, "-dimensional input for ", weight_dim, | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <array> | ||||
| #include <ATen/native/Math.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/MathConstants.h> | ||||
| @ -128,7 +127,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor | ||||
|  | ||||
| template<typename scalar_t> | ||||
| C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | ||||
|   constexpr static scalar_t kTailValues[] = { | ||||
|   const static scalar_t kTailValues[] = { | ||||
|     0.0810614667953272, | ||||
|     0.0413406959554092, | ||||
|     0.0276779256849983, | ||||
| @ -140,7 +139,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | ||||
|     0.00925546218271273, | ||||
|     0.00833056343336287 | ||||
|   }; | ||||
|   if (k < std::size(kTailValues)) { | ||||
|   if (k <= 9) { | ||||
|     return kTailValues[static_cast<size_t>(k)]; | ||||
|   } | ||||
|   scalar_t kp1sq = (k + 1) * (k + 1); | ||||
|  | ||||
| @ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, | ||||
| template <typename scalar_t> | ||||
| static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|   // lanczos approximation | ||||
|   static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|   static const scalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|     0.006061842346248906525783753964555936883222, | ||||
|     0.5098416655656676188125178644804694509993, | ||||
|     19.51992788247617482847860966235652136208, | ||||
| @ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|     103794043.1163445451906271053616070238554, | ||||
|     56906521.91347156388090791033559122686859 | ||||
|   }; | ||||
|   static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|   static const scalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|     1., | ||||
|     66., | ||||
|     1925., | ||||
| @ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { | ||||
| template <typename scalar_t> | ||||
| static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { | ||||
|   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] | ||||
|   static constexpr scalar_t d[25][25] = | ||||
|   static const scalar_t d[25][25] = | ||||
|     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, | ||||
|       1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, | ||||
|       3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, | ||||
|  | ||||
| @ -62,7 +62,7 @@ | ||||
| #include <utility> | ||||
| #include <vector> | ||||
|  | ||||
| static constexpr int MIOPEN_DIM_MAX = 5; | ||||
| static const int MIOPEN_DIM_MAX = 5; | ||||
|  | ||||
| namespace at::meta { | ||||
|  | ||||
|  | ||||
| @ -1906,9 +1906,11 @@ Tensor& index_fill_( | ||||
|         "This also applies to advanced indexing e.g. tensor[mask] = scalar"); | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       self.is_complex() || !source.isComplex(), | ||||
|       "index_fill_(): Converting complex Scalar to non-complex type is not supported"); | ||||
|   if (!self.is_complex() && source.isComplex()) { | ||||
|     TORCH_CHECK( | ||||
|         false, | ||||
|         "index_fill_(): Converting complex Scalar to non-complex type is not supported"); | ||||
|   } | ||||
|  | ||||
|   // Handle the case when `self` is 0-dim | ||||
|   Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; | ||||
|  | ||||
| @ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { | ||||
|   // next broadcast all index tensors together | ||||
|   try { | ||||
|     indices = expand_outplace(indices); | ||||
|   } catch (std::exception&) { | ||||
|   } catch (std::exception& e) { | ||||
|     TORCH_CHECK_INDEX( | ||||
|         false, | ||||
|         "shape mismatch: indexing tensors could not be broadcast together" | ||||
|  | ||||
| @ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase { | ||||
|   // We keep this structure for BC and consider as deprecated. | ||||
|   // See HelperInterpNearestExact as replacement | ||||
|  | ||||
|   static constexpr int interp_size = 1; | ||||
|   static const int interp_size = 1; | ||||
|  | ||||
|   static inline void init_indices_weights( | ||||
|     at::ScalarType output_type, | ||||
| @ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest { | ||||
|  | ||||
| struct HelperInterpLinear : public HelperInterpBase { | ||||
|  | ||||
|   static constexpr int interp_size = 2; | ||||
|   static const int interp_size = 2; | ||||
|  | ||||
|   // Compute indices and weights for each interpolated dimension | ||||
|   // indices_weights = { | ||||
| @ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase { | ||||
|  | ||||
| struct HelperInterpCubic : public HelperInterpBase { | ||||
|  | ||||
|   static constexpr int interp_size = 4; | ||||
|   static const int interp_size = 4; | ||||
|  | ||||
|   // Compute indices and weights for each interpolated dimension | ||||
|   // indices_weights = { | ||||
|  | ||||
| @ -1359,8 +1359,7 @@ _scaled_gemm( | ||||
|           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const bool use_fast_accum, | ||||
|           Tensor& out, | ||||
|           const std::optional<Tensor>& alpha = std::nullopt) { | ||||
|           Tensor& out) { | ||||
|   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); | ||||
|   const auto out_dtype_ = args.result->scalar_type(); | ||||
|   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); | ||||
| @ -1411,8 +1410,7 @@ _scaled_gemm( | ||||
|           args.scale_result_ptr, | ||||
|           args.result_ld, | ||||
|           out_dtype_, | ||||
|           use_fast_accum, | ||||
|           alpha); | ||||
|           use_fast_accum); | ||||
|       return out; | ||||
|   } | ||||
| } | ||||
| @ -1761,7 +1759,6 @@ enum class ScaledGemmImplementation { | ||||
|   MXFP8_MXFP8 = 6, | ||||
|   NVFP4_NVFP4 = 7, | ||||
|   NVFP4_NVFP4_SINGLE_SCALE = 8, | ||||
|   MXFP4_MXFP4 = 9, | ||||
| }; | ||||
|  | ||||
| /** | ||||
| @ -1958,39 +1955,10 @@ bool check_mxfp8_recipe(c10::ScalarType type_a, | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * Both inputs must be fp4 | ||||
|  * A, B must have 1 scale each, {Blockwise_1x32, e8m0} | ||||
|  */ | ||||
| bool check_mxfp4_recipe(c10::ScalarType type_a, | ||||
|                         std::vector<ScalingType>& recipe_a, | ||||
|                         ArrayRef<Tensor>& scales_a, | ||||
|                         c10::ScalarType type_b, | ||||
|                         std::vector<ScalingType>& recipe_b, | ||||
|                         ArrayRef<Tensor>& scales_b) { | ||||
|   // both types must be fp4 | ||||
|   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   // 1 scales, 1 recipes for each input | ||||
|   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   // Need {Blockwise_1x32, e8m0} for A & B | ||||
|   if (recipe_a[0] != ScalingType::BlockWise1x32) return false; | ||||
|   if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||
|   if (recipe_b[0] != ScalingType::BlockWise1x32) return false; | ||||
|   if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||
|  | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>; | ||||
| using namespace std::placeholders; | ||||
|  | ||||
| std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{ | ||||
| std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{ | ||||
|   { "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE }, | ||||
|   { "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, | ||||
|   { "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6), | ||||
| @ -2001,8 +1969,7 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> | ||||
|     ScaledGemmImplementation::BLOCK_1x128_1x128}, | ||||
|   { "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}, | ||||
|   { "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE }, | ||||
|   { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}, | ||||
|   { "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}}; | ||||
|   { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; | ||||
|  | ||||
| Tensor& | ||||
| _scaled_tensorwise_tensorwise( | ||||
| @ -2220,22 +2187,15 @@ _scaled_mxfp8_mxfp8( | ||||
|   TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", | ||||
|       mat_a.scalar_type(), mat_b.scalar_type()); | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   auto scale_a_elems = ceil_div<int64_t>(mat_a.size(0), 32) * mat_a.size(1); | ||||
|   auto scale_b_elems = ceil_div<int64_t>(mat_b.size(1), 32) * mat_b.size(0); | ||||
| #else | ||||
|   auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4); | ||||
|   auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4); | ||||
| #endif | ||||
|   TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), | ||||
|          "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); | ||||
|   TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), | ||||
|          "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); | ||||
|   TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); | ||||
| #endif | ||||
|  | ||||
|   TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), | ||||
|         "For Blockwise scaling both scales should be contiguous"); | ||||
| @ -2265,56 +2225,6 @@ _scaled_mxfp8_mxfp8( | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
| } | ||||
|  | ||||
|  | ||||
| Tensor& | ||||
| _scaled_mxfp4_mxfp4( | ||||
|           const Tensor& mat_a, const Tensor& mat_b, | ||||
|           const Tensor& scale_a, const SwizzleType swizzle_a, | ||||
|           const Tensor& scale_b, const SwizzleType swizzle_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const c10::ScalarType out_dtype, | ||||
|           Tensor& out) { | ||||
| #ifndef USE_ROCM | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only"); | ||||
| #endif | ||||
|   // Restrictions: | ||||
|   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 | ||||
|   TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", | ||||
|       mat_a.scalar_type(), mat_b.scalar_type()); | ||||
|  | ||||
|   auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1); | ||||
|   auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0); | ||||
|   TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), | ||||
|          "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); | ||||
|   TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), | ||||
|          "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); | ||||
|  | ||||
|   TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), | ||||
|         "For Blockwise scaling both scales should be contiguous"); | ||||
|  | ||||
|   TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype); | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x32; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x32; | ||||
|  | ||||
| #if ROCM_VERSION >= 70000 | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||||
|               "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||||
|  | ||||
|   TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && | ||||
|               mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, | ||||
|               "Matrix dimensions must be multiples of 32 for block-wise scaling"); | ||||
|  | ||||
|   TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || | ||||
|               out.scalar_type() == ScalarType::Half, | ||||
|               "Block-wise scaling only supports BFloat16 or Half output types"); | ||||
| #else | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); | ||||
| #endif | ||||
|  | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
| } | ||||
|  | ||||
| Tensor& | ||||
| _scaled_nvfp4_nvfp4( | ||||
|           const Tensor& mat_a, const Tensor& mat_b, | ||||
| @ -2322,23 +2232,12 @@ _scaled_nvfp4_nvfp4( | ||||
|           const Tensor& scale_b, const SwizzleType swizzle_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const c10::ScalarType out_dtype, | ||||
|           Tensor& out, | ||||
|           const std::optional<Tensor>& global_scale_a = std::nullopt, | ||||
|           const std::optional<Tensor>& global_scale_b = std::nullopt) { | ||||
|           const bool single_scale, | ||||
|           Tensor& out) { | ||||
| #ifdef USE_ROCM | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM"); | ||||
| #endif | ||||
|   std::optional<Tensor> alpha = std::nullopt; | ||||
|   // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise, | ||||
|   //       if this is "And" we would silently do nothing in the case where one global scale is | ||||
|   //       passed and not the other. | ||||
|   if (global_scale_a.has_value() || global_scale_b.has_value()) { | ||||
|     TORCH_CHECK_VALUE(global_scale_a.has_value(), | ||||
|         "For two-level-scaled NVFP4, global_scale_a must have a value"); | ||||
|     TORCH_CHECK_VALUE(global_scale_b.has_value(), | ||||
|         "For two-level-scaled NVFP4, global_scale_b must have a value"); | ||||
|     alpha = global_scale_a.value().mul(global_scale_b.value()); | ||||
|   } | ||||
|   TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported"); | ||||
|   // Restrictions: | ||||
|   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 | ||||
|   // Scales must be swizzled | ||||
| @ -2360,7 +2259,7 @@ _scaled_nvfp4_nvfp4( | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x16; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x16; | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha); | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -2566,12 +2465,9 @@ _scaled_mm_cuda_v2_out( | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) { | ||||
|     return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) { | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out, | ||||
|                                scale_a[1], scale_b[1]); | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported"); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) { | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) { | ||||
|     return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out); | ||||
|   } else { | ||||
|     TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really"); | ||||
|   } | ||||
|  | ||||
| @ -38,41 +38,12 @@ __device__ inline int min(int a, int b) { | ||||
| #define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched | ||||
| #endif | ||||
|  | ||||
| template <typename index_t> | ||||
| static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) { | ||||
|   const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1); | ||||
|   return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1; | ||||
| static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) { | ||||
|   return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1; | ||||
| } | ||||
|  | ||||
| template <typename index_t> | ||||
| static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) { | ||||
|   return std::min((size + pad) / stride + 1, pooled_size); | ||||
| } | ||||
|  | ||||
| static inline bool can_use_int32_nhwc( | ||||
|     int64_t nbatch, int64_t channels, | ||||
|     int64_t height, int64_t width, | ||||
|     int64_t pooled_height, int64_t pooled_width, | ||||
|     int64_t in_stride_n, int64_t in_stride_c, | ||||
|     int64_t in_stride_h, int64_t in_stride_w) | ||||
| { | ||||
|   constexpr int64_t int_max = std::numeric_limits<int>::max(); | ||||
|  | ||||
|   int64_t max_intra_batch = | ||||
|       (height ? (height - 1) * in_stride_h : 0) + | ||||
|       (width ? (width - 1) * in_stride_w : 0) + | ||||
|       (channels? (channels - 1) * in_stride_c : 0); | ||||
|  | ||||
|   int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch; | ||||
|  | ||||
|   if (max_input_offset > int_max) return false; | ||||
|  | ||||
|   int64_t out_batch_stride = pooled_height * pooled_width * channels; | ||||
|   if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false; | ||||
|  | ||||
|   if (height * width > int_max) return false; | ||||
|  | ||||
|   return true; | ||||
| static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) { | ||||
|   return min((size + pad) / stride + 1, pooled_size); | ||||
| } | ||||
|  | ||||
| // kernels borrowed from Caffe | ||||
| @ -114,25 +85,21 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename index_t> | ||||
| template <typename scalar_t> | ||||
| C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS) | ||||
| __global__ void max_pool_forward_nhwc( | ||||
|     const scalar_t* bottom_data, | ||||
|     const int nbatch, | ||||
|     const index_t channels, const index_t height, const index_t width, | ||||
|     const index_t pooled_height, const index_t pooled_width, | ||||
|     const int kernel_h, const int kernel_w, const int stride_h, | ||||
|     const int stride_w, const int pad_h, const int pad_w, | ||||
|     const int dilation_h, const int dilation_w, | ||||
|     const index_t in_stride_n, const index_t in_stride_c, | ||||
|     const index_t in_stride_h, const index_t in_stride_w, | ||||
|     const int kernel_stride_C, const int kernel_size_C, | ||||
|     scalar_t* top_data, int64_t* top_mask) { | ||||
|  | ||||
|   extern __shared__ unsigned char smem_raw[]; | ||||
|   index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw); | ||||
|   scalar_t *out_cached = reinterpret_cast<scalar_t*>( | ||||
|       out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z); | ||||
| __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, | ||||
|                                    const int64_t channels, const int64_t height, | ||||
|                                    const int64_t width, const int pooled_height, const int pooled_width, | ||||
|                                    const int kernel_h, const int kernel_w, const int stride_h, | ||||
|                                    const int stride_w, const int pad_h, const int pad_w, | ||||
|                                    const int dilation_h, const int dilation_w, | ||||
|                                    const int in_stride_n, const int in_stride_c, | ||||
|                                    const int in_stride_h, const int in_stride_w, | ||||
|                                    const int kernel_stride_C, const int kernel_size_C, | ||||
|                                    scalar_t* top_data, int64_t* top_mask) { | ||||
|   extern __shared__ int smem[]; | ||||
|   int *out_mask_cached = smem; | ||||
|   scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]); | ||||
|  | ||||
|   // flattening cta for pre-computation & smem initialization; | ||||
|   int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); | ||||
| @ -151,26 +118,26 @@ __global__ void max_pool_forward_nhwc( | ||||
|   int channel_id = blockIdx.x / nbatch; | ||||
|   int channel_offset = threadIdx.x + channel_id * blockDim.x; | ||||
|  | ||||
|   top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels); | ||||
|   top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels); | ||||
|   bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n; | ||||
|   top_data = top_data + batch_id * pooled_height * pooled_width * channels; | ||||
|   top_mask = top_mask + batch_id * pooled_height * pooled_width * channels; | ||||
|   bottom_data = bottom_data + batch_id * in_stride_n; | ||||
|  | ||||
|   out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x; | ||||
|   out_mask_cached  += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x; | ||||
|   out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; | ||||
|   out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; | ||||
|  | ||||
|   int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z; | ||||
|   int oW = (static_cast<int>(pooled_width)  + gridDim.y - 1) / gridDim.y; | ||||
|   int oH = (pooled_height + gridDim.z-1) / gridDim.z; | ||||
|   int oW = (pooled_width + gridDim.y-1) / gridDim.y; | ||||
|   int ostartH = threadIdx.z + blockIdx.z*oH; | ||||
|   int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height)); | ||||
|   int oendH = ::min(ostartH+oH, pooled_height); | ||||
|   int ostartW = threadIdx.y + blockIdx.y*oW; | ||||
|   int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width)); | ||||
|   int oendW = ::min(ostartW+oW, pooled_width); | ||||
|  | ||||
|   for (int oh = ostartH; oh < oendH; oh+=blockDim.z) { | ||||
|     index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h; | ||||
|     index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height); | ||||
|     int hstart = oh * stride_h - pad_h; | ||||
|     int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height); | ||||
|     for (int ow = ostartW; ow < oendW; ow+=blockDim.y) { | ||||
|       index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w; | ||||
|       index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width); | ||||
|       int wstart = ow * stride_w - pad_w; | ||||
|       int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width); | ||||
|       while(hstart < 0) | ||||
|         hstart += dilation_h; | ||||
|       while(wstart < 0) | ||||
| @ -218,12 +185,12 @@ __global__ void max_pool_forward_nhwc( | ||||
|       // Else do it Non-Prefetch... | ||||
|       else | ||||
| #endif | ||||
|       for (index_t ih = hstart; ih < hend; ih += dilation_h) { | ||||
|         for (index_t iw = wstart; iw < wend; iw += dilation_w) { | ||||
|       for (int ih = hstart; ih < hend; ih += dilation_h) { | ||||
|         for (int iw = wstart; iw < wend; iw += dilation_w) { | ||||
|           int cached_index = threadIdx.x; | ||||
|           const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; | ||||
|           for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) { | ||||
|             scalar_t val = ptr_input[c * in_stride_c]; | ||||
|           for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { | ||||
|             scalar_t val = ptr_input[c*in_stride_c]; | ||||
|             if ((val > out_cached[cached_index]) || at::_isnan(val)) { | ||||
|               out_cached[cached_index] = val; | ||||
|               out_mask_cached[cached_index] = ih * width + iw; | ||||
| @ -233,15 +200,15 @@ __global__ void max_pool_forward_nhwc( | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels; | ||||
|       int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels; | ||||
|       scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels; | ||||
|       int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels; | ||||
|  | ||||
|       int cached_index = threadIdx.x; | ||||
|       for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) { | ||||
|       for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { | ||||
|         ptr_output_data[c] = out_cached[cached_index]; | ||||
|         ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]); | ||||
|         ptr_output_mask[c] = out_mask_cached[cached_index]; | ||||
|         out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound(); | ||||
|         out_mask_cached[cached_index] = index_t(0); | ||||
|         out_mask_cached[cached_index] = 0; | ||||
|         cached_index += blockDim.x; | ||||
|       } | ||||
|     } | ||||
| @ -249,7 +216,7 @@ __global__ void max_pool_forward_nhwc( | ||||
| } | ||||
|  | ||||
|  | ||||
| static constexpr int BLOCK_THREADS = 256; | ||||
| static const int BLOCK_THREADS = 256; | ||||
|  | ||||
| template <typename scalar_t, typename accscalar_t> | ||||
| #if defined (USE_ROCM) | ||||
| @ -495,11 +462,6 @@ const Tensor& indices) { | ||||
|               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z)); | ||||
|           const dim3 block(block_x, block_y, block_z); | ||||
|  | ||||
|           bool use_int32 = can_use_int32_nhwc( | ||||
|               nbatch, nInputPlane, inputHeight, inputWidth, | ||||
|               outputHeight, outputWidth, | ||||
|               in_stride_n, in_stride_c, in_stride_h, in_stride_w); | ||||
|  | ||||
|           int kernel_stride_C = ceil_div( | ||||
|               safe_downcast<int, int64_t>(nInputPlane), block_x * 4); | ||||
|           int kernel_size_C = ceil_div( | ||||
| @ -514,41 +476,18 @@ const Tensor& indices) { | ||||
|               ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD)); | ||||
|           const dim3 grid(grid_x, grid_y, grid_z); | ||||
|  | ||||
|           size_t shmem_size; | ||||
|           size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z; | ||||
|           size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t)); | ||||
|           AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); | ||||
|  | ||||
|           if (use_int32) { | ||||
|             shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t)); | ||||
|             TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock, | ||||
|                         "shared memory too small"); | ||||
|             max_pool_forward_nhwc<scalar_t, int32_t> | ||||
|               <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>( | ||||
|                 input_data, static_cast<int>(nbatch), | ||||
|                 static_cast<int32_t>(nInputPlane), | ||||
|                 static_cast<int32_t>(inputHeight), | ||||
|                 static_cast<int32_t>(inputWidth), | ||||
|                 static_cast<int32_t>(outputHeight), | ||||
|                 static_cast<int32_t>(outputWidth), | ||||
|                 kH, kW, dH, dW, padH, padW, dilationH, dilationW, | ||||
|                 static_cast<int32_t>(in_stride_n), | ||||
|                 static_cast<int32_t>(in_stride_c), | ||||
|                 static_cast<int32_t>(in_stride_h), | ||||
|                 static_cast<int32_t>(in_stride_w), | ||||
|                 kernel_stride_C, kernel_size_C, | ||||
|                 output_data, indices_data); | ||||
|           } else { | ||||
|             shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t)); | ||||
|             TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock, | ||||
|                         "shared memory too small"); | ||||
|             max_pool_forward_nhwc<scalar_t, int64_t> | ||||
|               <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>( | ||||
|                 input_data, static_cast<int>(nbatch), | ||||
|                 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, | ||||
|                 kH, kW, dH, dW, padH, padW, dilationH, dilationW, | ||||
|                 in_stride_n, in_stride_c, in_stride_h, in_stride_w, | ||||
|                 kernel_stride_C, kernel_size_C, | ||||
|                 output_data, indices_data); | ||||
|           } | ||||
|           max_pool_forward_nhwc<scalar_t> | ||||
|           <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>( | ||||
|               input_data, nbatch, | ||||
|                   nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, | ||||
|                   kH, kW, dH, dW, padH, padW, dilationH, dilationW, | ||||
|                   in_stride_n, in_stride_c, | ||||
|                   in_stride_h, in_stride_w, | ||||
|                   kernel_stride_C, kernel_size_C, | ||||
|                   output_data, indices_data); | ||||
|           C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|           break; | ||||
|         } | ||||
|  | ||||
| @ -15,7 +15,9 @@ | ||||
| #include <ATen/native/cuda/block_reduce.cuh> | ||||
| #include <ATen/native/cuda/thread_constants.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| #include <thrust/iterator/reverse_iterator.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -34,9 +36,9 @@ namespace at::native { | ||||
| namespace { | ||||
|  | ||||
| #if defined(USE_ROCM) | ||||
| static constexpr int BLOCKDIMY = 16; | ||||
| static const int BLOCKDIMY = 16; | ||||
| #else | ||||
| static constexpr int BLOCKDIMY = 32; | ||||
| static const int BLOCKDIMY = 32; | ||||
| #endif | ||||
|  | ||||
| template | ||||
| @ -238,6 +240,10 @@ __global__ void renorm_kernel( | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); | ||||
| #endif | ||||
|  | ||||
| Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, | ||||
|                                int64_t num_weights, int64_t padding_idx, | ||||
| @ -300,6 +306,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | ||||
|  | ||||
|   if (scale_grad_by_freq) { | ||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
| @ -326,6 +333,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | ||||
|         num_indices | ||||
|       ); | ||||
|     }); | ||||
| #else | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); | ||||
|     }); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, | ||||
|  | ||||
| @ -10,7 +10,9 @@ | ||||
|  | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| #include <thrust/iterator/counting_iterator.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -194,9 +196,18 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm | ||||
|             partials_per_segment_offset[num_of_segments-1]; | ||||
| } | ||||
|  | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| __global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) { | ||||
|   *num_of_segments_ptr = num_of_segments; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| } // anon namespace | ||||
|  | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| template<typename index_t> | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
| #endif | ||||
|  | ||||
| Tensor embedding_backward_cuda_kernel( | ||||
|         const Tensor &grad, | ||||
| @ -223,12 +234,20 @@ Tensor embedding_backward_cuda_kernel( | ||||
|   auto segment_offsets = at::empty({numel}, orig_indices.options()); | ||||
|   auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); | ||||
|   int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>(); | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||
|     int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets); | ||||
|     write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   }); | ||||
| #else | ||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||
|     cuda::cub::unique_by_key( | ||||
|       sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0), | ||||
|       segment_offsets.mutable_data_ptr<index_t>(), | ||||
|       num_of_segments_ptr, sorted_indices.numel()); | ||||
|   }); | ||||
| #endif | ||||
|  | ||||
|   int64_t max_segments = std::min<int64_t>(numel, num_weights); | ||||
|  | ||||
|  | ||||
| @ -31,10 +31,16 @@ | ||||
|  | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| #include <thrust/iterator/reverse_iterator.h> | ||||
| #endif | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); | ||||
| #endif | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| @ -193,6 +199,7 @@ Tensor embedding_bag_backward_cuda_sum_avg( | ||||
|  | ||||
|   if (scale_grad_by_freq) { | ||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
| @ -219,6 +226,11 @@ Tensor embedding_bag_backward_cuda_sum_avg( | ||||
|         num_indices | ||||
|       ); | ||||
|     }); | ||||
| #else | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); | ||||
|     }); | ||||
| #endif | ||||
|   } | ||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, | ||||
|       count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, | ||||
|  | ||||
| @ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|   // lanczos approximation | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|  | ||||
|   constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|   static const accscalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|     0.006061842346248906525783753964555936883222, | ||||
|     0.5098416655656676188125178644804694509993, | ||||
|     19.51992788247617482847860966235652136208, | ||||
| @ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|     103794043.1163445451906271053616070238554, | ||||
|     56906521.91347156388090791033559122686859 | ||||
|   }; | ||||
|   constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|   static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|     1., | ||||
|     66., | ||||
|     1925., | ||||
| @ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   accscalar_t ax, fac, res, num, numfac; | ||||
|   constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? | ||||
|   static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? | ||||
|     7.09782712893383996843E2 : 88.72283905206835; | ||||
|   constexpr accscalar_t EXP1 = 2.718281828459045; | ||||
|   constexpr accscalar_t lanczos_g = 6.024680040776729583740234375; | ||||
|   static const accscalar_t EXP1 = 2.718281828459045; | ||||
|   static const accscalar_t lanczos_g = 6.024680040776729583740234375; | ||||
|  | ||||
|   if (::fabs(a - x) > 0.4 * ::fabs(a)) { | ||||
|     ax = a * ::log(x) - x - ::lgamma(a); | ||||
| @ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) { | ||||
|   // Compute igam using DLMF 8.11.4. [igam1] | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|   constexpr int MAXITER = 2000; | ||||
|   static const int MAXITER = 2000; | ||||
|  | ||||
|   int i; | ||||
|   accscalar_t ans, ax, c, r; | ||||
| @ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { | ||||
|   accscalar_t fac = 1; | ||||
|   accscalar_t sum = 0; | ||||
|   accscalar_t term, logx; | ||||
|   constexpr int MAXITER = 2000; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const int MAXITER = 2000; | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|  | ||||
|   for (n = 1; n < MAXITER; n++) { | ||||
| @ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t | ||||
|   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   constexpr accscalar_t d[25][25] = | ||||
|   static const accscalar_t d[25][25] = | ||||
|     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, | ||||
|     {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, | ||||
|     {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, | ||||
| @ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t | ||||
|  | ||||
|   int k, n, sgn; | ||||
|   int maxpow = 0; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|   accscalar_t lambda = x / a; | ||||
|   accscalar_t sigma = (x - a) / a; | ||||
| @ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar | ||||
|   int i; | ||||
|   accscalar_t ans, ax, c, yc, r, t, y, z; | ||||
|   accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; | ||||
|   constexpr int MAXITER = 2000; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const int MAXITER = 2000; | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|   constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ? | ||||
|   static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ? | ||||
|     4.503599627370496e15 : 16777216.; | ||||
|   constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? | ||||
|   static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? | ||||
|     2.22044604925031308085e-16 : 5.9604644775390625E-8; | ||||
|  | ||||
|   ax = _igam_helper_fac(a, x); | ||||
| @ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   accscalar_t absxma_a; | ||||
|  | ||||
|   constexpr accscalar_t SMALL = 20.0; | ||||
|   constexpr accscalar_t LARGE = 200.0; | ||||
|   constexpr accscalar_t SMALLRATIO = 0.3; | ||||
|   constexpr accscalar_t LARGERATIO = 4.5; | ||||
|   static const accscalar_t SMALL = 20.0; | ||||
|   static const accscalar_t LARGE = 200.0; | ||||
|   static const accscalar_t SMALLRATIO = 0.3; | ||||
|   static const accscalar_t LARGERATIO = 4.5; | ||||
|  | ||||
|   if ((x < 0) || (a < 0)) { | ||||
|     // out of defined-region of the function | ||||
| @ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   accscalar_t absxma_a; | ||||
|   constexpr accscalar_t SMALL = 20.0; | ||||
|   constexpr accscalar_t LARGE = 200.0; | ||||
|   constexpr accscalar_t SMALLRATIO = 0.3; | ||||
|   constexpr accscalar_t LARGERATIO = 4.5; | ||||
|   static const accscalar_t SMALL = 20.0; | ||||
|   static const accscalar_t LARGE = 200.0; | ||||
|   static const accscalar_t SMALLRATIO = 0.3; | ||||
|   static const accscalar_t LARGERATIO = 4.5; | ||||
|  | ||||
|   // boundary values following SciPy | ||||
|   if ((x < 0) || (a < 0)) { | ||||
|  | ||||
							
								
								
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,90 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/native/cuda/SortingCommon.cuh> | ||||
| #include <ATen/cuda/cub_definitions.cuh> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| #else | ||||
| #include <ATen/ops/empty_like.h> | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cuda/ThrustAllocator.h> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/execution_policy.h> | ||||
| #include <thrust/sort.h> | ||||
| #include <thrust/unique.h> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/iterator/constant_iterator.h> | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|  | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   at::cuda::ThrustAllocator allocator; | ||||
|   auto policy = thrust::cuda::par(allocator).on(stream); | ||||
|  | ||||
|   auto num_indices = count.numel(); | ||||
|  | ||||
|   // Compute an increasing sequence per unique item in sortedIndices: | ||||
|   // sorted: 2 5 5 5 7 7 8 9 9 | ||||
|   //  count: 1 1 2 3 1 2 1 1 2 | ||||
|   auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); | ||||
|   auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>()); | ||||
|   thrust::inclusive_scan_by_key( | ||||
|     policy, | ||||
|     sorted_data, | ||||
|     sorted_data + num_indices, | ||||
|     thrust::make_constant_iterator(1), | ||||
|     count_data | ||||
|   ); | ||||
|  | ||||
|   // Take the maximum of each count per unique key in reverse: | ||||
|   // sorted: 2 5 5 5 7 7 8 9 9 | ||||
|   //  count: 1 3 3 3 2 2 1 2 2 | ||||
|   thrust::inclusive_scan_by_key( | ||||
|     policy, | ||||
|     thrust::make_reverse_iterator(sorted_data + num_indices), | ||||
|     thrust::make_reverse_iterator(sorted_data), | ||||
|     thrust::make_reverse_iterator(count_data + num_indices), | ||||
|     thrust::make_reverse_iterator(count_data + num_indices), | ||||
|     thrust::equal_to<index_t>(), | ||||
|     thrust::maximum<index_t>() | ||||
|   ); | ||||
| } | ||||
|  | ||||
| template | ||||
| void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count); | ||||
| template | ||||
| void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count); | ||||
|  | ||||
| #endif | ||||
|  | ||||
| template<typename index_t> | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(); | ||||
|   at::cuda::ThrustAllocator allocator; | ||||
|   auto policy = thrust::cuda::par(allocator).on(stream); | ||||
|   const ptrdiff_t numel = sorted_indices.numel(); | ||||
|   auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); | ||||
|   auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
|   auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>()); | ||||
|   auto ends = thrust::unique_by_key_copy( | ||||
|           policy, | ||||
|           sorted_indices_dev, | ||||
|           sorted_indices_dev + numel, | ||||
|           thrust::make_counting_iterator(0), | ||||
|           dummy_dev, | ||||
|           thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>())); | ||||
|   return thrust::get<0>(ends) - dummy_dev; | ||||
| } | ||||
|  | ||||
| template | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
| template | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
|  | ||||
| } // namespace at::native | ||||
| @ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify( | ||||
| const auto digamma_string = jiterator_stringify( | ||||
|   template <typename T> | ||||
|   T digamma(T x) { | ||||
|     static constexpr double PI_f64 = 3.14159265358979323846; | ||||
|     static const double PI_f64 = 3.14159265358979323846; | ||||
|  | ||||
|     // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard | ||||
|     if (x == 0) { | ||||
| @ -3072,9 +3072,9 @@ template <typename scalar_t> | ||||
| static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { | ||||
|   // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   static constexpr double PI_f64 = 3.14159265358979323846; | ||||
|   constexpr accscalar_t PSI_10 = 2.25175258906672110764; | ||||
|   constexpr accscalar_t A[] = { | ||||
|   static const double PI_f64 = 3.14159265358979323846; | ||||
|   const accscalar_t PSI_10 = 2.25175258906672110764; | ||||
|   const accscalar_t A[] = { | ||||
|       8.33333333333333333333E-2, | ||||
|       -2.10927960927960927961E-2, | ||||
|       7.57575757575757575758E-3, | ||||
|  | ||||
| @ -1097,7 +1097,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ | ||||
|   // threads with different threadIdx.x are independent and will produce results for different outputs. | ||||
|   // In such case, values in each loaded vector always correspond to different outputs. | ||||
|   if (fastest_moving_stride == sizeof(scalar_t)) { | ||||
| #ifdef USE_ROCM | ||||
|     if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) { | ||||
| #else | ||||
|     if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) { | ||||
| #endif | ||||
|       // Case 1: "vectorize along input" | ||||
|       // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, | ||||
|       // we should avoid vectorization. | ||||
|  | ||||
| @ -39,14 +39,9 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta | ||||
| template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> | ||||
| void mean_kernel_impl(TensorIterator& iter) { | ||||
|   //  returns acc_t for all non-complex dtypes and returns T for c10::complex<T> | ||||
|   constexpr bool is_16_bits = sizeof(scalar_t) == 2; | ||||
|   using factor_t = typename c10::scalar_value_type<acc_t>::type; | ||||
|   factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel(); | ||||
|   if constexpr (is_16_bits) { | ||||
|     gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); | ||||
|   } else { | ||||
|     gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); | ||||
|   } | ||||
|   gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); | ||||
| } | ||||
|  | ||||
| static void mean_kernel_cuda(TensorIterator& iter) { | ||||
|  | ||||
| @ -13,19 +13,24 @@ namespace at::native { | ||||
| template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t> | ||||
| struct sum_functor { | ||||
|   void operator()(TensorIterator& iter) { | ||||
|     const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||||
|       return a + b; | ||||
|     }; | ||||
|     constexpr bool is_16_bits = sizeof(scalar_t) == 2; | ||||
|     if constexpr (is_16_bits) { | ||||
| #ifdef USE_ROCM | ||||
|     // Half and BFloat16 can be packed in groups of up to 8 elements and | ||||
|     // can use *_DWORDX4 instructions to achieve that. | ||||
|     const bool is_16_bits = | ||||
|       ( (std::is_same<at::Half, scalar_t>::value) || | ||||
|         (std::is_same<at::BFloat16, scalar_t>::value) ); | ||||
|     if (is_16_bits) { | ||||
|       gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>( | ||||
|         iter, func_wrapper<out_t>(sum_combine) | ||||
|       ); | ||||
|     } else { | ||||
|       gpu_reduce_kernel<scalar_t, out_t>( | ||||
|         iter, func_wrapper<out_t>(sum_combine) | ||||
|       ); | ||||
|         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||||
|           return a + b; | ||||
|         })); | ||||
|       return; | ||||
|     } | ||||
| #endif | ||||
|     gpu_reduce_kernel<scalar_t, out_t>( | ||||
|         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||||
|           return a + b; | ||||
|         })); | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -19,6 +19,7 @@ | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| // TODO: remove this when CUDA <11.6 is no longer supported | ||||
| void topk_out_with_sort( | ||||
|   const Tensor& self, | ||||
|   int64_t k, int64_t dim, bool largest, | ||||
| @ -30,12 +31,21 @@ void topk_out_with_sort( | ||||
|   indices.copy_(sorted_indices.narrow(dim, 0, k)); | ||||
| } | ||||
|  | ||||
| // TODO: remove this when CUDA <11.6 is no longer supported | ||||
| bool disable_sort_for_topk(); | ||||
| bool should_use_sort(const Tensor& self, int64_t dim) { | ||||
| #if defined(USE_ROCM) | ||||
|   if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972 | ||||
|   return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387 | ||||
| #else | ||||
|   return false; | ||||
|   if (disable_sort_for_topk()) return false; | ||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632 | ||||
|   if (self.dim() == 0) return false; | ||||
|   if (self.dtype() == kBool) return false; // Bool is not support by topk | ||||
|   int64_t slice_size = self.size(dim); | ||||
|   if (slice_size == 0) return false; | ||||
|   int64_t num_slices = self.numel() / slice_size; | ||||
|   return num_slices <= 10 && slice_size >= 100000; | ||||
| #endif | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -21,6 +21,11 @@ using namespace at::native; | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| // TODO: remove this when CUDA <11.6 is no longer supported | ||||
| bool disable_sort_for_topk() { | ||||
|   return CUB_SUPPORTS_SCAN_BY_KEY(); | ||||
| } | ||||
|  | ||||
| namespace sbtopk { // single_block_topk | ||||
|  | ||||
| template <typename T> | ||||
| @ -413,6 +418,10 @@ __global__ void computeBlockwiseWithinKCounts( | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   return; | ||||
| #endif | ||||
|  | ||||
|   Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS); | ||||
|  | ||||
|   // if largest, then only threads that has tidx > desired_digit are active | ||||
| @ -468,6 +477,7 @@ __global__ void computeBlockwiseWithinKCounts( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| // Assumption: slice_size can not be larger than UINT32_MAX | ||||
| template <typename Bitwise> | ||||
| __global__ void computeBlockwiseKthCounts( | ||||
| @ -599,6 +609,7 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu | ||||
|     } | ||||
|   } | ||||
| } | ||||
| #endif | ||||
|  | ||||
| int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { | ||||
|   // occupancy of this kernel is limited by registers per threads | ||||
| @ -676,12 +687,16 @@ void launch( | ||||
|   uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get()); | ||||
|   AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream)); | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); | ||||
|   uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get()); | ||||
|   AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); | ||||
|  | ||||
|   auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); | ||||
|   uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get()); | ||||
| #else | ||||
|   uint32_t* withinKCounts = nullptr; | ||||
| #endif | ||||
|  | ||||
|   Bitwise desiredMask = 0; | ||||
|   dim3 grid; | ||||
| @ -728,6 +743,7 @@ void launch( | ||||
|   } | ||||
|   desired = desired_in; | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>( | ||||
|     desired, counts, num_blocks, blocks_per_slice, kthCounts); | ||||
|   C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| @ -743,6 +759,28 @@ void launch( | ||||
|     topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, | ||||
|     blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); | ||||
|   C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| #else | ||||
|   // Find topk values based on kth values | ||||
|   { | ||||
|     dim3 grid; | ||||
|     TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk"); | ||||
|     int warp_size = at::cuda::warp_size(); | ||||
|     dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024)); | ||||
|     sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>( | ||||
|         input, | ||||
|         inputSliceSize, | ||||
|         outputSliceSize, | ||||
|         largest, | ||||
|         numInputSlices, | ||||
|         inputWithinSliceStride, | ||||
|         topK, | ||||
|         topKWithinSliceStride, | ||||
|         indices, | ||||
|         indicesWithinSliceStride, | ||||
|         kthValues); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| } // namespace mbtopk | ||||
| @ -750,6 +788,7 @@ void launch( | ||||
| bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | ||||
|   if (num_slices > std::numeric_limits<uint32_t>::max() || | ||||
|       slice_size > std::numeric_limits<uint32_t>::max()) return false; | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 | ||||
|   return (num_slices <= 20 && slice_size >= 20000) || | ||||
|       (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || | ||||
| @ -758,6 +797,12 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | ||||
|       (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || | ||||
|       (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || | ||||
|       (num_slices > 4000 && slice_size >= 400); | ||||
| #else | ||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081 | ||||
|   return (num_slices <= 400 && slice_size >= 5000) || | ||||
|       (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) || | ||||
|       (num_slices >= 4000 && slice_size >= 300); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void launch_gather_topk_kernel( | ||||
|  | ||||
| @ -277,7 +277,7 @@ struct BilinearFilterFunctor { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   static constexpr int size = 2; | ||||
|   static const int size = 2; | ||||
| }; | ||||
|  | ||||
| // taken from | ||||
| @ -301,7 +301,7 @@ struct BicubicFilterFunctor { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   static constexpr int size = 4; | ||||
|   static const int size = 4; | ||||
| }; | ||||
|  | ||||
| template <typename accscalar_t> | ||||
|  | ||||
| @ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| // Helper function to compute output pixel range that can contribute to input pixel | ||||
| template <typename accscalar_t> | ||||
| __device__ __forceinline__ void compute_output_range( | ||||
|     int input_pos, | ||||
|     accscalar_t scale, | ||||
|     int output_size, | ||||
|     bool align_corners, | ||||
|     int& min_output, | ||||
|     int& max_output) { | ||||
|   accscalar_t lo, hi; | ||||
|   if (align_corners) { | ||||
|       lo = static_cast<accscalar_t>(input_pos - 1) / scale; | ||||
|       hi = static_cast<accscalar_t>(input_pos + 1) / scale; | ||||
|   } else { | ||||
|       lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5); | ||||
|       hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5); | ||||
|   } | ||||
|   min_output = max(0, static_cast<int>(ceil(lo))); | ||||
|   max_output = min(output_size - 1, static_cast<int>(floor(hi))); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| // Backward (adjoint) operation 1 <- 2 (accumulates) | ||||
| template <typename scalar_t, typename accscalar_t> | ||||
| C10_LAUNCH_BOUNDS_1(1024) | ||||
| @ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame( | ||||
|     const bool align_corners, | ||||
|     scalar_t* __restrict__ idata, | ||||
|     const scalar_t* __restrict__ odata) { | ||||
|   const size_t o_numel = nc * width2 * height2; | ||||
|   // In C++, integer multiplication, like in standard arithmetic, is generally commutative. | ||||
|   const size_t i_numel = nc * width1 * height1; | ||||
| #ifdef USE_ROCM | ||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel; | ||||
|        index += blockDim.x * gridDim.x) { | ||||
|     // Decode input pixel coordinates | ||||
|     size_t index_temp = index; | ||||
|     const int w1 = index_temp % width1; | ||||
|     index_temp /= width1; | ||||
|     const int h1 = index_temp % height1; | ||||
|     const size_t nc_idx = index_temp / height1; | ||||
|  | ||||
|     accscalar_t grad_sum = 0; | ||||
|  | ||||
|     // Find range of output pixels that could interpolate from this input pixel | ||||
|     int h2_min, h2_max, w2_min, w2_max; | ||||
|     compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max); | ||||
|     compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max); | ||||
|  | ||||
|     // Iterate over potential output pixels | ||||
|     for (int h2 = h2_min; h2 <= h2_max; h2++) { | ||||
|       for (int w2 = w2_min; w2 <= w2_max; w2++) { | ||||
|         // Compute source coordinates for this output pixel | ||||
|         const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>( | ||||
|             rheight, h2, align_corners, /*cubic=*/false); | ||||
|         const int h1_base = (int)h1r; | ||||
|         const int h1p = (h1_base < height1 - 1) ? 1 : 0; | ||||
|         const accscalar_t h1lambda = h1r - h1_base; | ||||
|         const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda; | ||||
|  | ||||
|         const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>( | ||||
|             rwidth, w2, align_corners, /*cubic=*/false); | ||||
|         const int w1_base = (int)w1r; | ||||
|         const int w1p = (w1_base < width1 - 1) ? 1 : 0; | ||||
|         const accscalar_t w1lambda = w1r - w1_base; | ||||
|         const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda; | ||||
|  | ||||
|         // Check if our input pixel participates in this interpolation and accumulate all weights | ||||
|         // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse | ||||
|         // to the same pixel, so we need to accumulate weights from all matching positions | ||||
|         accscalar_t weight = 0; | ||||
|  | ||||
|         // Check all four interpolation positions and accumulate weights | ||||
|         if (h1 == h1_base && w1 == w1_base) { | ||||
|           weight += h0lambda * w0lambda;  // top-left | ||||
|         } | ||||
|         if (h1 == h1_base && w1 == w1_base + w1p) { | ||||
|           weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0) | ||||
|         } | ||||
|         if (h1 == h1_base + h1p && w1 == w1_base) { | ||||
|           weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0) | ||||
|         } | ||||
|         if (h1 == h1_base + h1p && w1 == w1_base + w1p) { | ||||
|           weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions) | ||||
|         } | ||||
|  | ||||
|         if (weight > 0) { | ||||
|           const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2; | ||||
|           grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Write accumulated gradient (no atomics needed) | ||||
|     idata[index] = static_cast<scalar_t>(grad_sum); | ||||
|   } | ||||
| #else | ||||
|   const size_t o_numel = nc * width2 * height2; | ||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; | ||||
|        index += blockDim.x * gridDim.x) { | ||||
|     size_t index_temp = index; | ||||
| @ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame( | ||||
|         static_cast<scalar_t>(h1lambda * w1lambda * d2val), | ||||
|         true); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename accscalar_t> | ||||
| @ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|   // threads are not covering the whole input tensor. | ||||
|   grad_input.zero_(); | ||||
|  | ||||
|   const size_t num_kernels = nbatch * channels * output_height * output_width; | ||||
|   const int num_threads = std::min( | ||||
|       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
| @ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   constexpr bool use_input = true; | ||||
| #else | ||||
|   constexpr bool use_input = false; | ||||
| #endif | ||||
|  | ||||
|   AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|       at::ScalarType::Half, at::ScalarType::BFloat16, | ||||
|       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { | ||||
| @ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||
|           input_width, output_width, align_corners, scales_w); | ||||
|  | ||||
|       const size_t num_kernels = nbatch * channels * output_height * output_width; | ||||
|  | ||||
|       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t> | ||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>( | ||||
|               input_height, | ||||
| @ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( | ||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||
|           input_width, output_width, align_corners, scales_w); | ||||
|  | ||||
|       const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width); | ||||
|  | ||||
|       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t> | ||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), | ||||
|              num_threads, | ||||
|  | ||||
| @ -141,11 +141,7 @@ WelfordDataLN cuWelfordOnlineSum( | ||||
|   if constexpr (!rms_norm){ | ||||
|     U delta = val - curr_sum.mean; | ||||
|     U new_count = curr_sum.count + 1.f; | ||||
| #if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL) | ||||
|     U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count); | ||||
| #else | ||||
|     U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster | ||||
| #endif | ||||
|     return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; | ||||
|   } else{ | ||||
|     return {0.f, curr_sum.sigma2 + val * val, 0}; | ||||
| @ -163,11 +159,7 @@ WelfordDataLN cuWelfordCombine( | ||||
|     U count = dataA.count + dataB.count; | ||||
|     U mean, sigma2; | ||||
|     if (count > decltype(dataB.count){0}) { | ||||
| #if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL) | ||||
|       auto coef = __builtin_amdgcn_rcpf(count); | ||||
| #else | ||||
|       auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division | ||||
| #endif | ||||
|       auto nA = dataA.count * coef; | ||||
|       auto nB = dataB.count * coef; | ||||
|       mean = nA*dataA.mean + nB*dataB.mean; | ||||
|  | ||||
| @ -487,7 +487,9 @@ std::unique_ptr<fe::graph::Graph> build_graph( | ||||
|   auto scaled_dot_product_flash_attention_options = | ||||
|       fe::graph::SDPA_attributes() | ||||
|           .set_name("CUDNN_SDPA") | ||||
|           .set_generate_stats(return_softmaxstats) | ||||
|           .set_is_inference(return_softmaxstats == false) | ||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded | ||||
|           // .set_generate_stats(return_softmaxstats) | ||||
|           .set_causal_mask(is_causal) | ||||
|           .set_attn_scale(attn_scale); | ||||
|   if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { | ||||
| @ -705,7 +707,9 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor( | ||||
|   auto scaled_dot_product_flash_attention_options = | ||||
|       fe::graph::SDPA_attributes() | ||||
|           .set_name("CUDNN_SDPA_NESTEDTENSOR") | ||||
|           .set_generate_stats(return_softmaxstats) | ||||
|           .set_is_inference(return_softmaxstats == false) | ||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded | ||||
|           // .set_generate_stats(return_softmaxstats) | ||||
|           .set_causal_mask(is_causal) | ||||
|           .set_attn_scale(attn_scale) | ||||
|           .set_seq_len_q(SEQ_LEN_Q_) | ||||
|  | ||||
| @ -160,12 +160,8 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ | ||||
| } | ||||
|  | ||||
| static bool mkldnn_conv_enabled_fpmath_mode_tf32(){ | ||||
| #if defined(__x86_64__) || defined(_M_X64) | ||||
|     return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 && | ||||
|         cpuinfo_has_x86_amx_fp16(); | ||||
| #else | ||||
|     return false;   //TF32 not supported on power system | ||||
| #endif | ||||
|   return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 && | ||||
|       cpuinfo_has_x86_amx_fp16(); | ||||
| } | ||||
|  | ||||
| static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { | ||||
|  | ||||
| @ -74,12 +74,8 @@ static bool use_mkldnn_bf32_linear() { | ||||
| } | ||||
|  | ||||
| static bool use_mkldnn_tf32_linear() { | ||||
| #if defined(__x86_64__) || defined(_M_X64) | ||||
|     return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && | ||||
|   return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && | ||||
|       cpuinfo_has_x86_amx_fp16(); | ||||
| #else | ||||
|   return false;  // TF32 not supported on power system | ||||
| #endif | ||||
| } | ||||
|  | ||||
| Tensor mkldnn_linear( | ||||
|  | ||||
| @ -114,13 +114,8 @@ static bool use_mkldnn_bf32_matmul() { | ||||
|   return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16; | ||||
| } | ||||
|  | ||||
|  | ||||
| static bool use_mkldnn_tf32_matmul() { | ||||
| #if defined(__x86_64__) || defined(_M_X64) | ||||
|     return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32; | ||||
| #else | ||||
|     return false;  // TF32 not supported on power system | ||||
| #endif | ||||
|   return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32; | ||||
| } | ||||
|  | ||||
| // returns an ideep::tensor | ||||
| @ -416,7 +411,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ | ||||
|   // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) | ||||
|   // else called from aten::mv, mat1.size = (m * n), mat2.size = (n) | ||||
|   // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel | ||||
|   constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16; | ||||
|   static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16; | ||||
|   if (mat1.dim() == 1 && mat2.dim() == 1) { | ||||
|     // aten::dot | ||||
|     return mat1.size(0) > mkldnn_gemm_min_size; | ||||
|  | ||||
| @ -712,7 +712,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) { | ||||
|   } else if (scalar.isBoolean()) { | ||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool)); | ||||
|   } else if (scalar.isComplex()) { | ||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexFloat)); | ||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble)); | ||||
|   } else { | ||||
|     TORCH_INTERNAL_ASSERT(scalar.isIntegral(false)); | ||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong)); | ||||
|  | ||||
| @ -1,16 +1,16 @@ | ||||
| #pragma once | ||||
| #include <c10/metal/common.h> | ||||
|  | ||||
| template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim> | ||||
| struct CatSharedParams { | ||||
| template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t> | ||||
| struct CatLargeSharedParams { | ||||
|   int32_t ndim; | ||||
|   int32_t cat_dim; | ||||
|   ::c10::metal::array<idx_type_t, N> output_strides; | ||||
|   ::c10::metal::array<idx_type_t, N> output_sizes; | ||||
| }; | ||||
|  | ||||
| template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim> | ||||
| struct CatInputParams { | ||||
| template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t> | ||||
| struct CatLargeInputParams { | ||||
|   idx_type_t cat_dim_offset; | ||||
|   idx_type_t input_element_offset; | ||||
|   ::c10::metal::array<idx_type_t, N> input_strides; | ||||
|  | ||||
| @ -6,25 +6,26 @@ | ||||
| using namespace metal; | ||||
| using namespace c10::metal; | ||||
|  | ||||
| template <typename I, typename T_in, typename T_out> | ||||
| kernel void cat( | ||||
| template <typename T_in, typename T_out> | ||||
| kernel void cat_large( | ||||
|     constant T_in* input [[buffer(0)]], | ||||
|     device T_out* output [[buffer(1)]], | ||||
|     constant CatSharedParams<I>& shared_params [[buffer(2)]], | ||||
|     constant CatInputParams<I>& input_params [[buffer(3)]], | ||||
|     constant CatLargeSharedParams<>& shared_params [[buffer(2)]], | ||||
|     constant CatLargeInputParams<>& input_params [[buffer(3)]], | ||||
|     uint tid [[thread_position_in_grid]]) { | ||||
|   auto ndim = shared_params.ndim; | ||||
|   auto cat_dim = shared_params.cat_dim; | ||||
|   constant auto& output_strides = shared_params.output_strides; | ||||
|   constant auto& output_sizes = shared_params.output_sizes; | ||||
|  | ||||
|   auto cat_dim_offset = input_params.cat_dim_offset; | ||||
|   auto input_element_offset = input_params.input_element_offset; | ||||
|   constant auto& input_strides = input_params.input_strides; | ||||
|   constant auto& input_sizes = input_params.input_sizes; | ||||
|  | ||||
|   auto input_element_idx = static_cast<I>(tid) + input_element_offset; | ||||
|   I input_offset = 0; | ||||
|   I output_offset = 0; | ||||
|   auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset; | ||||
|   int64_t input_offset = 0; | ||||
|   int64_t output_offset = 0; | ||||
|  | ||||
|   for (auto dim = ndim - 1; dim >= 0; dim--) { | ||||
|     auto dim_size = input_sizes[dim]; | ||||
| @ -41,45 +42,41 @@ kernel void cat( | ||||
|   output[output_offset] = static_cast<T_out>(input[input_offset]); | ||||
| } | ||||
|  | ||||
| #define REGISTER_CAT_OP(I, T_in, T_out)                          \ | ||||
|   template [[host_name("cat_" #I "_" #T_in "_" #T_out)]]         \ | ||||
|   kernel void cat<I, T_in, T_out>(                               \ | ||||
|       constant T_in * input [[buffer(0)]],                       \ | ||||
|       device T_out * output [[buffer(1)]],                       \ | ||||
|       constant CatSharedParams<I> & shared_params [[buffer(2)]], \ | ||||
|       constant CatInputParams<I> & input_params [[buffer(3)]],   \ | ||||
| #define REGISTER_CAT_LARGE_OP(T_in, T_out)                           \ | ||||
|   template [[host_name("cat_large_" #T_in "_" #T_out)]]              \ | ||||
|   kernel void cat_large<T_in, T_out>(                                \ | ||||
|       constant T_in * input [[buffer(0)]],                           \ | ||||
|       device T_out * output [[buffer(1)]],                           \ | ||||
|       constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \ | ||||
|       constant CatLargeInputParams<> & input_params [[buffer(3)]],   \ | ||||
|       uint tid [[thread_position_in_grid]]); | ||||
|  | ||||
| #define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \ | ||||
|   REGISTER_CAT_OP(I, float, T_out);               \ | ||||
|   REGISTER_CAT_OP(I, half, T_out);                \ | ||||
|   REGISTER_CAT_OP(I, bfloat, T_out);              \ | ||||
|   REGISTER_CAT_OP(I, int, T_out);                 \ | ||||
|   REGISTER_CAT_OP(I, uint, T_out);                \ | ||||
|   REGISTER_CAT_OP(I, long, T_out);                \ | ||||
|   REGISTER_CAT_OP(I, ulong, T_out);               \ | ||||
|   REGISTER_CAT_OP(I, short, T_out);               \ | ||||
|   REGISTER_CAT_OP(I, ushort, T_out);              \ | ||||
|   REGISTER_CAT_OP(I, char, T_out);                \ | ||||
|   REGISTER_CAT_OP(I, uchar, T_out);               \ | ||||
|   REGISTER_CAT_OP(I, bool, T_out); | ||||
| #define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \ | ||||
|   REGISTER_CAT_LARGE_OP(float, T_out);               \ | ||||
|   REGISTER_CAT_LARGE_OP(half, T_out);                \ | ||||
|   REGISTER_CAT_LARGE_OP(bfloat, T_out);              \ | ||||
|   REGISTER_CAT_LARGE_OP(int, T_out);                 \ | ||||
|   REGISTER_CAT_LARGE_OP(uint, T_out);                \ | ||||
|   REGISTER_CAT_LARGE_OP(long, T_out);                \ | ||||
|   REGISTER_CAT_LARGE_OP(ulong, T_out);               \ | ||||
|   REGISTER_CAT_LARGE_OP(short, T_out);               \ | ||||
|   REGISTER_CAT_LARGE_OP(ushort, T_out);              \ | ||||
|   REGISTER_CAT_LARGE_OP(char, T_out);                \ | ||||
|   REGISTER_CAT_LARGE_OP(uchar, T_out);               \ | ||||
|   REGISTER_CAT_LARGE_OP(bool, T_out); | ||||
|  | ||||
| #define REGISTER_CAT_FOR_INDEX_TYPE(I)        \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float);  \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half);   \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int);    \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint);   \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long);   \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong);  \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short);  \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char);   \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar);  \ | ||||
|   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool);   \ | ||||
|                                               \ | ||||
|   REGISTER_CAT_OP(I, float2, float2);         \ | ||||
|   REGISTER_CAT_OP(I, half2, half2); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar); | ||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool); | ||||
|  | ||||
| REGISTER_CAT_FOR_INDEX_TYPE(int64_t); | ||||
| REGISTER_CAT_FOR_INDEX_TYPE(int32_t); | ||||
| REGISTER_CAT_LARGE_OP(float2, float2); | ||||
| REGISTER_CAT_LARGE_OP(half2, half2); | ||||
|  | ||||
| @ -907,8 +907,6 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te | ||||
|   TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, | ||||
|               "index_fill_(): Expected dtype int32 or int64 for index"); | ||||
|   TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor"); | ||||
|   TORCH_CHECK(self.is_complex() || !source.is_complex(), | ||||
|               "index_fill_(): Converting complex Scalar to non-complex type is not supported"); | ||||
|   // MPS.scatter crashes if used with complex dtypes | ||||
|   TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported"); | ||||
|  | ||||
|  | ||||
| @ -196,28 +196,6 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) | ||||
|        other.size(0) > max_stride_size || other.size(1) > max_stride_size); | ||||
| } | ||||
|  | ||||
| void map_mps_decomposition_error_code_to_blas(const Tensor& status) { | ||||
|   const auto& status_flat = status.view(-1); | ||||
|  | ||||
|   for (const auto i : c10::irange(status_flat.size(0))) { | ||||
|     int code = status_flat[i].item<int>(); | ||||
|     switch (code) { | ||||
|       case MPSMatrixDecompositionStatusSuccess: | ||||
|         status_flat[i] = 0; | ||||
|         break; | ||||
|       case MPSMatrixDecompositionStatusNonPositiveDefinite: | ||||
|       case MPSMatrixDecompositionStatusSingular: | ||||
|         status_flat[i] = 2; | ||||
|         break; | ||||
|       case MPSMatrixDecompositionStatusFailure: | ||||
|         status_flat[i] = -1; | ||||
|         break; | ||||
|       default: | ||||
|         TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, | ||||
| @ -509,9 +487,6 @@ static void linalg_solve_out_mps_impl(const Tensor& A, | ||||
|                   "mpsmatrixdecompositionstatus for details."); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   map_mps_decomposition_error_code_to_blas(info); | ||||
|  | ||||
|   if (!left) { | ||||
|     // If this was a right solve, transpose the result back | ||||
|     result.copy_(result_t.transpose(-2, -1).contiguous()); | ||||
|  | ||||
| @ -3,7 +3,6 @@ | ||||
| #include <ATen/MemoryOverlap.h> | ||||
| #include <ATen/WrapDimUtils.h> | ||||
| #include <ATen/mps/MPSProfiler.h> | ||||
| #include <ATen/native/Pool.h> | ||||
| #include <ATen/native/TensorShape.h> | ||||
| #include <ATen/native/TypeProperties.h> | ||||
| #include <ATen/native/mps/OperationUtils.h> | ||||
| @ -70,40 +69,29 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| std::string get_type_str(); | ||||
|  | ||||
| template <> | ||||
| std::string get_type_str<int64_t>() { | ||||
|   return "int64_t"; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| std::string get_type_str<int32_t>() { | ||||
|   return "int32_t"; | ||||
| } | ||||
|  | ||||
| // This implementation of cat is used only if one of the inputs or the output is | ||||
| // too large to use MPSGraph. | ||||
| // NOTE: `output` is expected to already have the correct size. | ||||
| template <typename idx_type_t> | ||||
| static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) { | ||||
|   CatSharedParams<idx_type_t> shared_params; | ||||
| static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) { | ||||
|   CatLargeSharedParams shared_params; | ||||
|  | ||||
|   shared_params.ndim = output.dim(); | ||||
|   shared_params.cat_dim = dimension; | ||||
|  | ||||
|   for (const auto dim : c10::irange(output.dim())) { | ||||
|     shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim)); | ||||
|     shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim)); | ||||
|     shared_params.output_strides[dim] = output.stride(dim); | ||||
|     shared_params.output_sizes[dim] = output.size(dim); | ||||
|   } | ||||
|  | ||||
|   idx_type_t cat_dim_offset = 0; | ||||
|   int64_t cat_dim_offset = 0; | ||||
|   size_t input_idx = 0; | ||||
|   MPSStream* stream = getCurrentMPSStream(); | ||||
|  | ||||
|   // Launch a separate kernels for each input. This will produce some overhead. | ||||
|   // In order to launch only one kernel to process all inputs, we would have to | ||||
|   // copy all the input tensor data into a packed buffer, which would not be | ||||
|   // ideal. | ||||
|   // Launch a separate kernels for each input. This will produce some overhead, | ||||
|   // but that should be relatively minimal since at least one of the inputs is | ||||
|   // very large. In order to launch only one kernel to process all inputs, we | ||||
|   // would have to copy all the input tensor data into a packed buffer, which | ||||
|   // would not be ideal. | ||||
|   for (const Tensor& input : inputs) { | ||||
|     if (input.numel() == 0) { | ||||
|       continue; | ||||
| @ -116,23 +104,21 @@ static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, co | ||||
|  | ||||
|     for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) { | ||||
|       auto num_threads = std::min(max_num_threads, numel_remaining); | ||||
|       CatInputParams<idx_type_t> input_params; | ||||
|       CatLargeInputParams input_params; | ||||
|  | ||||
|       input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset); | ||||
|       input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining); | ||||
|       input_params.cat_dim_offset = cat_dim_offset; | ||||
|       input_params.input_element_offset = input.numel() - numel_remaining; | ||||
|  | ||||
|       for (const auto dim : c10::irange(input.dim())) { | ||||
|         input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim)); | ||||
|         input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim)); | ||||
|         input_params.input_strides[dim] = input.stride(dim); | ||||
|         input_params.input_sizes[dim] = input.size(dim); | ||||
|       } | ||||
|  | ||||
|       dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|         @autoreleasepool { | ||||
|           id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder(); | ||||
|           auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}", | ||||
|                                                                         get_type_str<idx_type_t>(), | ||||
|                                                                         scalarToMetalTypeString(input), | ||||
|                                                                         scalarToMetalTypeString(output))); | ||||
|           auto pipeline_state = lib.getPipelineStateForFunc( | ||||
|               fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output))); | ||||
|           getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input}); | ||||
|           [computeEncoder setComputePipelineState:pipeline_state]; | ||||
|           mtl_setArgs(computeEncoder, input, output, shared_params, input_params); | ||||
| @ -308,6 +294,13 @@ TORCH_IMPL_FUNC(cat_out_mps) | ||||
|               " and out is on ", | ||||
|               out.device()); | ||||
|  | ||||
|   // TODO: For better performance by eliminating input tensor gathering and post transpose, | ||||
|   // TODO: it is better to keep the out tensor's memory format. | ||||
|   // TODO: dimension needs to be recomputed as: | ||||
|   // TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1 | ||||
|   if (needsGather(out)) { | ||||
|     out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); | ||||
|   } | ||||
|   std::vector<int64_t> size(notSkippedTensor.sizes().vec()); | ||||
|  | ||||
|   // Compute size of the result in the cat dimension | ||||
| @ -338,9 +331,82 @@ TORCH_IMPL_FUNC(cat_out_mps) | ||||
|   has_large_tensor |= isTooLargeForMPSGraph(out); | ||||
|  | ||||
|   if (has_large_tensor) { | ||||
|     return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out); | ||||
|   } else { | ||||
|     return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out); | ||||
|     return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out); | ||||
|   } | ||||
|  | ||||
|   struct CachedGraph : public MPSCachedGraph { | ||||
|     CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} | ||||
|     std::vector<MPSGraphTensor*> inputTensors_; | ||||
|     MPSGraphTensor* outputTensor_ = nil; | ||||
|   }; | ||||
|  | ||||
|   @autoreleasepool { | ||||
|     std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" + | ||||
|         (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); | ||||
|     if (!all_same_dtype) { | ||||
|       key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride); | ||||
|     } else { | ||||
|       key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size()); | ||||
|     } | ||||
|     for (auto idx : skipped_tensor_indices) { | ||||
|       key += "," + std::to_string(idx); | ||||
|     } | ||||
|  | ||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||
|       auto len_tensor_array = inputs.size() - skipped_tensor_indices.size(); | ||||
|       std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array); | ||||
|       newCachedGraph->inputTensors_.reserve(len_tensor_array); | ||||
|  | ||||
|       for (const auto idx : c10::irange(len_tensor_array)) { | ||||
|         const Tensor& tensor = input_tensors[idx]; | ||||
|         auto scalar_type = getMPSScalarType(tensor.scalar_type()); | ||||
|         if (tensor.scalar_type() == kBool) { | ||||
|           scalar_type = MPSDataTypeInt8; | ||||
|         } | ||||
|         newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type); | ||||
|         if (tensor.scalar_type() != out_dtype) { | ||||
|           castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx] | ||||
|                                                 toType:getMPSDataType(out_dtype) | ||||
|                                                   name:@"castInput"]; | ||||
|         } else { | ||||
|           castInputTensors[idx] = newCachedGraph->inputTensors_[idx]; | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array]; | ||||
|       MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray | ||||
|                                                    dimension:dimension // Maybe convert this from int64_t -> int32 | ||||
|                                                         name:nil]; | ||||
|       if (getMPSDataType(out_dtype) == MPSDataTypeBool) { | ||||
|         outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"]; | ||||
|       } | ||||
|       newCachedGraph->outputTensor_ = outputTensor; | ||||
|     }); | ||||
|  | ||||
|     std::vector<Placeholder> inputPlaceholders; | ||||
|     int i = 0; | ||||
|     int t_idx = 0; | ||||
|     for (const Tensor& tensor : materialized_inputs) { | ||||
|       if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) { | ||||
|         auto scalar_type = getMPSScalarType(tensor.scalar_type()); | ||||
|         if (tensor.scalar_type() == kBool) { | ||||
|           scalar_type = MPSDataTypeInt8; | ||||
|         } | ||||
|         inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type); | ||||
|         t_idx++; | ||||
|       } | ||||
|       i++; | ||||
|     } | ||||
|  | ||||
|     auto outputDataType = getMPSScalarType(out.scalar_type()); | ||||
|     Placeholder outputPlaceholder = | ||||
|         Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType); | ||||
|  | ||||
|     NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; | ||||
|     for (auto& inputPlaceholder : inputPlaceholders) { | ||||
|       feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); | ||||
|     } | ||||
|     runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -1370,7 +1370,6 @@ | ||||
|   dispatch: | ||||
|     SparseCPU: bmm_sparse_cpu | ||||
|     SparseCUDA: bmm_sparse_cuda | ||||
|     SparseMPS: bmm_sparse_mps | ||||
|     NestedTensorCPU: bmm_nested | ||||
|     NestedTensorCUDA: bmm_nested_cuda | ||||
|   tags: core | ||||
| @ -1386,7 +1385,6 @@ | ||||
|     MTIA: bmm_out_mtia | ||||
|     SparseCPU: bmm_out_sparse_cpu | ||||
|     SparseCUDA: bmm_out_sparse_cuda | ||||
|     SparseMPS: bmm_out_sparse_mps | ||||
|     SparseCsrCUDA: bmm_out_sparse_csr_cuda | ||||
|  | ||||
| - func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor | ||||
| @ -4175,7 +4173,7 @@ | ||||
|   structured_delegate: mm.out | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     SparseCPU, SparseCUDA, SparseMPS: _sparse_mm | ||||
|     SparseCPU, SparseCUDA: _sparse_mm | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm | ||||
|   tags: core | ||||
|  | ||||
| @ -6533,7 +6531,6 @@ | ||||
|   dispatch: | ||||
|     CPU, CUDA: var | ||||
|     MPS: var_mps | ||||
|     MTIA: var_mtia | ||||
|   tags: core | ||||
|  | ||||
| - func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||
| @ -7114,7 +7111,6 @@ | ||||
|     MTIA: addmm_out_mtia | ||||
|     SparseCPU: addmm_out_sparse_dense_cpu | ||||
|     SparseCUDA: addmm_out_sparse_dense_cuda | ||||
|     SparseMPS: addmm_out_sparse_dense_mps | ||||
|     SparseCsrCPU: addmm_out_sparse_compressed_cpu | ||||
|     SparseCsrCUDA: addmm_out_sparse_compressed_cuda | ||||
|  | ||||
| @ -7124,7 +7120,6 @@ | ||||
|   dispatch: | ||||
|     SparseCPU: addmm_sparse_dense_cpu | ||||
|     SparseCUDA: addmm_sparse_dense_cuda | ||||
|     SparseMPS: addmm_sparse_dense_mps | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense | ||||
|   tags: core | ||||
|  | ||||
| @ -7389,7 +7384,7 @@ | ||||
| - func: sparse_mask(Tensor self, Tensor mask) -> Tensor | ||||
|   variants: method | ||||
|   dispatch: | ||||
|     SparseCPU, SparseCUDA, SparseMPS: sparse_mask | ||||
|     SparseCPU, SparseCUDA: sparse_mask | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed | ||||
|   autogen: sparse_mask.out | ||||
|  | ||||
|  | ||||
| @ -184,23 +184,15 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba | ||||
|           0 & \text{ else } | ||||
|         \end{cases} | ||||
|   */ | ||||
|  | ||||
|   bool is_bfloat16 = (X.scalar_type() == at::kBFloat16); | ||||
|  | ||||
|   at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X; | ||||
|   at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY; | ||||
|   at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale; | ||||
|   at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point; | ||||
|  | ||||
|   float scale_val = scale_[0].item<float>(); | ||||
|   float scale_val = scale[0].item<float>(); | ||||
|   float inv_scale_val = 1.0f / scale_val; | ||||
|   int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false); | ||||
|   int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false); | ||||
|  | ||||
|   TORCH_CHECK(dY_.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(X_.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(scale_.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size"); | ||||
|   TORCH_CHECK(dY.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(X.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(scale.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float); | ||||
|   TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size"); | ||||
|   TORCH_CHECK( | ||||
|       quant_min <= 0 && quant_max >= 0, | ||||
|       "`quant_min` should be less than or \ | ||||
| @ -208,28 +200,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba | ||||
|   TORCH_CHECK( | ||||
|       zero_point_val >= quant_min && zero_point_val <= quant_max, | ||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||
|   if (X_.numel() <= 0) { | ||||
|   if (X.numel() <= 0) { | ||||
|     return std::make_tuple(X, scale, zero_point); | ||||
|   } | ||||
|  | ||||
|   auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve); | ||||
|   auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve); | ||||
|   auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve); | ||||
|   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|   auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|   auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve); | ||||
|  | ||||
|   auto iter = TensorIteratorConfig() | ||||
|     .add_output(dX) | ||||
|     .add_output(dScale_vec) | ||||
|     .add_output(dZeroPoint_vec) | ||||
|     .add_input(X_) | ||||
|     .add_input(dY_) | ||||
|     .add_input(X) | ||||
|     .add_input(dY) | ||||
|     .build(); | ||||
|  | ||||
|   fake_quant_grad_learnable_tensor_stub( | ||||
|     X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor); | ||||
|     X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor); | ||||
|  | ||||
|   // The total sums over the scale and zero point gradient vectors are what will be returned in the end. | ||||
|   auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device()); | ||||
|   auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device()); | ||||
|   auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device()); | ||||
|   auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device()); | ||||
|  | ||||
|   return std::make_tuple(dX, dScale, dZeroPoint); | ||||
| } | ||||
|  | ||||
| @ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu( | ||||
|  | ||||
| #if defined(__ARM_NEON__) || defined(__aarch64__) | ||||
|  | ||||
| constexpr static int PARALLEL_THRESHOLD = 1 << 20; | ||||
| const static int PARALLEL_THRESHOLD = 1 << 20; | ||||
|  | ||||
| // Generic template defaults to naive quantize implementation | ||||
| template <typename T> | ||||
|  | ||||
| @ -1388,7 +1388,7 @@ namespace at::native { | ||||
|     TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1, | ||||
|         "onednn int8 linear: act scale/zp size should be 1/<=1"); | ||||
|     static std::optional<at::Tensor> other = std::nullopt; | ||||
|     constexpr std::string_view binary_post_op = "none"; | ||||
|     static const std::string_view binary_post_op = "none"; | ||||
|     int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0; | ||||
|     return linear_int8_with_onednn_weight( | ||||
|         act, act_scale.item().toDouble(), act_zp, | ||||
|  | ||||
| @ -16,8 +16,8 @@ namespace { | ||||
|  | ||||
| #ifdef USE_PYTORCH_QNNPACK | ||||
|  | ||||
| constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f; | ||||
| constexpr static int qnnpack_softmax_output_zero_point = 0; | ||||
| const static float qnnpack_softmax_output_scale = 0x1.0p-8f; | ||||
| const static int qnnpack_softmax_output_zero_point = 0; | ||||
|  | ||||
| bool is_qnnpack_compatible( | ||||
|     const Tensor& qx, | ||||
|  | ||||
| @ -1,9 +1,6 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/native/SparseTensorUtils.h> | ||||
| #include <ATen/ExpandUtils.h> | ||||
| #include <ATen/native/mps/OperationUtils.h> | ||||
| #include <ATen/native/sparse/SparseStubs.h> | ||||
| #include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -16,11 +13,7 @@ | ||||
| #include <ATen/ops/mul_native.h> | ||||
| #include <ATen/ops/empty_native.h> | ||||
| #include <ATen/ops/zeros_native.h> | ||||
| #include <ATen/ops/ones_like.h> | ||||
| #include <ATen/ops/argsort.h> | ||||
| #include <ATen/ops/result_type.h> | ||||
| #include <ATen/ops/bmm_native.h> | ||||
| #include <ATen/ops/addmm_native.h> | ||||
| #include <ATen/ops/copy_sparse_to_sparse.h> | ||||
| #include <ATen/ops/mul.h> | ||||
| #endif | ||||
| @ -36,305 +29,6 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary(); | ||||
| #include <ATen/native/mps/Mul_metallib.h> | ||||
| #endif | ||||
|  | ||||
| static Tensor& s_addmm_out_sparse_dense_mps( | ||||
|     Tensor& r, | ||||
|     const Tensor& t, | ||||
|     const SparseTensor& sparse_, | ||||
|     const Tensor& dense, | ||||
|     const Scalar& beta, | ||||
|     const Scalar& alpha) { | ||||
|   TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim()); | ||||
|   TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim()); | ||||
|   TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim()); | ||||
|   TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim()); | ||||
|  | ||||
|   const int64_t I = sparse_.size(0); | ||||
|   const int64_t J = sparse_.size(1); | ||||
|   const int64_t K = dense.size(1); | ||||
|  | ||||
|   TORCH_CHECK(dense.size(0) == J, | ||||
|       "addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0)); | ||||
|   TORCH_CHECK(t.size(0) == I && t.size(1) == K, | ||||
|       "addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")"); | ||||
|  | ||||
|   r.resize_({I, K}); | ||||
|  | ||||
|   auto sparse = sparse_.coalesce(); | ||||
|   const int64_t nnz = sparse._nnz(); | ||||
|  | ||||
|   if (nnz == 0 || I == 0 || K == 0) { | ||||
|     at::mul_out(r, t, beta); | ||||
|     return r; | ||||
|   } | ||||
|  | ||||
|   const auto v_dtype = sparse._values().scalar_type(); | ||||
|   const auto d_dtype = dense.scalar_type(); | ||||
|   const auto t_dtype = t.scalar_type(); | ||||
|   auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype); | ||||
|  | ||||
|   TORCH_CHECK(canCast(compute_dtype, r.scalar_type()), | ||||
|               "Can't convert computed type ", compute_dtype, " to output ", r.scalar_type()); | ||||
|  | ||||
|   auto indices2d = sparse._indices().contiguous(); | ||||
|   auto values = sparse._values().to(compute_dtype); | ||||
|   auto dense_c = dense.to(compute_dtype).contiguous(); | ||||
|   auto t_c = t.to(compute_dtype).contiguous(); | ||||
|  | ||||
|   const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous(); | ||||
|   Tensor out_buf = out_needs_cast | ||||
|       ? at::empty({I, K}, r.options().dtype(compute_dtype)) | ||||
|       : r; | ||||
|   auto out_contig = out_buf.contiguous(); | ||||
|  | ||||
|   auto device = r.device(); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   const float alpha_f = alpha.to<float>(); | ||||
|   const float beta_f  = beta.to<float>(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values); | ||||
|       auto pso = lib.getPipelineStateForFunc(func); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t gridX = static_cast<uint32_t>(K); | ||||
|       const uint32_t gridZ = static_cast<uint32_t>(I); | ||||
|       const uint32_t tgW = std::min<uint32_t>(gridX, tew); | ||||
|  | ||||
|       MTLSize grid = MTLSizeMake(gridX, 1, gridZ); | ||||
|       MTLSize tgs = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   indices2d, | ||||
|                   values, | ||||
|                   dense_c, | ||||
|                   t_c, | ||||
|                   out_contig, | ||||
|                   std::array<uint32_t, 3>{static_cast<uint32_t>(I), | ||||
|                                            static_cast<uint32_t>(J), | ||||
|                                            static_cast<uint32_t>(K)}, | ||||
|                   std::array<float, 2>{alpha_f, beta_f}, | ||||
|                   static_cast<uint32_t>(nnz)); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   if (out_needs_cast) { | ||||
|     r.copy_(out_contig.to(r.scalar_type())); | ||||
|   } | ||||
|  | ||||
|   return r; | ||||
| } | ||||
|  | ||||
|  | ||||
| static void build_batch_ptr_mps( | ||||
|     const Tensor& indices_dim0, | ||||
|     int64_t B, | ||||
|     Tensor& batch_ptr | ||||
| ) { | ||||
|   // Builds an array of pointers which point to each batches elements. Example: | ||||
|   // idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2]  // 9 non-zero elements | ||||
|   //          └─────┘  └──┘  └─────────┘ | ||||
|   //          batch 0  batch 1  batch 2 | ||||
|   // batch_ptr = [0, 3, 5, 9] | ||||
|   //              │  │  │  └─ end of batch 2 (total nnz) | ||||
|   //              │  │  └──── batch 2 starts at index 5 | ||||
|   //              │  └─────── batch 1 starts at index 3 | ||||
|   //              └────────── batch 0 starts at index 0 | ||||
|   TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected"); | ||||
|   auto device = indices_dim0.device(); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   const int64_t nnz = indices_dim0.numel(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches"); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t Q = static_cast<uint32_t>(B + 1); | ||||
|       const uint32_t tgW = std::min<uint32_t>(Q, tew); | ||||
|       MTLSize grid = MTLSizeMake(Q, 1, 1); | ||||
|       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   indices_dim0, | ||||
|                   batch_ptr, | ||||
|                   std::array<uint32_t, 2>{static_cast<uint32_t>(nnz), | ||||
|                                           static_cast<uint32_t>(B)}); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| static void build_row_ptr_per_batch_mps( | ||||
|     const Tensor& rows, | ||||
|     const Tensor& batch_ptr, | ||||
|     int64_t B, | ||||
|     int64_t I, | ||||
|     Tensor& row_ptr | ||||
| ) { | ||||
|   // Build per-batch CSR-style row pointer arrays from row indices sorted by batch | ||||
|   // Given: | ||||
|   //   rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch | ||||
|   //   batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b | ||||
|   // Produces: | ||||
|   //   - row_ptr: shape [B, I+1] | ||||
|   // | ||||
|   // Example (B = 2, I = 4): | ||||
|   // rows       = [0,   0,   1,  3,  0,   2,    2]   // 7 non-zero elements | ||||
|   //               └─── batch 0 ──┘  └─ batch 1 ─┘ | ||||
|   // batch_ptr  = [0, 4, 7] | ||||
|   //               │  │  └─ end of batch 1 (total nnz) | ||||
|   //               │  └──── end of batch 0/start of batch 1 | ||||
|   //               └─────── start of batch 0 | ||||
|   // | ||||
|   // per-batch row pointers (I+1 entries each): | ||||
|   //   row_ptr[0] = [0, 2, 3, 3, 4] | ||||
|   //   row_ptr[1] = [0, 1, 1, 3, 3] | ||||
|   // laid out in memory: [0, 2, 3, 3, 4,  0, 1, 1, 3, 3] | ||||
|   TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected"); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch"); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t Qx = static_cast<uint32_t>(I + 1); | ||||
|       const uint32_t Qy = static_cast<uint32_t>(B); | ||||
|       const uint32_t tgW = std::min<uint32_t>(Qx, tew); | ||||
|  | ||||
|       MTLSize grid = MTLSizeMake(Qx, Qy, 1); | ||||
|       MTLSize tgs = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   rows, | ||||
|                   batch_ptr, | ||||
|                   row_ptr, | ||||
|                   std::array<uint32_t, 2>{static_cast<uint32_t>(I), | ||||
|                                            static_cast<uint32_t>(B)}); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) { | ||||
|   TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device()); | ||||
|   TORCH_CHECK(self_.is_mps(),  "bmm_sparse: expected 'self' to be MPS, got ", self_.device()); | ||||
|   TORCH_CHECK(mat2_.is_mps(),  "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device()); | ||||
|  | ||||
|   TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim()); | ||||
|   TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim()); | ||||
|   TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim()); | ||||
|  | ||||
|   TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match"); | ||||
|   TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match"); | ||||
|  | ||||
|   const int64_t B = self_.size(0); | ||||
|   const int64_t I = self_.size(1); | ||||
|   const int64_t J = self_.size(2); | ||||
|   const int64_t K = mat2_.size(2); | ||||
|  | ||||
|   auto self = self_.coalesce(); | ||||
|   const int64_t nnz = self._nnz(); | ||||
|   if (nnz == 0) { | ||||
|     return result_.zero_(); | ||||
|   } | ||||
|  | ||||
|   const auto computeDtype = at::kFloat; | ||||
|  | ||||
|   auto indices = self._indices(); | ||||
|   auto values  = self._values(); | ||||
|  | ||||
|   auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype); | ||||
|   auto mat2_c = mat2_.scalar_type()   == computeDtype ? mat2_   : mat2_.to(computeDtype); | ||||
|   auto mat2_contig = mat2_c.contiguous(); | ||||
|  | ||||
|   auto idx_b = indices.select(0, 0).contiguous(); | ||||
|   auto idx_i = indices.select(0, 1).contiguous(); | ||||
|   auto idx_j = indices.select(0, 2).contiguous(); | ||||
|  | ||||
|   // builds an array of pointers of where the batch_idx's pointer starts and ends | ||||
|   // look in function for better explanation | ||||
|   auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong)); | ||||
|   build_batch_ptr_mps(idx_b, B, batch_ptr); | ||||
|   // build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals | ||||
|   auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong)); | ||||
|   build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr); | ||||
|  | ||||
|   const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous(); | ||||
|   Tensor out_buf = out_needs_cast | ||||
|       ? at::empty({B, I, K}, result_.options().dtype(computeDtype)) | ||||
|       : result_; | ||||
|   auto out_contig = out_buf.contiguous(); | ||||
|  | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values)); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|  | ||||
|       const uint32_t tew = pso.threadExecutionWidth; | ||||
|       const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew); | ||||
|  | ||||
|       // One threadgroup per (row i, batch b), lanes cover K | ||||
|       MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B); | ||||
|       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||
|  | ||||
|       mtl_setArgs(enc, | ||||
|                   idx_i, | ||||
|                   idx_j, | ||||
|                   values_c, | ||||
|                   mat2_contig, | ||||
|                   out_contig, | ||||
|                   row_ptr, | ||||
|                   std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K}); | ||||
|       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||
|     } | ||||
|   }); | ||||
|   if (out_needs_cast) { | ||||
|     result_.copy_(out_contig.to(result_.scalar_type())); | ||||
|   } | ||||
|   return result_; | ||||
| } | ||||
|  | ||||
| Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) { | ||||
|   Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options()); | ||||
|   return bmm_out_sparse_mps(self, mat2, result); | ||||
| } | ||||
|  | ||||
| Tensor& addmm_out_sparse_dense_mps( | ||||
|     const Tensor& self, | ||||
|     const SparseTensor& mat1, | ||||
|     const Tensor& mat2, | ||||
|     const Scalar& beta, | ||||
|     const Scalar& alpha, | ||||
|     Tensor& result) { | ||||
|   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); | ||||
|   return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha); | ||||
| } | ||||
|  | ||||
| Tensor addmm_sparse_dense_mps( | ||||
|     const Tensor& self, | ||||
|     const SparseTensor& mat1, | ||||
|     const Tensor& mat2, | ||||
|     const Scalar& beta, | ||||
|     const Scalar& alpha | ||||
| ) { | ||||
|   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); | ||||
|   Tensor result = at::empty({0}, self.options()); | ||||
|   return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha); | ||||
| } | ||||
|  | ||||
| static SparseTensor& mul_out_dense_sparse_mps( | ||||
|     const Tensor& dense, | ||||
|     const Tensor& sparse, | ||||
| @ -742,137 +436,4 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self, | ||||
|   return out; | ||||
| } | ||||
|  | ||||
| using OptTensor = std::optional<Tensor>; | ||||
|  | ||||
|  | ||||
| static void sparse_mask_apply_out_mps_kernel( | ||||
|     Tensor& result, | ||||
|     const Tensor& src_in, | ||||
|     const Tensor& mask_in, | ||||
|     bool accumulate_matches, | ||||
|     bool require_same_sizes, | ||||
|     bool coalesce_mask) { | ||||
|   TORCH_CHECK(src_in.is_sparse() && mask_in.is_sparse(), | ||||
|               "sparse_mask: expected both inputs to be sparse COO"); | ||||
|   TORCH_CHECK(src_in.is_mps() && mask_in.is_mps(), | ||||
|               "sparse_mask: expected tensors to be on MPS device"); | ||||
|   TORCH_CHECK(src_in.sparse_dim() == mask_in.sparse_dim(), | ||||
|               "sparse_mask: sparse_dim mismatch: ", src_in.sparse_dim(), " vs ", mask_in.sparse_dim()); | ||||
|   if (require_same_sizes) { | ||||
|     TORCH_CHECK(src_in.sizes().equals(mask_in.sizes()), | ||||
|                 "sparse_mask: sizes must match exactly (no broadcasting)"); | ||||
|   } | ||||
|   auto src  = src_in.coalesce(); | ||||
|   auto mask = coalesce_mask ? mask_in.coalesce() : mask_in; | ||||
|  | ||||
|   const int64_t src_nnz = src._nnz(); | ||||
|   const int64_t mask_nnz = mask._nnz(); | ||||
|   const int64_t sd = src.sparse_dim(); | ||||
|   result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim()); | ||||
|  | ||||
|   auto commonDtype = at::result_type(src, mask); | ||||
|   TORCH_CHECK(canCast(commonDtype, result.scalar_type()), | ||||
|               "Can't convert result type ", commonDtype, " to output ", result.scalar_type()); | ||||
|  | ||||
|   if (mask_nnz == 0) { | ||||
|     alias_into_sparse( | ||||
|         result, | ||||
|         mask._indices().narrow(1, 0, 0), | ||||
|         at::empty({0}, result.options().dtype(result.scalar_type()))); | ||||
|     result._coalesced_(mask.is_coalesced()); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(sd > 0 || (src_nnz <= 1 && mask_nnz <= 1), | ||||
|               "sparse_mask: invalid sparse_dim or nnz"); | ||||
|  | ||||
|   if (sd == 0) { | ||||
|     auto out_indices = mask._indices().narrow(1, 0, 1); | ||||
|     auto out_values = src_nnz | ||||
|       ? src._values().narrow(0, 0, 1).to(commonDtype) | ||||
|       : at::zeros({1}, at::device(result.device()).dtype(commonDtype)); | ||||
|     alias_into_sparse(result, out_indices, out_values); | ||||
|     result._coalesced_(mask.is_coalesced()); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (src_nnz == 0) { | ||||
|     auto out_indices = mask._indices().contiguous(); | ||||
|     auto src_values  = src._values().to(commonDtype); | ||||
|     auto out_val_sizes = src_values.sizes().vec(); | ||||
|     out_val_sizes[0] = mask_nnz; | ||||
|     auto out_values = at::zeros(out_val_sizes, src_values.options()); | ||||
|     alias_into_sparse(result, out_indices, out_values); | ||||
|     result._coalesced_(mask.is_coalesced()); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   auto mask_indices = mask._indices().contiguous(); | ||||
|   auto src_indices = src._indices().contiguous(); | ||||
|   auto src_values = src._values().to(commonDtype).contiguous(); | ||||
|  | ||||
|   auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous(); | ||||
|   auto src_keys  = flatten_indices(src_indices,  src.sizes().slice(0, sd)).contiguous(); | ||||
|  | ||||
|   const bool A_is_src = (src_nnz <= mask_nnz); | ||||
|   const int64_t lenA = A_is_src ? src_nnz  : mask_nnz; | ||||
|   const int64_t lenB = A_is_src ? mask_nnz : src_nnz; | ||||
|   auto A_keys = A_is_src ? src_keys  : mask_keys; | ||||
|   auto B_keys = A_is_src ? mask_keys : src_keys; | ||||
|  | ||||
|   const auto device = result.device(); | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|  | ||||
|   auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong)); | ||||
|   auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong)); | ||||
|   auto counter = at::zeros({1}, at::device(device).dtype(at::kInt)); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto pso = lib.getPipelineStateForFunc("intersect_binary_search"); | ||||
|       auto enc = stream->commandEncoder(); | ||||
|       [enc setComputePipelineState:pso]; | ||||
|       mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter, | ||||
|                   static_cast<uint32_t>(lenB), A_is_src); | ||||
|       mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA)); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   const int64_t M = static_cast<int64_t>(counter.item<int32_t>()); | ||||
|  | ||||
|   auto out_val_sizes = src_values.sizes().vec(); | ||||
|   out_val_sizes[0] = mask_nnz; | ||||
|   auto out_values = at::zeros(out_val_sizes, src_values.options()); | ||||
|  | ||||
|   if (M > 0) { | ||||
|     auto src_match = outA_idx.narrow(0, 0, M); | ||||
|     auto mask_match = outB_idx.narrow(0, 0, M); | ||||
|  | ||||
|     auto src_rows = src_values.index_select(0, src_match); | ||||
|     if (accumulate_matches) { | ||||
|       out_values.index_add_(0, mask_match, src_rows); | ||||
|     } else { | ||||
|       out_values.index_copy_(0, mask_match, src_rows); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   alias_into_sparse(result, mask_indices, out_values); | ||||
|   result._coalesced_(mask.is_coalesced()); | ||||
| } | ||||
|  | ||||
| static void sparse_mask_intersection_out_mps_kernel( | ||||
|     Tensor& result, | ||||
|     const Tensor& lhs, | ||||
|     const Tensor& rhs, | ||||
|     const OptTensor& = std::nullopt) { | ||||
|   sparse_mask_apply_out_mps_kernel( | ||||
|       result, | ||||
|       /*src_in=*/lhs, | ||||
|       /*mask_in=*/rhs, | ||||
|       /*accumulate_matches=*/false, | ||||
|       /*require_same_sizes=*/false, | ||||
|       /*coalesce_mask=*/false); | ||||
| } | ||||
|  | ||||
| REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel); | ||||
| } // namespace at::native | ||||
| @ -1,105 +1,7 @@ | ||||
| #include <metal_stdlib> | ||||
| #include <c10/metal/indexing.h> | ||||
| #include <c10/metal/utils.h> | ||||
| using namespace c10::metal; | ||||
| using namespace metal; | ||||
|  | ||||
| inline uint lower_bound_i64(device const long* arr, uint lo, uint hi, long key) { | ||||
|   uint l = lo, r = hi; | ||||
|   while (l < r) { | ||||
|     uint m = (l + r) >> 1; | ||||
|     long v = arr[m]; | ||||
|     if (v < key) { | ||||
|       l = m + 1; | ||||
|     } else { | ||||
|       r = m; | ||||
|     } | ||||
|   } | ||||
|   return l; | ||||
| } | ||||
|  | ||||
| inline uint upper_bound_i64(device const long* arr, uint lo, uint hi, long key) { | ||||
|   uint l = lo, r = hi; | ||||
|   while (l < r) { | ||||
|     uint m = (l + r) >> 1; | ||||
|     long v = arr[m]; | ||||
|     if (v <= key) { | ||||
|       l = m + 1; | ||||
|     } else { | ||||
|       r = m; | ||||
|     } | ||||
|   } | ||||
|   return l; | ||||
| } | ||||
|  | ||||
| kernel void build_row_ptr_from_sorted_rows_by_batch( | ||||
|     device const long* rows        [[buffer(0)]], | ||||
|     device const long* batch_ptr   [[buffer(1)]], | ||||
|     device long*       row_ptr     [[buffer(2)]], | ||||
|     constant uint2&    dims        [[buffer(3)]], | ||||
|     uint3              tid         [[thread_position_in_grid]]) | ||||
| { | ||||
|   const uint I = dims.x; | ||||
|   const uint B = dims.y; | ||||
|  | ||||
|   const uint i = tid.x; | ||||
|   const uint b = tid.y; | ||||
|  | ||||
|   if (b >= B || i > I) return; | ||||
|  | ||||
|   const uint base = (uint)batch_ptr[b]; | ||||
|   const uint lim  = (uint)batch_ptr[b + 1]; | ||||
|  | ||||
|   const ulong out_base = (ulong)b * (ulong)(I + 1); | ||||
|  | ||||
|   if (i == I) { | ||||
|     row_ptr[out_base + (ulong)I] = (long)lim; | ||||
|   } else { | ||||
|     const long key = (long)i; | ||||
|     const uint pos = lower_bound_i64(rows, base, lim, key); | ||||
|     row_ptr[out_base + (ulong)i] = (long)pos; | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| kernel void spmm_bmm_coo_rows_grouped( | ||||
|     device const long*   rows      [[buffer(0)]], | ||||
|     device const long*   cols      [[buffer(1)]], | ||||
|     device const T*      vals      [[buffer(2)]], | ||||
|     device const T*      dense     [[buffer(3)]], | ||||
|     device T*            out       [[buffer(4)]], | ||||
|     device const long*   row_ptr   [[buffer(5)]], | ||||
|     constant uint4&      dims      [[buffer(6)]], | ||||
|     uint3                tid       [[thread_position_in_grid]], | ||||
|     uint3                ltid      [[thread_position_in_threadgroup]], | ||||
|     uint3                tptg      [[threads_per_threadgroup]]) | ||||
| { | ||||
|   const uint B = dims.x; | ||||
|   const uint I = dims.y; | ||||
|   const uint J = dims.z; | ||||
|   const uint K = dims.w; | ||||
|  | ||||
|   const uint b = tid.z; | ||||
|   const uint i = tid.y; | ||||
|   const uint lane = ltid.x; | ||||
|   const uint tgW  = tptg.x; | ||||
|  | ||||
|   const ulong rp_base = (ulong)b * (ulong)(I + 1); | ||||
|   const uint start = (uint)row_ptr[rp_base + (ulong)i]; | ||||
|   const uint end   = (uint)row_ptr[rp_base + (ulong)i + 1]; | ||||
|  | ||||
|   for (uint k = lane; k < K; k += tgW) { | ||||
|     auto acc = static_cast<accum_t<T>>(T(0)); | ||||
|     for (uint p = start; p < end; ++p) { | ||||
|       const uint c = (uint)cols[p]; | ||||
|       const auto v = static_cast<accum_t<T>>(vals[p]); | ||||
|       const uint d_off = ((b * J) + c) * K + k; | ||||
|       const auto d = static_cast<accum_t<T>>(dense[d_off]); | ||||
|       acc += mul(v, d); | ||||
|     } | ||||
|     const uint y_off = ((b * I) + i) * K + k; | ||||
|     out[y_off] = static_cast<T>(acc); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| kernel void dense_sparse_mul_kernel( | ||||
| @ -127,9 +29,9 @@ kernel void dense_sparse_mul_kernel( | ||||
|   ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col; | ||||
|   ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col; | ||||
|  | ||||
|   const auto a = static_cast<accum_t<T>>(values[val_idx]); | ||||
|   const auto b = static_cast<accum_t<T>>(dense[dense_idx]); | ||||
|   out_values[val_idx] = static_cast<T>(mul(a, b)); | ||||
|   const auto a = static_cast<float>(values[val_idx]); | ||||
|   const auto b = static_cast<float>(dense[dense_idx]); | ||||
|   out_values[val_idx] = static_cast<T>(a * b); | ||||
| } | ||||
|  | ||||
| kernel void intersect_binary_search( | ||||
| @ -214,76 +116,6 @@ kernel void fused_gather_mul_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| kernel void build_batch_ptr_from_sorted_batches( | ||||
|     device const long* batches       [[buffer(0)]], | ||||
|     device long*       batch_ptr     [[buffer(1)]], | ||||
|     constant uint2&    nnz_B         [[buffer(2)]], | ||||
|     uint3              tid           [[thread_position_in_grid]]) | ||||
| { | ||||
|   uint b = tid.x; | ||||
|   uint nnz = nnz_B.x; | ||||
|   uint batch = nnz_B.y; | ||||
|  | ||||
|   if (b == batch) { | ||||
|     batch_ptr[b] = (long)nnz; | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   uint lo = 0; | ||||
|   uint hi = nnz; | ||||
|   long key = (long)b; | ||||
|   while (lo < hi) { | ||||
|     uint mid = (lo + hi) >> 1; | ||||
|     long v = batches[mid]; | ||||
|     if (v < key) lo = mid + 1; | ||||
|     else         hi = mid; | ||||
|   } | ||||
|   batch_ptr[b] = (long)lo; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| kernel void spmm_addmm_coo( | ||||
|     device const long*   indices2d   [[buffer(0)]], | ||||
|     device const T*      vals        [[buffer(1)]], | ||||
|     device const T*      dense       [[buffer(2)]], | ||||
|     device const T*      t_in        [[buffer(3)]], | ||||
|     device T*            out         [[buffer(4)]], | ||||
|     constant uint3&      dims        [[buffer(5)]], | ||||
|     constant float2&     alpha_beta  [[buffer(6)]], | ||||
|     constant uint&       nnz         [[buffer(7)]], | ||||
|     uint3                tid         [[thread_position_in_grid]]) | ||||
| { | ||||
|   const uint K = dims.z; | ||||
|   const uint k = tid.x; | ||||
|   const uint i = tid.z; | ||||
|   const float alpha = alpha_beta.x; | ||||
|   const float beta = alpha_beta.y; | ||||
|  | ||||
|   device const long* rows = indices2d; | ||||
|   device const long* cols = indices2d + nnz; | ||||
|  | ||||
|   const uint start = lower_bound_i64(rows, 0u, nnz, (long)i); | ||||
|   const uint end = upper_bound_i64(rows, 0u, nnz, (long)i); | ||||
|  | ||||
|   // accumulator is float for scalar/half/bfloat and float2 for float2 | ||||
|   auto acc = static_cast<accum_t<T>>(T(0)); | ||||
|  | ||||
|   for (uint p = start; p < end; ++p) { | ||||
|     const uint c = (uint)cols[p]; | ||||
|     const auto v = static_cast<accum_t<T>>(vals[p]); | ||||
|     const uint dense_off = c * K + k; | ||||
|     const auto d = static_cast<accum_t<T>>(dense[dense_off]); | ||||
|     acc += mul(v, d); | ||||
|   } | ||||
|  | ||||
|   const uint off = i * K + k; | ||||
|   const auto base = (beta != 0.0f) ? (static_cast<accum_t<T>>(t_in[off]) * beta) : static_cast<accum_t<T>>(T(0)); | ||||
|   const auto y = base + alpha * acc; | ||||
|   out[off] = static_cast<T>(y); | ||||
| } | ||||
|  | ||||
|  | ||||
| #define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE)                                 \ | ||||
|   template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void     \ | ||||
|   dense_sparse_mul_kernel<DTYPE>(                                           \ | ||||
| @ -298,8 +130,6 @@ kernel void spmm_addmm_coo( | ||||
| INSTANTIATE_DENSE_SPARSE_MUL(float); | ||||
| INSTANTIATE_DENSE_SPARSE_MUL(half); | ||||
| INSTANTIATE_DENSE_SPARSE_MUL(bfloat); | ||||
| INSTANTIATE_DENSE_SPARSE_MUL(long); | ||||
| INSTANTIATE_DENSE_SPARSE_MUL(float2); | ||||
|  | ||||
| #define INSTANTIATE_FUSED_GATHER_MUL(DTYPE)                                  \ | ||||
|   template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void      \ | ||||
| @ -315,36 +145,6 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2); | ||||
|       constant uint2&     dims_output   [[buffer(8)]],                       \ | ||||
|       uint3               gid           [[thread_position_in_grid]]); | ||||
|  | ||||
| INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL); | ||||
|  | ||||
|  | ||||
| #define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE)                         \ | ||||
|   template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void    \ | ||||
|   spmm_bmm_coo_rows_grouped<DTYPE>(                                          \ | ||||
|       device const long*   rows      [[buffer(0)]],                          \ | ||||
|       device const long*   cols      [[buffer(1)]],                          \ | ||||
|       device const DTYPE*  vals      [[buffer(2)]],                          \ | ||||
|       device const DTYPE*  dense     [[buffer(3)]],                          \ | ||||
|       device DTYPE*        out       [[buffer(4)]],                          \ | ||||
|       device const long*   row_ptr   [[buffer(5)]],                          \ | ||||
|       constant uint4&      dims      [[buffer(6)]],                          \ | ||||
|       uint3                tid       [[thread_position_in_grid]],            \ | ||||
|       uint3                ltid      [[thread_position_in_threadgroup]],     \ | ||||
|       uint3                tptg      [[threads_per_threadgroup]]); | ||||
|  | ||||
| INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED); | ||||
|  | ||||
| #define INSTANTIATE_SPMM_ADDMM_COO(DTYPE) \ | ||||
|   template [[host_name("spmm_addmm_coo_" #DTYPE)]] kernel void  \ | ||||
|   spmm_addmm_coo<DTYPE>(                                        \ | ||||
|     device const long*   indices2d   [[buffer(0)]],             \ | ||||
|     device const DTYPE*  vals        [[buffer(1)]],             \ | ||||
|     device const DTYPE*  dense       [[buffer(2)]],             \ | ||||
|     device const DTYPE*  t_in        [[buffer(3)]],             \ | ||||
|     device DTYPE*        out         [[buffer(4)]],             \ | ||||
|     constant uint3&      dims        [[buffer(5)]],             \ | ||||
|     constant float2&     alpha_beta  [[buffer(6)]],             \ | ||||
|     constant uint&       nnz         [[buffer(7)]],             \ | ||||
|     uint3                tid         [[thread_position_in_grid]]); | ||||
|  | ||||
| INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_ADDMM_COO); | ||||
| INSTANTIATE_FUSED_GATHER_MUL(float); | ||||
| INSTANTIATE_FUSED_GATHER_MUL(half); | ||||
| INSTANTIATE_FUSED_GATHER_MUL(bfloat); | ||||
| @ -110,9 +110,9 @@ class ApplyLogSumExp { | ||||
|   using ElementCompute = ElementCompute_; | ||||
|   using ElementLSE = ElementLSE_; | ||||
|  | ||||
|   static int constexpr kElementsPerAccess = ElementsPerAccess; | ||||
|   static int constexpr kCount = kElementsPerAccess; | ||||
|   static constexpr ScaleType::Kind kScale = | ||||
|   static int const kElementsPerAccess = ElementsPerAccess; | ||||
|   static int const kCount = kElementsPerAccess; | ||||
|   static const ScaleType::Kind kScale = | ||||
|       cutlass::epilogue::thread::ScaleType::NoBetaScaling; | ||||
|  | ||||
|   using FragmentOutput = Array<ElementOutput, kCount>; | ||||
|  | ||||
| @ -14,16 +14,16 @@ using namespace at; | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| constexpr auto int_min = std::numeric_limits<int>::min(); | ||||
| constexpr auto int_max = std::numeric_limits<int>::max(); | ||||
| constexpr auto long_min = std::numeric_limits<int64_t>::min(); | ||||
| constexpr auto long_max = std::numeric_limits<int64_t>::max(); | ||||
| constexpr auto float_lowest = std::numeric_limits<float>::lowest(); | ||||
| constexpr auto float_min = std::numeric_limits<float>::min(); | ||||
| constexpr auto float_max = std::numeric_limits<float>::max(); | ||||
| constexpr auto double_lowest = std::numeric_limits<double>::lowest(); | ||||
| constexpr auto double_min = std::numeric_limits<double>::min(); | ||||
| constexpr auto double_max = std::numeric_limits<double>::max(); | ||||
| const auto int_min = std::numeric_limits<int>::min(); | ||||
| const auto int_max = std::numeric_limits<int>::max(); | ||||
| const auto long_min = std::numeric_limits<int64_t>::min(); | ||||
| const auto long_max = std::numeric_limits<int64_t>::max(); | ||||
| const auto float_lowest = std::numeric_limits<float>::lowest(); | ||||
| const auto float_min = std::numeric_limits<float>::min(); | ||||
| const auto float_max = std::numeric_limits<float>::max(); | ||||
| const auto double_lowest = std::numeric_limits<double>::lowest(); | ||||
| const auto double_min = std::numeric_limits<double>::min(); | ||||
| const auto double_max = std::numeric_limits<double>::max(); | ||||
|  | ||||
| const std::vector<int> ints { | ||||
|   int_min, | ||||
|  | ||||
| @ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() { | ||||
|  | ||||
| c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const { | ||||
|   // The RNG state comprises the seed, and an offset used for Philox. | ||||
|   constexpr size_t seed_size = sizeof(uint64_t); | ||||
|   constexpr size_t offset_size = sizeof(uint64_t); | ||||
|   constexpr size_t total_size = seed_size + offset_size; | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(uint64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
|  | ||||
|   // The internal state is returned as a CPU byte tensor. | ||||
|   auto state_tensor = at::detail::empty_cpu( | ||||
| @ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const { | ||||
| void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||||
|   at::xpu::assertNotCapturing( | ||||
|       "Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing."); | ||||
|   constexpr size_t seed_size = sizeof(uint64_t); | ||||
|   constexpr size_t offset_size = sizeof(uint64_t); | ||||
|   constexpr size_t total_size = seed_size + offset_size; | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(uint64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
|  | ||||
|   at::detail::check_rng_state(new_state); | ||||
|  | ||||
|  | ||||
| @ -6,7 +6,7 @@ import os | ||||
| import subprocess | ||||
| import sys | ||||
| import tempfile | ||||
| from collections.abc import Callable | ||||
| from typing import Callable | ||||
|  | ||||
| from torch._inductor.utils import fresh_cache | ||||
|  | ||||
|  | ||||
| @ -2284,11 +2284,9 @@ class BenchmarkRunner: | ||||
|                     ) | ||||
|                 ): | ||||
|                     is_same = False | ||||
|             except Exception as e: | ||||
|             except Exception: | ||||
|                 # Sometimes torch.allclose may throw RuntimeError | ||||
|                 exception_string = str(e) | ||||
|                 accuracy_status = f"fail_exception: {exception_string}" | ||||
|                 return record_status(accuracy_status, dynamo_start_stats=start_stats) | ||||
|                 is_same = False | ||||
|  | ||||
|             if not is_same: | ||||
|                 accuracy_status = "eager_two_runs_differ" | ||||
| @ -2405,11 +2403,9 @@ class BenchmarkRunner: | ||||
|                     force_max_multiplier=force_max_multiplier, | ||||
|                 ): | ||||
|                     is_same = False | ||||
|             except Exception as e: | ||||
|             except Exception: | ||||
|                 # Sometimes torch.allclose may throw RuntimeError | ||||
|                 exception_string = str(e) | ||||
|                 accuracy_status = f"fail_exception: {exception_string}" | ||||
|                 return record_status(accuracy_status, dynamo_start_stats=start_stats) | ||||
|                 is_same = False | ||||
|  | ||||
|             if not is_same: | ||||
|                 if self.args.skip_accuracy_check: | ||||
| @ -4064,7 +4060,7 @@ def run(runner, args, original_dir=None): | ||||
|         else: | ||||
|             optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) | ||||
|         experiment = ( | ||||
|             speedup_experiment if args.backend != "torchao" else latency_experiment | ||||
|             speedup_experiment if not args.backend == "torchao" else latency_experiment | ||||
|         ) | ||||
|         if args.accuracy: | ||||
|             output_filename = f"accuracy_{args.backend}.csv" | ||||
|  | ||||
| @ -1,8 +1,7 @@ | ||||
| import os | ||||
| from collections import defaultdict | ||||
| from collections.abc import Callable | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Optional | ||||
| from typing import Any, Callable, Optional | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
|  | ||||
| @ -124,7 +124,7 @@ with open(MODELS_FILENAME) as fh: | ||||
|             continue | ||||
|         batch_size = int(batch_size) | ||||
|         BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size | ||||
| assert BATCH_SIZE_KNOWN_MODELS | ||||
| assert len(BATCH_SIZE_KNOWN_MODELS) | ||||
|  | ||||
|  | ||||
| try: | ||||
|  | ||||
| @ -1,5 +1,4 @@ | ||||
| from collections.abc import Callable | ||||
| from typing import Any | ||||
| from typing import Any, Callable | ||||
|  | ||||
| import torch | ||||
|  | ||||
|  | ||||
| @ -1,8 +1,7 @@ | ||||
| import time | ||||
| from argparse import ArgumentParser | ||||
| from collections import defaultdict | ||||
| from collections.abc import Callable | ||||
| from typing import Any, NamedTuple | ||||
| from typing import Any, Callable, NamedTuple | ||||
|  | ||||
| import torch | ||||
| from torch.autograd import functional | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| from collections import defaultdict | ||||
| from collections.abc import Callable | ||||
| from typing import Optional, Union | ||||
| from typing import Callable, Optional, Union | ||||
|  | ||||
| import torch | ||||
| from torch import nn, Tensor | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| import dataclasses | ||||
| from collections.abc import Callable | ||||
| from typing import Optional | ||||
| from typing import Callable, Optional | ||||
|  | ||||
|  | ||||
| all_experiments: dict[str, Callable] = {} | ||||
|  | ||||
| @ -85,7 +85,7 @@ class WeightOnlyInt8QuantHandler: | ||||
|                 cur_state_dict[f"{fqn}.weight"] = int8_weight | ||||
|                 cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) | ||||
|             elif isinstance(mod, ConditionalFeedForward): | ||||
|                 for weight_idx in range(3): | ||||
|                 for weight_idx in range(0, 3): | ||||
|                     weight_name = f"w{weight_idx + 1}" | ||||
|                     scales_name = f"scales{weight_idx + 1}" | ||||
|                     weight = getattr(mod, weight_name) | ||||
|  | ||||
| @ -9,9 +9,8 @@ import logging | ||||
| import time | ||||
| from abc import abstractmethod | ||||
| from collections import defaultdict | ||||
| from collections.abc import Callable | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from typing import Any, Optional | ||||
| from typing import Any, Callable, Optional | ||||
|  | ||||
| from tabulate import tabulate | ||||
| from tqdm import tqdm | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -7,7 +7,6 @@ from pt import (  # noqa: F401 | ||||
|     binary_inplace_test, | ||||
|     binary_test, | ||||
|     bmm_test, | ||||
|     boolean_test, | ||||
|     cat_test, | ||||
|     channel_shuffle_test, | ||||
|     chunk_test, | ||||
|  | ||||
| @ -56,9 +56,6 @@ binary_ops_list = op_bench.op_list( | ||||
|         ["sub", torch.sub], | ||||
|         ["div", torch.div], | ||||
|         ["mul", torch.mul], | ||||
|         ["asr", torch.bitwise_right_shift], | ||||
|         ["lsl", torch.bitwise_left_shift], | ||||
|         ["xor", torch.bitwise_xor], | ||||
|     ], | ||||
| ) | ||||
|  | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	