mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	Compare commits
	
		
			260 Commits
		
	
	
		
			main-enabl
			...
			ciflow/tru
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| f6df9ab39e | |||
| 9f9ab881b2 | |||
| f2bb22ff84 | |||
| 03f3f7899c | |||
| 771170807b | |||
| ffa90d46e6 | |||
| 0e083942cc | |||
| ce1fcff03e | |||
| a238a9a100 | |||
| fe69a2bbbd | |||
| 0be0de4ffa | |||
| 7406d2e665 | |||
| 303c9cf048 | |||
| d7d4bb7c51 | |||
| 0b1c462979 | |||
| 4a6cf0a93e | |||
| 4c963a68d7 | |||
| b20deec3d1 | |||
| 51d0d8ee67 | |||
| 70592c6819 | |||
| 259cb945f5 | |||
| e20c9bf288 | |||
| 99c8640b5d | |||
| 96b0e7aaa6 | |||
| 850ba8c96d | |||
| 1bcd736f91 | |||
| df64c0c464 | |||
| 1891239a1d | |||
| cf280ca1e8 | |||
| efc277cac7 | |||
| 4f7f43253d | |||
| 779296a3fc | |||
| 8f06a1308f | |||
| 240c13394e | |||
| 150682ba7f | |||
| ca7360e996 | |||
| 0bf604320f | |||
| 9875e70da8 | |||
| 69a4bfe8bb | |||
| 62a263b8d4 | |||
| 0da1f911dc | |||
| 8700d68fef | |||
| ab82456c16 | |||
| b23f4687fd | |||
| 2705937080 | |||
| c1eda348be | |||
| ba93d5636e | |||
| 722b2b86c9 | |||
| e1e8491b31 | |||
| 767199fd9b | |||
| 602ace5eb4 | |||
| 47804ce467 | |||
| e8cb34dd52 | |||
| e9d8973427 | |||
| 61d9a5180e | |||
| 8a8329b51f | |||
| 6b80c94901 | |||
| 8951df03de | |||
| 8139f33fa5 | |||
| a88587348b | |||
| 633a3b7f67 | |||
| fa0db212e7 | |||
| 15ff1cd28b | |||
| c73f5080de | |||
| 22ae059d32 | |||
| 1b121d636e | |||
| 1ba808dd97 | |||
| b2f5c25b27 | |||
| a1114beed2 | |||
| 4888ed440e | |||
| 5d62b63a76 | |||
| 57ba575242 | |||
| ceb11a584d | |||
| 33adb276fe | |||
| e939651972 | |||
| 3255e7872b | |||
| c4f6619330 | |||
| f18041cca8 | |||
| 35e51893bd | |||
| 1f43d17ce6 | |||
| 032bed95cd | |||
| d14cbb4476 | |||
| f510d0dbc0 | |||
| beb6b62e8c | |||
| 4740ce7787 | |||
| ad67170c8b | |||
| fdab48a7c1 | |||
| a0948d4d23 | |||
| 0bbdd6b8db | |||
| 24520b8386 | |||
| c79dfdc655 | |||
| e595136187 | |||
| aaac8cb0f5 | |||
| 0f0b4bf029 | |||
| b8194268a6 | |||
| f02e3947f6 | |||
| 9095a9dfae | |||
| d9f94e0d7d | |||
| 23417ae50f | |||
| e4d6c56ffb | |||
| 017d2985f3 | |||
| c6a8db0b9a | |||
| de09bab4b6 | |||
| c137e222d4 | |||
| cf3a787bbc | |||
| de3da77cf7 | |||
| 543ddbf44c | |||
| e9f4999985 | |||
| 29b029648e | |||
| a25a649e70 | |||
| 69c33898fa | |||
| 1b397420f2 | |||
| fe80f03726 | |||
| e50dc40d28 | |||
| 2e22b1a61e | |||
| 616c6bdf8f | |||
| c18ddfc572 | |||
| 86ebce1766 | |||
| 8cb2fb44f2 | |||
| ab65498d71 | |||
| 06d324365c | |||
| 6c9c6e0936 | |||
| 2bcd892c86 | |||
| 75e2a9fae3 | |||
| a16fd6b488 | |||
| 382b0150de | |||
| a664b299ac | |||
| 9c12651417 | |||
| 08c97b4a1f | |||
| fae74cd52f | |||
| 7a65770013 | |||
| e4454947e2 | |||
| 3806e9767b | |||
| b08d8c2e50 | |||
| ca5b7f8ded | |||
| 9a71d96256 | |||
| 0d4c2b71e8 | |||
| d659bbde62 | |||
| 58879bfafa | |||
| a032510db3 | |||
| 39e0a832c9 | |||
| dd3b48e85d | |||
| cff1b20771 | |||
| da8517fa63 | |||
| 45afaf08a1 | |||
| 080365b7d8 | |||
| 2928c5c572 | |||
| 630520b346 | |||
| 1dc9a05d03 | |||
| bfcdbd0a97 | |||
| faff826a46 | |||
| 85c5433d38 | |||
| 935ccdbe75 | |||
| 3af2f0c12a | |||
| 6ece527fc5 | |||
| ce29d0d796 | |||
| 7231118db3 | |||
| 5d4da26ed0 | |||
| 574c9fc950 | |||
| 80d2ca7566 | |||
| 4a22139eea | |||
| cb6e4d7d82 | |||
| 202f83dc4e | |||
| 9fe3b2afbe | |||
| d0c24b392c | |||
| b44fb14906 | |||
| 51348c0219 | |||
| fdd560afd1 | |||
| e925dfcc6b | |||
| f1d882212a | |||
| 24879f0de9 | |||
| 9e94ec76b8 | |||
| 364624e209 | |||
| 7e150467f7 | |||
| 43d78423ac | |||
| fcbde24c1c | |||
| 861cdb887b | |||
| 3154482072 | |||
| 9fccbdd4f0 | |||
| 7dabfb07cb | |||
| d0add0be43 | |||
| 11e2084308 | |||
| 9726553653 | |||
| d82527b32a | |||
| 5d9b024276 | |||
| 5b2afe4c5d | |||
| b2953f5643 | |||
| 470e2f61c3 | |||
| e0fe37fa68 | |||
| d2c82bafb7 | |||
| 98a488c9aa | |||
| 5b3ea75895 | |||
| 556fc09a9f | |||
| ce109b3f79 | |||
| 4d833f859b | |||
| d7e275d4b4 | |||
| d5db3aee0d | |||
| 5641de7b6b | |||
| cbc08c8993 | |||
| 1a54d3333d | |||
| 4c1c341fa0 | |||
| 5f21cc786a | |||
| e86942f422 | |||
| 2cd5fd1588 | |||
| 7d0f872cb3 | |||
| fb06e49ce8 | |||
| 27a98e6ae9 | |||
| b10f463b1a | |||
| 431c13cf61 | |||
| aead9270f5 | |||
| 9bf5b38c14 | |||
| aba8c43594 | |||
| 37f3ba274a | |||
| 585b9dbb5e | |||
| d795fb225a | |||
| 7df9aca529 | |||
| d4a713cd9c | |||
| 5daef30b26 | |||
| 6dedd34c31 | |||
| a303d6dda9 | |||
| 7669ac9402 | |||
| 86fd4fc23e | |||
| 99097b6d89 | |||
| a214371008 | |||
| 7d87d7052e | |||
| 1a34ff4e04 | |||
| fe5ccb1a74 | |||
| 85586d7efc | |||
| e1d71a6b35 | |||
| d61a9b88cf | |||
| 99b32a6750 | |||
| 783da8b8e7 | |||
| ed74dc054d | |||
| f33c7e1a43 | |||
| 219fb6aafc | |||
| 515b5ff539 | |||
| 608a6d4a26 | |||
| 03e5dbb26e | |||
| 7ee45f7503 | |||
| e6d9d68598 | |||
| 1a5b7eca7b | |||
| 8573574b32 | |||
| e6033f6efb | |||
| 9272437cde | |||
| f06e669f6c | |||
| 69b05913fb | |||
| d73c283c3a | |||
| eaeaa08e3a | |||
| d0c32971b4 | |||
| d7ffa8b8a2 | |||
| 00afa06800 | |||
| 5d0b22008d | |||
| ab6014a903 | |||
| f6daffc54d | |||
| 66b75693ae | |||
| 21697feff2 | |||
| 12fa4192c5 | |||
| 23fb7e9f4b | |||
| 5e480b8ecf | |||
| 19ba506ca3 | 
| @ -113,6 +113,7 @@ case "$tag" in | |||||||
|     UCX_COMMIT=${_UCX_COMMIT} |     UCX_COMMIT=${_UCX_COMMIT} | ||||||
|     UCC_COMMIT=${_UCC_COMMIT} |     UCC_COMMIT=${_UCC_COMMIT} | ||||||
|     TRITON=yes |     TRITON=yes | ||||||
|  |     INSTALL_MINGW=yes | ||||||
|     ;; |     ;; | ||||||
|   pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11) |   pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11) | ||||||
|     CUDA_VERSION=13.0.0 |     CUDA_VERSION=13.0.0 | ||||||
| @ -361,6 +362,7 @@ docker build \ | |||||||
|        --build-arg "OPENBLAS=${OPENBLAS:-}" \ |        --build-arg "OPENBLAS=${OPENBLAS:-}" \ | ||||||
|        --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \ |        --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \ | ||||||
|        --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \ |        --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \ | ||||||
|  |        --build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \ | ||||||
|        -f $(dirname ${DOCKERFILE})/Dockerfile \ |        -f $(dirname ${DOCKERFILE})/Dockerfile \ | ||||||
|        -t "$tmp_tag" \ |        -t "$tmp_tag" \ | ||||||
|        "$@" \ |        "$@" \ | ||||||
|  | |||||||
| @ -83,10 +83,6 @@ function build_cpython { | |||||||
|         py_suffix=${py_ver::-1} |         py_suffix=${py_ver::-1} | ||||||
|         py_folder=$py_suffix |         py_folder=$py_suffix | ||||||
|     fi |     fi | ||||||
|     # Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4 |  | ||||||
|     if [ "$py_suffix" == "3.14.0" ]; then |  | ||||||
|         py_suffix="3.14.0rc2" |  | ||||||
|     fi |  | ||||||
|     wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz |     wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz | ||||||
|     do_cpython_build $py_ver Python-$py_suffix |     do_cpython_build $py_ver Python-$py_suffix | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								.ci/docker/common/install_mingw.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								.ci/docker/common/install_mingw.sh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | |||||||
|  | #!/bin/bash | ||||||
|  |  | ||||||
|  | set -ex | ||||||
|  |  | ||||||
|  | # Install MinGW-w64 for Windows cross-compilation | ||||||
|  | apt-get update | ||||||
|  | apt-get install -y g++-mingw-w64-x86-64-posix | ||||||
|  |  | ||||||
|  | echo "MinGW-w64 installed successfully" | ||||||
|  | x86_64-w64-mingw32-g++ --version | ||||||
| @ -20,7 +20,7 @@ pip_install \ | |||||||
|  |  | ||||||
| pip_install coloredlogs packaging | pip_install coloredlogs packaging | ||||||
| pip_install onnxruntime==1.23.0 | pip_install onnxruntime==1.23.0 | ||||||
| pip_install onnxscript==0.5.3 | pip_install onnxscript==0.5.4 | ||||||
|  |  | ||||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||||
| # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ | # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ | ||||||
|  | |||||||
| @ -39,9 +39,13 @@ case ${DOCKER_TAG_PREFIX} in | |||||||
|         DOCKER_GPU_BUILD_ARG="" |         DOCKER_GPU_BUILD_ARG="" | ||||||
|         ;; |         ;; | ||||||
|     rocm*) |     rocm*) | ||||||
|  |         # we want the patch version of 7.0 instead | ||||||
|  |         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||||
|  |             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||||
|  |         fi | ||||||
|         # we want the patch version of 6.4 instead |         # we want the patch version of 6.4 instead | ||||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then |         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" |             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" | ||||||
|         fi |         fi | ||||||
|         BASE_TARGET=rocm |         BASE_TARGET=rocm | ||||||
|         GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete |         GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete | ||||||
|  | |||||||
| @ -75,9 +75,13 @@ case ${image} in | |||||||
|         DOCKERFILE_SUFFIX="_cuda_aarch64" |         DOCKERFILE_SUFFIX="_cuda_aarch64" | ||||||
|         ;; |         ;; | ||||||
|     manylinux2_28-builder:rocm*) |     manylinux2_28-builder:rocm*) | ||||||
|  |         # we want the patch version of 7.0 instead | ||||||
|  |         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||||
|  |             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||||
|  |         fi | ||||||
|         # we want the patch version of 6.4 instead |         # we want the patch version of 6.4 instead | ||||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then |         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" |             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" | ||||||
|         fi |         fi | ||||||
|         TARGET=rocm_final |         TARGET=rocm_final | ||||||
|         MANY_LINUX_VERSION="2_28" |         MANY_LINUX_VERSION="2_28" | ||||||
|  | |||||||
| @ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt | |||||||
| RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi | RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi | ||||||
| RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt | RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt | ||||||
|  |  | ||||||
|  | ARG INSTALL_MINGW | ||||||
|  | COPY ./common/install_mingw.sh install_mingw.sh | ||||||
|  | RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi | ||||||
|  | RUN rm install_mingw.sh | ||||||
|  |  | ||||||
| ARG TRITON | ARG TRITON | ||||||
| ARG TRITON_CPU | ARG TRITON_CPU | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules | |||||||
|         logger.info("Successfully cloned %s", target) |         logger.info("Successfully cloned %s", target) | ||||||
|         return r, commit |         return r, commit | ||||||
|  |  | ||||||
|     except GitCommandError as e: |     except GitCommandError: | ||||||
|         logger.error("Git operation failed: %s", e) |         logger.exception("Git operation failed") | ||||||
|         raise |         raise | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -485,6 +485,22 @@ test_inductor_aoti() { | |||||||
|   /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile |   /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile | ||||||
| } | } | ||||||
|  |  | ||||||
|  | test_inductor_aoti_cross_compile_for_windows() { | ||||||
|  |  | ||||||
|  |   TEST_REPORTS_DIR=$(pwd)/test/test-reports | ||||||
|  |   mkdir -p "$TEST_REPORTS_DIR" | ||||||
|  |  | ||||||
|  |   # Set WINDOWS_CUDA_HOME environment variable | ||||||
|  |   WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted" | ||||||
|  |   export WINDOWS_CUDA_HOME | ||||||
|  |  | ||||||
|  |   echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME" | ||||||
|  |   echo "Contents:" | ||||||
|  |   ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true | ||||||
|  |  | ||||||
|  |   python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib" | ||||||
|  | } | ||||||
|  |  | ||||||
| test_inductor_cpp_wrapper_shard() { | test_inductor_cpp_wrapper_shard() { | ||||||
|   if [[ -z "$NUM_TEST_SHARDS" ]]; then |   if [[ -z "$NUM_TEST_SHARDS" ]]; then | ||||||
|     echo "NUM_TEST_SHARDS must be defined to run a Python test shard" |     echo "NUM_TEST_SHARDS must be defined to run a Python test shard" | ||||||
| @ -900,7 +916,7 @@ test_inductor_set_cpu_affinity(){ | |||||||
|   export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD" |   export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD" | ||||||
|   export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1" |   export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1" | ||||||
|  |  | ||||||
|   if [[ "${TEST_CONFIG}" != *aarch64* ]]; then |   if [[ "$(uname -m)" != "aarch64" ]]; then | ||||||
|     # Use Intel OpenMP for x86 |     # Use Intel OpenMP for x86 | ||||||
|     IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so" |     IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so" | ||||||
|     export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD" |     export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD" | ||||||
| @ -914,7 +930,7 @@ test_inductor_set_cpu_affinity(){ | |||||||
|   cores=$((cpus / thread_per_core)) |   cores=$((cpus / thread_per_core)) | ||||||
|  |  | ||||||
|   # Set number of cores to 16 on aarch64 for performance runs |   # Set number of cores to 16 on aarch64 for performance runs | ||||||
|   if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then |   if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then | ||||||
|     cores=16 |     cores=16 | ||||||
|   fi |   fi | ||||||
|   export OMP_NUM_THREADS=$cores |   export OMP_NUM_THREADS=$cores | ||||||
| @ -1615,6 +1631,7 @@ test_operator_benchmark() { | |||||||
|   TEST_REPORTS_DIR=$(pwd)/test/test-reports |   TEST_REPORTS_DIR=$(pwd)/test/test-reports | ||||||
|   mkdir -p "$TEST_REPORTS_DIR" |   mkdir -p "$TEST_REPORTS_DIR" | ||||||
|   TEST_DIR=$(pwd) |   TEST_DIR=$(pwd) | ||||||
|  |   ARCH=$(uname -m) | ||||||
|  |  | ||||||
|   test_inductor_set_cpu_affinity |   test_inductor_set_cpu_affinity | ||||||
|  |  | ||||||
| @ -1629,7 +1646,7 @@ test_operator_benchmark() { | |||||||
|   pip_install pandas |   pip_install pandas | ||||||
|   python check_perf_csv.py \ |   python check_perf_csv.py \ | ||||||
|       --actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \ |       --actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \ | ||||||
|       --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv" |       --expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv" | ||||||
| } | } | ||||||
|  |  | ||||||
| test_operator_microbenchmark() { | test_operator_microbenchmark() { | ||||||
| @ -1666,7 +1683,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then | |||||||
|     python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 |     python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 | ||||||
|   fi |   fi | ||||||
|   python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py |   python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py | ||||||
| elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then | elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then | ||||||
|   test_linux_aarch64 |   test_linux_aarch64 | ||||||
| elif [[ "${TEST_CONFIG}" == *backward* ]]; then | elif [[ "${TEST_CONFIG}" == *backward* ]]; then | ||||||
|   test_forward_backward_compatibility |   test_forward_backward_compatibility | ||||||
| @ -1717,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then | |||||||
|   test_inductor_triton_cpu |   test_inductor_triton_cpu | ||||||
| elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then | elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then | ||||||
|   test_inductor_micro_benchmark |   test_inductor_micro_benchmark | ||||||
|  | elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then | ||||||
|  |   test_inductor_aoti_cross_compile_for_windows | ||||||
| elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then | elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then | ||||||
|   install_torchvision |   install_torchvision | ||||||
|   id=$((SHARD_NUMBER-1)) |   id=$((SHARD_NUMBER-1)) | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								.flake8
									
									
									
									
									
								
							| @ -7,16 +7,12 @@ max-line-length = 120 | |||||||
| # C408 ignored because we like the dict keyword argument syntax | # C408 ignored because we like the dict keyword argument syntax | ||||||
| # E501 is not flexible enough, we're using B950 instead | # E501 is not flexible enough, we're using B950 instead | ||||||
| ignore = | ignore = | ||||||
|     E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, |     E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, | ||||||
|     # shebang has extra meaning in fbcode lints, so I think it's not worth trying |     # shebang has extra meaning in fbcode lints, so I think it's not worth trying | ||||||
|     # to line this up with executable bit |     # to line this up with executable bit | ||||||
|     EXE001, |     EXE001, | ||||||
|     # these ignores are from flake8-bugbear; please fix! |     # these ignores are from flake8-bugbear; please fix! | ||||||
|     B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910 |     B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910 | ||||||
|     # these ignores are from flake8-comprehensions; please fix! |  | ||||||
|     C407, |  | ||||||
|     # these ignores are from flake8-logging-format; please fix! |  | ||||||
|     G100,G101,G200 |  | ||||||
|     # these ignores are from flake8-simplify. please fix or ignore with commented reason |     # these ignores are from flake8-simplify. please fix or ignore with commented reason | ||||||
|     SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, |     SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, | ||||||
|     # SIM104 is already covered by pyupgrade ruff |     # SIM104 is already covered by pyupgrade ruff | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							| @ -54,12 +54,17 @@ self-hosted-runner: | |||||||
|     - windows-11-arm64 |     - windows-11-arm64 | ||||||
|     - windows-11-arm64-preview |     - windows-11-arm64-preview | ||||||
|     # Organization-wide AMD-hosted runners |     # Organization-wide AMD-hosted runners | ||||||
|     # MI2xx runners |     # MI2xx non-ARC runners | ||||||
|     - linux.rocm.gpu |     - linux.rocm.gpu | ||||||
|     - linux.rocm.gpu.mi250 |  | ||||||
|     - linux.rocm.gpu.2 |     - linux.rocm.gpu.2 | ||||||
|     - linux.rocm.gpu.4 |     - linux.rocm.gpu.4 | ||||||
|     # gfx942 runners |     - linux.rocm.gpu.mi250 | ||||||
|  |     - linux.rocm.gpu.gfx1100 | ||||||
|  |     # MI2xx ARC runners | ||||||
|  |     - linux.rocm.gpu.mi250.1 | ||||||
|  |     - linux.rocm.gpu.mi250.2 | ||||||
|  |     - linux.rocm.gpu.mi250.4 | ||||||
|  |     # gfx942 ARC runners | ||||||
|     - linux.rocm.gpu.gfx942.1 |     - linux.rocm.gpu.gfx942.1 | ||||||
|     - linux.rocm.gpu.gfx942.2 |     - linux.rocm.gpu.gfx942.2 | ||||||
|     - linux.rocm.gpu.gfx942.4 |     - linux.rocm.gpu.gfx942.4 | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| 1b013f5b5a87a1882eb143c26d79d091150d6a37 | 69bbe7363897764f9e758d851cd0340147d27f94 | ||||||
|  | |||||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							| @ -133,3 +133,32 @@ | |||||||
|  |  | ||||||
| "ciflow/vllm": | "ciflow/vllm": | ||||||
| - .github/ci_commit_pins/vllm.txt | - .github/ci_commit_pins/vllm.txt | ||||||
|  |  | ||||||
|  | "ciflow/b200": | ||||||
|  | - test/test_matmul_cuda.py | ||||||
|  | - test/test_scaled_matmul_cuda.py | ||||||
|  | - test/inductor/test_fp8.py | ||||||
|  | - aten/src/ATen/native/cuda/Blas.cpp | ||||||
|  | - torch/**/*cublas* | ||||||
|  | - torch/_inductor/kernel/mm.py | ||||||
|  | - test/inductor/test_max_autotune.py | ||||||
|  | - third_party/fbgemm | ||||||
|  |  | ||||||
|  | "ciflow/h100": | ||||||
|  | - test/test_matmul_cuda.py | ||||||
|  | - test/test_scaled_matmul_cuda.py | ||||||
|  | - test/inductor/test_fp8.py | ||||||
|  | - aten/src/ATen/native/cuda/Blas.cpp | ||||||
|  | - torch/**/*cublas* | ||||||
|  | - torch/_inductor/kernel/mm.py | ||||||
|  | - test/inductor/test_max_autotune.py | ||||||
|  | - third_party/fbgemm | ||||||
|  |  | ||||||
|  | "ciflow/rocm": | ||||||
|  | - test/test_matmul_cuda.py | ||||||
|  | - test/test_scaled_matmul_cuda.py | ||||||
|  | - test/inductor/test_fp8.py | ||||||
|  | - aten/src/ATen/native/cuda/Blas.cpp | ||||||
|  | - torch/_inductor/kernel/mm.py | ||||||
|  | - test/inductor/test_max_autotune.py | ||||||
|  | - third_party/fbgemm | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							| @ -33,6 +33,7 @@ ciflow_push_tags: | |||||||
| - ciflow/rocm | - ciflow/rocm | ||||||
| - ciflow/rocm-mi300 | - ciflow/rocm-mi300 | ||||||
| - ciflow/rocm-mi355 | - ciflow/rocm-mi355 | ||||||
|  | - ciflow/rocm-navi31 | ||||||
| - ciflow/s390 | - ciflow/s390 | ||||||
| - ciflow/slow | - ciflow/slow | ||||||
| - ciflow/torchbench | - ciflow/torchbench | ||||||
|  | |||||||
							
								
								
									
										42
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										42
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							| @ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { | |||||||
|         "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" |         "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" | ||||||
|     ), |     ), | ||||||
|     "12.9": ( |     "12.9": ( | ||||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " | ||||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " | ||||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " | ||||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | " | ||||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | " | ||||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | " | ||||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | " | ||||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | " | ||||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | " | ||||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " | ||||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | " | ||||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | " | ||||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | " | ||||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " |         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | " | ||||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'" |         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" | ||||||
|     ), |     ), | ||||||
|     "13.0": ( |     "13.0": ( | ||||||
|         "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | " |         "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | " | ||||||
| @ -241,7 +241,11 @@ def generate_libtorch_matrix( | |||||||
|             arches += CUDA_ARCHES |             arches += CUDA_ARCHES | ||||||
|             arches += ROCM_ARCHES |             arches += ROCM_ARCHES | ||||||
|         elif os == "windows": |         elif os == "windows": | ||||||
|             arches += CUDA_ARCHES |             # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up | ||||||
|  |             # in 2.10 | ||||||
|  |             windows_cuda_arches = CUDA_ARCHES.copy() | ||||||
|  |             windows_cuda_arches.remove("12.9") | ||||||
|  |             arches += windows_cuda_arches | ||||||
|     if libtorch_variants is None: |     if libtorch_variants is None: | ||||||
|         libtorch_variants = [ |         libtorch_variants = [ | ||||||
|             "shared-with-deps", |             "shared-with-deps", | ||||||
| @ -305,7 +309,11 @@ def generate_wheels_matrix( | |||||||
|         if os == "linux": |         if os == "linux": | ||||||
|             arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES |             arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES | ||||||
|         elif os == "windows": |         elif os == "windows": | ||||||
|             arches += CUDA_ARCHES + XPU_ARCHES |             # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up | ||||||
|  |             # in 2.10 | ||||||
|  |             windows_cuda_arches = CUDA_ARCHES.copy() | ||||||
|  |             windows_cuda_arches.remove("12.9") | ||||||
|  |             arches += windows_cuda_arches + XPU_ARCHES | ||||||
|         elif os == "linux-aarch64": |         elif os == "linux-aarch64": | ||||||
|             # Separate new if as the CPU type is different and |             # Separate new if as the CPU type is different and | ||||||
|             # uses different build/test scripts |             # uses different build/test scripts | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							| @ -1092,7 +1092,7 @@ class GitHubPR: | |||||||
|         editor = node["editor"] |         editor = node["editor"] | ||||||
|         return GitHubComment( |         return GitHubComment( | ||||||
|             body_text=node["bodyText"], |             body_text=node["bodyText"], | ||||||
|             created_at=node["createdAt"] if "createdAt" in node else "", |             created_at=node.get("createdAt", ""), | ||||||
|             author_login=node["author"]["login"], |             author_login=node["author"]["login"], | ||||||
|             author_url=node["author"].get("url", None), |             author_url=node["author"].get("url", None), | ||||||
|             author_association=node["authorAssociation"], |             author_association=node["authorAssociation"], | ||||||
|  | |||||||
| @ -26,9 +26,8 @@ name: !{{ build_environment }} | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}" |           python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}" | ||||||
|           freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }} |           freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }} | ||||||
| {%- endmacro %} | {%- endmacro %} | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							| @ -37,7 +37,7 @@ on: | |||||||
|       runner: |       runner: | ||||||
|         required: false |         required: false | ||||||
|         type: string |         type: string | ||||||
|         default: "linux.2xlarge" |         default: "linux.c7i.2xlarge" | ||||||
|         description: | |         description: | | ||||||
|           Label of the runner this job should run on. |           Label of the runner this job should run on. | ||||||
|       test-matrix: |       test-matrix: | ||||||
|  | |||||||
							
								
								
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							| @ -224,6 +224,46 @@ jobs: | |||||||
|         continue-on-error: true |         continue-on-error: true | ||||||
|         uses: ./.github/actions/download-td-artifacts |         uses: ./.github/actions/download-td-artifacts | ||||||
|  |  | ||||||
|  |       - name: Download Windows torch wheel for cross-compilation | ||||||
|  |         if: matrix.win_torch_wheel_artifact != '' | ||||||
|  |         uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0 | ||||||
|  |         with: | ||||||
|  |           name: ${{ matrix.win_torch_wheel_artifact }} | ||||||
|  |           path: win-torch-wheel | ||||||
|  |  | ||||||
|  |       - name: Extract Windows wheel and setup CUDA libraries | ||||||
|  |         if: matrix.win_torch_wheel_artifact != '' | ||||||
|  |         shell: bash | ||||||
|  |         run: | | ||||||
|  |           set -x | ||||||
|  |  | ||||||
|  |           # Find the wheel file | ||||||
|  |           WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1) | ||||||
|  |           if [ -z "$WHEEL_FILE" ]; then | ||||||
|  |             echo "Error: No wheel file found in win-torch-wheel directory" | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|  |           echo "Found wheel file: $WHEEL_FILE" | ||||||
|  |  | ||||||
|  |           # Unzip the wheel file | ||||||
|  |           unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted | ||||||
|  |           echo "Extracted wheel contents" | ||||||
|  |  | ||||||
|  |           # Setup CUDA libraries (cuda.lib and cudart.lib) directory | ||||||
|  |           mkdir -p win-torch-wheel-extracted/lib/x64 | ||||||
|  |           if [ -f "win-torch-wheel/cuda.lib" ]; then | ||||||
|  |             mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/ | ||||||
|  |             echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/" | ||||||
|  |           fi | ||||||
|  |           if [ -f "win-torch-wheel/cudart.lib" ]; then | ||||||
|  |             mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/ | ||||||
|  |             echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/" | ||||||
|  |           fi | ||||||
|  |  | ||||||
|  |           # Verify CUDA libraries are present | ||||||
|  |           echo "CUDA libraries:" | ||||||
|  |           ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found" | ||||||
|  |  | ||||||
|       - name: Parse ref |       - name: Parse ref | ||||||
|         id: parse-ref |         id: parse-ref | ||||||
|         run: .github/scripts/parse_ref.py |         run: .github/scripts/parse_ref.py | ||||||
|  | |||||||
							
								
								
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							| @ -168,6 +168,31 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           .ci/pytorch/win-build.sh |           .ci/pytorch/win-build.sh | ||||||
|  |  | ||||||
|  |       # Collect Windows torch libs and CUDA libs for cross-compilation | ||||||
|  |       - name: Collect Windows CUDA libs for cross-compilation | ||||||
|  |         if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu' | ||||||
|  |         shell: bash | ||||||
|  |         run: | | ||||||
|  |           set -ex | ||||||
|  |  | ||||||
|  |           # Create directory structure if does not exist | ||||||
|  |           mkdir -p /c/${{ github.run_id }}/build-results | ||||||
|  |  | ||||||
|  |           # Copy CUDA libs | ||||||
|  |           CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}" | ||||||
|  |  | ||||||
|  |           if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then | ||||||
|  |             cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/ | ||||||
|  |           fi | ||||||
|  |  | ||||||
|  |           if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then | ||||||
|  |             cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/ | ||||||
|  |           fi | ||||||
|  |  | ||||||
|  |           # List collected files | ||||||
|  |           echo "Collected CUDA libs:" | ||||||
|  |           ls -lah /c/${{ github.run_id }}/build-results/*.lib | ||||||
|  |  | ||||||
|       # Upload to github so that people can click and download artifacts |       # Upload to github so that people can click and download artifacts | ||||||
|       - name: Upload artifacts to s3 |       - name: Upload artifacts to s3 | ||||||
|         if: steps.build.outcome != 'skipped' |         if: steps.build.outcome != 'skipped' | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -224,7 +224,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_10-cuda-aarch64-12_9 |       build_name: manywheel-py3_10-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -473,7 +473,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_11-cuda-aarch64-12_9 |       build_name: manywheel-py3_11-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -722,7 +722,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_12-cuda-aarch64-12_9 |       build_name: manywheel-py3_12-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -971,7 +971,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_13-cuda-aarch64-12_9 |       build_name: manywheel-py3_13-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1220,7 +1220,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_13t-cuda-aarch64-12_9 |       build_name: manywheel-py3_13t-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1469,7 +1469,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_14-cuda-aarch64-12_9 |       build_name: manywheel-py3_14-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
| @ -1718,7 +1718,7 @@ jobs: | |||||||
|       ALPINE_IMAGE: "arm64v8/alpine" |       ALPINE_IMAGE: "arm64v8/alpine" | ||||||
|       build_name: manywheel-py3_14t-cuda-aarch64-12_9 |       build_name: manywheel-py3_14t-cuda-aarch64-12_9 | ||||||
|       build_environment: linux-aarch64-binary-manywheel |       build_environment: linux-aarch64-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|       timeout-minutes: 420 |       timeout-minutes: 420 | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -259,7 +259,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_10-cuda12_9 |       build_name: manywheel-py3_10-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_10-cuda12_9-test:  # Testing |   manywheel-py3_10-cuda12_9-test:  # Testing | ||||||
| @ -925,7 +925,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_11-cuda12_9 |       build_name: manywheel-py3_11-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_11-cuda12_9-test:  # Testing |   manywheel-py3_11-cuda12_9-test:  # Testing | ||||||
| @ -1591,7 +1591,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_12-cuda12_9 |       build_name: manywheel-py3_12-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_12-cuda12_9-test:  # Testing |   manywheel-py3_12-cuda12_9-test:  # Testing | ||||||
| @ -2257,7 +2257,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_13-cuda12_9 |       build_name: manywheel-py3_13-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_13-cuda12_9-test:  # Testing |   manywheel-py3_13-cuda12_9-test:  # Testing | ||||||
| @ -2923,7 +2923,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_13t-cuda12_9 |       build_name: manywheel-py3_13t-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_13t-cuda12_9-test:  # Testing |   manywheel-py3_13t-cuda12_9-test:  # Testing | ||||||
| @ -3589,7 +3589,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_14-cuda12_9 |       build_name: manywheel-py3_14-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_14-cuda12_9-test:  # Testing |   manywheel-py3_14-cuda12_9-test:  # Testing | ||||||
| @ -4255,7 +4255,7 @@ jobs: | |||||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|       build_name: manywheel-py3_14t-cuda12_9 |       build_name: manywheel-py3_14t-cuda12_9 | ||||||
|       build_environment: linux-binary-manywheel |       build_environment: linux-binary-manywheel | ||||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' |       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|   manywheel-py3_14t-cuda12_9-test:  # Testing |   manywheel-py3_14t-cuda12_9-test:  # Testing | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -63,7 +63,6 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.10.4" |           python-version: "3.10.4" | ||||||
|           freethreaded: false |           freethreaded: false | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -59,7 +59,6 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.10.4" |           python-version: "3.10.4" | ||||||
|           freethreaded: false |           freethreaded: false | ||||||
| @ -169,7 +168,6 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.11.4" |           python-version: "3.11.4" | ||||||
|           freethreaded: false |           freethreaded: false | ||||||
| @ -279,7 +277,6 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.12.4" |           python-version: "3.12.4" | ||||||
|           freethreaded: false |           freethreaded: false | ||||||
| @ -389,7 +386,6 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.13.4" |           python-version: "3.13.4" | ||||||
|           freethreaded: false |           freethreaded: false | ||||||
| @ -499,7 +495,6 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.13.4" |           python-version: "3.13.4" | ||||||
|           freethreaded: true |           freethreaded: true | ||||||
| @ -609,9 +604,8 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.14.0-rc.2" |           python-version: "3.14.0" | ||||||
|           freethreaded: false |           freethreaded: false | ||||||
|       - name: Checkout PyTorch |       - name: Checkout PyTorch | ||||||
|         uses: actions/checkout@v4 |         uses: actions/checkout@v4 | ||||||
| @ -719,9 +713,8 @@ jobs: | |||||||
|       - name: Setup Python |       - name: Setup Python | ||||||
|         uses: actions/setup-python@v6 |         uses: actions/setup-python@v6 | ||||||
|         with: |         with: | ||||||
|           # TODO: Removeme once 3.14 is out |  | ||||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 |           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||||
|           python-version: "3.14.0-rc.2" |           python-version: "3.14.0" | ||||||
|           freethreaded: true |           freethreaded: true | ||||||
|       - name: Checkout PyTorch |       - name: Checkout PyTorch | ||||||
|         uses: actions/checkout@v4 |         uses: actions/checkout@v4 | ||||||
|  | |||||||
							
								
								
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -788,256 +788,6 @@ jobs: | |||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|     uses: ./.github/workflows/_binary-upload.yml |     uses: ./.github/workflows/_binary-upload.yml | ||||||
|   libtorch-cuda12_9-shared-with-deps-debug-build: |  | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |  | ||||||
|     needs: get-label-type |  | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |  | ||||||
|     timeout-minutes: 360 |  | ||||||
|     env: |  | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |  | ||||||
|       PACKAGE_TYPE: libtorch |  | ||||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in |  | ||||||
|       #       favor of GPU_ARCH_VERSION |  | ||||||
|       DESIRED_CUDA: cu129 |  | ||||||
|       GPU_ARCH_VERSION: "12.9" |  | ||||||
|       GPU_ARCH_TYPE: cuda |  | ||||||
|       SKIP_ALL_TESTS: 1 |  | ||||||
|       LIBTORCH_CONFIG: debug |  | ||||||
|       LIBTORCH_VARIANT: shared-with-deps |  | ||||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts |  | ||||||
|       # without this value pip does not get installed for some reason |  | ||||||
|       DESIRED_PYTHON: "3.10" |  | ||||||
|     steps: |  | ||||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally |  | ||||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the |  | ||||||
|       #       runner.temp variable, which we need. |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" |  | ||||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" |  | ||||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" |  | ||||||
|       - name: Display EC2 information |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           set -euo pipefail |  | ||||||
|           function get_ec2_metadata() { |  | ||||||
|             # Pulled from instance metadata endpoint for EC2 |  | ||||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html |  | ||||||
|             category=$1 |  | ||||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" |  | ||||||
|           } |  | ||||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" |  | ||||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" |  | ||||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" |  | ||||||
|           echo "system info $(uname -a)" |  | ||||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" |  | ||||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main |  | ||||||
|         continue-on-error: true |  | ||||||
|         with: |  | ||||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           git config --global core.longpaths true |  | ||||||
|           git config --global core.symlinks true |  | ||||||
|  |  | ||||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock |  | ||||||
|           # the directory on Windows and prevent GHA from checking out as reported |  | ||||||
|           # in https://github.com/actions/checkout/issues/1018 |  | ||||||
|           git config --global core.fsmonitor false |  | ||||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 |  | ||||||
|       - name: Enable long paths on Windows |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 |  | ||||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be |  | ||||||
|       # removed once Windows Defender is removed from the AMI |  | ||||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch |  | ||||||
|         continue-on-error: true |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore |  | ||||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure |  | ||||||
|           # that it doesn't interfere |  | ||||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore |  | ||||||
|       - name: Checkout PyTorch |  | ||||||
|         uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} |  | ||||||
|           submodules: recursive |  | ||||||
|           path: pytorch |  | ||||||
|           show-progress: false |  | ||||||
|       - name: Clean PyTorch checkout |  | ||||||
|         run: | |  | ||||||
|           # Remove any artifacts from the previous checkouts |  | ||||||
|           git clean -fxd |  | ||||||
|         working-directory: pytorch |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" |  | ||||||
|       - name: Build PyTorch binary |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" |  | ||||||
|       - uses: actions/upload-artifact@v4.4.0 |  | ||||||
|         if: always() |  | ||||||
|         with: |  | ||||||
|           name: libtorch-cuda12_9-shared-with-deps-debug |  | ||||||
|           retention-days: 14 |  | ||||||
|           if-no-files-found: error |  | ||||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" |  | ||||||
|       - name: Wait until all sessions have drained |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         timeout-minutes: 120 |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 |  | ||||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\kill_active_ssh_sessions.ps1 |  | ||||||
|  |  | ||||||
|   libtorch-cuda12_9-shared-with-deps-debug-test:  # Testing |  | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |  | ||||||
|     needs: |  | ||||||
|       - libtorch-cuda12_9-shared-with-deps-debug-build |  | ||||||
|       - get-label-type |  | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" |  | ||||||
|     timeout-minutes: 360 |  | ||||||
|     env: |  | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |  | ||||||
|       PACKAGE_TYPE: libtorch |  | ||||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in |  | ||||||
|       #       favor of GPU_ARCH_VERSION |  | ||||||
|       DESIRED_CUDA: cu129 |  | ||||||
|       GPU_ARCH_VERSION: "12.9" |  | ||||||
|       GPU_ARCH_TYPE: cuda |  | ||||||
|       SKIP_ALL_TESTS: 1 |  | ||||||
|       LIBTORCH_CONFIG: debug |  | ||||||
|       LIBTORCH_VARIANT: shared-with-deps |  | ||||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts |  | ||||||
|       # without this value pip does not get installed for some reason |  | ||||||
|       DESIRED_PYTHON: "3.10" |  | ||||||
|     steps: |  | ||||||
|       - name: Display EC2 information |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           set -euo pipefail |  | ||||||
|           function get_ec2_metadata() { |  | ||||||
|             # Pulled from instance metadata endpoint for EC2 |  | ||||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html |  | ||||||
|             category=$1 |  | ||||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" |  | ||||||
|           } |  | ||||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" |  | ||||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" |  | ||||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" |  | ||||||
|           echo "system info $(uname -a)" |  | ||||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" |  | ||||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main |  | ||||||
|         continue-on-error: true |  | ||||||
|         with: |  | ||||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           git config --global core.longpaths true |  | ||||||
|           git config --global core.symlinks true |  | ||||||
|  |  | ||||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock |  | ||||||
|           # the directory on Windows and prevent GHA from checking out as reported |  | ||||||
|           # in https://github.com/actions/checkout/issues/1018 |  | ||||||
|           git config --global core.fsmonitor false |  | ||||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 |  | ||||||
|       - name: Enable long paths on Windows |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 |  | ||||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be |  | ||||||
|       # removed once Windows Defender is removed from the AMI |  | ||||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch |  | ||||||
|         continue-on-error: true |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore |  | ||||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure |  | ||||||
|           # that it doesn't interfere |  | ||||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore |  | ||||||
|       - name: Checkout PyTorch |  | ||||||
|         uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} |  | ||||||
|           submodules: recursive |  | ||||||
|           path: pytorch |  | ||||||
|           show-progress: false |  | ||||||
|       - name: Clean PyTorch checkout |  | ||||||
|         run: | |  | ||||||
|           # Remove any artifacts from the previous checkouts |  | ||||||
|           git clean -fxd |  | ||||||
|         working-directory: pytorch |  | ||||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally |  | ||||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the |  | ||||||
|       #       runner.temp variable, which we need. |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" |  | ||||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" |  | ||||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" |  | ||||||
|       - uses: actions/download-artifact@v4.1.7 |  | ||||||
|         name: Download Build Artifacts |  | ||||||
|         with: |  | ||||||
|           name: libtorch-cuda12_9-shared-with-deps-debug |  | ||||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" |  | ||||||
|       - name: Test PyTorch binary |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" |  | ||||||
|       - name: Wait until all sessions have drained |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         timeout-minutes: 120 |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 |  | ||||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\kill_active_ssh_sessions.ps1 |  | ||||||
|   libtorch-cuda12_9-shared-with-deps-debug-upload:  # Uploading |  | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |  | ||||||
|     permissions: |  | ||||||
|       id-token: write |  | ||||||
|       contents: read |  | ||||||
|     needs: libtorch-cuda12_9-shared-with-deps-debug-test |  | ||||||
|     with: |  | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |  | ||||||
|       PACKAGE_TYPE: libtorch |  | ||||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in |  | ||||||
|       #       favor of GPU_ARCH_VERSION |  | ||||||
|       DESIRED_CUDA: cu129 |  | ||||||
|       GPU_ARCH_VERSION: "12.9" |  | ||||||
|       GPU_ARCH_TYPE: cuda |  | ||||||
|       LIBTORCH_CONFIG: debug |  | ||||||
|       LIBTORCH_VARIANT: shared-with-deps |  | ||||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts |  | ||||||
|       # without this value pip does not get installed for some reason |  | ||||||
|       DESIRED_PYTHON: "3.10" |  | ||||||
|       build_name: libtorch-cuda12_9-shared-with-deps-debug |  | ||||||
|     secrets: |  | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|     uses: ./.github/workflows/_binary-upload.yml |  | ||||||
|   libtorch-cuda13_0-shared-with-deps-debug-build: |   libtorch-cuda13_0-shared-with-deps-debug-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|  | |||||||
							
								
								
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -788,256 +788,6 @@ jobs: | |||||||
|     secrets: |     secrets: | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|     uses: ./.github/workflows/_binary-upload.yml |     uses: ./.github/workflows/_binary-upload.yml | ||||||
|   libtorch-cuda12_9-shared-with-deps-release-build: |  | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |  | ||||||
|     needs: get-label-type |  | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" |  | ||||||
|     timeout-minutes: 360 |  | ||||||
|     env: |  | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |  | ||||||
|       PACKAGE_TYPE: libtorch |  | ||||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in |  | ||||||
|       #       favor of GPU_ARCH_VERSION |  | ||||||
|       DESIRED_CUDA: cu129 |  | ||||||
|       GPU_ARCH_VERSION: "12.9" |  | ||||||
|       GPU_ARCH_TYPE: cuda |  | ||||||
|       SKIP_ALL_TESTS: 1 |  | ||||||
|       LIBTORCH_CONFIG: release |  | ||||||
|       LIBTORCH_VARIANT: shared-with-deps |  | ||||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts |  | ||||||
|       # without this value pip does not get installed for some reason |  | ||||||
|       DESIRED_PYTHON: "3.10" |  | ||||||
|     steps: |  | ||||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally |  | ||||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the |  | ||||||
|       #       runner.temp variable, which we need. |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" |  | ||||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" |  | ||||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" |  | ||||||
|       - name: Display EC2 information |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           set -euo pipefail |  | ||||||
|           function get_ec2_metadata() { |  | ||||||
|             # Pulled from instance metadata endpoint for EC2 |  | ||||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html |  | ||||||
|             category=$1 |  | ||||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" |  | ||||||
|           } |  | ||||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" |  | ||||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" |  | ||||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" |  | ||||||
|           echo "system info $(uname -a)" |  | ||||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" |  | ||||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main |  | ||||||
|         continue-on-error: true |  | ||||||
|         with: |  | ||||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           git config --global core.longpaths true |  | ||||||
|           git config --global core.symlinks true |  | ||||||
|  |  | ||||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock |  | ||||||
|           # the directory on Windows and prevent GHA from checking out as reported |  | ||||||
|           # in https://github.com/actions/checkout/issues/1018 |  | ||||||
|           git config --global core.fsmonitor false |  | ||||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 |  | ||||||
|       - name: Enable long paths on Windows |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 |  | ||||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be |  | ||||||
|       # removed once Windows Defender is removed from the AMI |  | ||||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch |  | ||||||
|         continue-on-error: true |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore |  | ||||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure |  | ||||||
|           # that it doesn't interfere |  | ||||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore |  | ||||||
|       - name: Checkout PyTorch |  | ||||||
|         uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} |  | ||||||
|           submodules: recursive |  | ||||||
|           path: pytorch |  | ||||||
|           show-progress: false |  | ||||||
|       - name: Clean PyTorch checkout |  | ||||||
|         run: | |  | ||||||
|           # Remove any artifacts from the previous checkouts |  | ||||||
|           git clean -fxd |  | ||||||
|         working-directory: pytorch |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" |  | ||||||
|       - name: Build PyTorch binary |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" |  | ||||||
|       - uses: actions/upload-artifact@v4.4.0 |  | ||||||
|         if: always() |  | ||||||
|         with: |  | ||||||
|           name: libtorch-cuda12_9-shared-with-deps-release |  | ||||||
|           retention-days: 14 |  | ||||||
|           if-no-files-found: error |  | ||||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" |  | ||||||
|       - name: Wait until all sessions have drained |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         timeout-minutes: 120 |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 |  | ||||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\kill_active_ssh_sessions.ps1 |  | ||||||
|  |  | ||||||
|   libtorch-cuda12_9-shared-with-deps-release-test:  # Testing |  | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |  | ||||||
|     needs: |  | ||||||
|       - libtorch-cuda12_9-shared-with-deps-release-build |  | ||||||
|       - get-label-type |  | ||||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" |  | ||||||
|     timeout-minutes: 360 |  | ||||||
|     env: |  | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |  | ||||||
|       PACKAGE_TYPE: libtorch |  | ||||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in |  | ||||||
|       #       favor of GPU_ARCH_VERSION |  | ||||||
|       DESIRED_CUDA: cu129 |  | ||||||
|       GPU_ARCH_VERSION: "12.9" |  | ||||||
|       GPU_ARCH_TYPE: cuda |  | ||||||
|       SKIP_ALL_TESTS: 1 |  | ||||||
|       LIBTORCH_CONFIG: release |  | ||||||
|       LIBTORCH_VARIANT: shared-with-deps |  | ||||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts |  | ||||||
|       # without this value pip does not get installed for some reason |  | ||||||
|       DESIRED_PYTHON: "3.10" |  | ||||||
|     steps: |  | ||||||
|       - name: Display EC2 information |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           set -euo pipefail |  | ||||||
|           function get_ec2_metadata() { |  | ||||||
|             # Pulled from instance metadata endpoint for EC2 |  | ||||||
|             # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html |  | ||||||
|             category=$1 |  | ||||||
|             curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" |  | ||||||
|           } |  | ||||||
|           echo "ami-id: $(get_ec2_metadata ami-id)" |  | ||||||
|           echo "instance-id: $(get_ec2_metadata instance-id)" |  | ||||||
|           echo "instance-type: $(get_ec2_metadata instance-type)" |  | ||||||
|           echo "system info $(uname -a)" |  | ||||||
|       - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" |  | ||||||
|         uses: pytorch/test-infra/.github/actions/setup-ssh@main |  | ||||||
|         continue-on-error: true |  | ||||||
|         with: |  | ||||||
|           github-secret: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|       - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           git config --global core.longpaths true |  | ||||||
|           git config --global core.symlinks true |  | ||||||
|  |  | ||||||
|           # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock |  | ||||||
|           # the directory on Windows and prevent GHA from checking out as reported |  | ||||||
|           # in https://github.com/actions/checkout/issues/1018 |  | ||||||
|           git config --global core.fsmonitor false |  | ||||||
|       # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 |  | ||||||
|       - name: Enable long paths on Windows |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 |  | ||||||
|       # Since it's just a defensive command, the workflow should continue even the command fails. This step can be |  | ||||||
|       # removed once Windows Defender is removed from the AMI |  | ||||||
|       - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch |  | ||||||
|         continue-on-error: true |  | ||||||
|         shell: powershell |  | ||||||
|         run: | |  | ||||||
|           Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore |  | ||||||
|           # Let's both exclude the path and disable Windows Defender completely just to be sure |  | ||||||
|           # that it doesn't interfere |  | ||||||
|           Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore |  | ||||||
|       - name: Checkout PyTorch |  | ||||||
|         uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} |  | ||||||
|           submodules: recursive |  | ||||||
|           path: pytorch |  | ||||||
|           show-progress: false |  | ||||||
|       - name: Clean PyTorch checkout |  | ||||||
|         run: | |  | ||||||
|           # Remove any artifacts from the previous checkouts |  | ||||||
|           git clean -fxd |  | ||||||
|         working-directory: pytorch |  | ||||||
|       # NOTE: These environment variables are put here so that they can be applied on every job equally |  | ||||||
|       #       They are also here because setting them at a workflow level doesn't give us access to the |  | ||||||
|       #       runner.temp variable, which we need. |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" |  | ||||||
|           echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" |  | ||||||
|           echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" |  | ||||||
|       - uses: actions/download-artifact@v4.1.7 |  | ||||||
|         name: Download Build Artifacts |  | ||||||
|         with: |  | ||||||
|           name: libtorch-cuda12_9-shared-with-deps-release |  | ||||||
|           path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" |  | ||||||
|       - name: Populate binary env |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" |  | ||||||
|       - name: Test PyTorch binary |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" |  | ||||||
|       - name: Wait until all sessions have drained |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         timeout-minutes: 120 |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\wait_for_ssh_to_drain.ps1 |  | ||||||
|       - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) |  | ||||||
|         shell: powershell |  | ||||||
|         working-directory: pytorch |  | ||||||
|         if: always() |  | ||||||
|         run: | |  | ||||||
|           .github\scripts\kill_active_ssh_sessions.ps1 |  | ||||||
|   libtorch-cuda12_9-shared-with-deps-release-upload:  # Uploading |  | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |  | ||||||
|     permissions: |  | ||||||
|       id-token: write |  | ||||||
|       contents: read |  | ||||||
|     needs: libtorch-cuda12_9-shared-with-deps-release-test |  | ||||||
|     with: |  | ||||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch |  | ||||||
|       PACKAGE_TYPE: libtorch |  | ||||||
|       # TODO: This is a legacy variable that we eventually want to get rid of in |  | ||||||
|       #       favor of GPU_ARCH_VERSION |  | ||||||
|       DESIRED_CUDA: cu129 |  | ||||||
|       GPU_ARCH_VERSION: "12.9" |  | ||||||
|       GPU_ARCH_TYPE: cuda |  | ||||||
|       LIBTORCH_CONFIG: release |  | ||||||
|       LIBTORCH_VARIANT: shared-with-deps |  | ||||||
|       # This is a dummy value for libtorch to work correctly with our batch scripts |  | ||||||
|       # without this value pip does not get installed for some reason |  | ||||||
|       DESIRED_PYTHON: "3.10" |  | ||||||
|       build_name: libtorch-cuda12_9-shared-with-deps-release |  | ||||||
|     secrets: |  | ||||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|     uses: ./.github/workflows/_binary-upload.yml |  | ||||||
|   libtorch-cuda13_0-shared-with-deps-release-build: |   libtorch-cuda13_0-shared-with-deps-release-build: | ||||||
|     if: ${{ github.repository_owner == 'pytorch' }} |     if: ${{ github.repository_owner == 'pytorch' }} | ||||||
|     needs: get-label-type |     needs: get-label-type | ||||||
|  | |||||||
							
								
								
									
										1666
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1666
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -88,27 +88,27 @@ jobs: | |||||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|         ]} |         ]} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							| @ -118,9 +118,9 @@ jobs: | |||||||
|         CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" |         CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" | ||||||
|         echo "Running all other linters" |         echo "Running all other linters" | ||||||
|         if [ "$CHANGED_FILES" = '*' ]; then |         if [ "$CHANGED_FILES" = '*' ]; then | ||||||
|           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh |           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh | ||||||
|         else |         else | ||||||
|           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh |           ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh | ||||||
|         fi |         fi | ||||||
|  |  | ||||||
|   quick-checks: |   quick-checks: | ||||||
|  | |||||||
							
								
								
									
										38
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							| @ -30,9 +30,9 @@ permissions: | |||||||
|   contents: read |   contents: read | ||||||
|  |  | ||||||
| jobs: | jobs: | ||||||
|   opbenchmark-build: |   x86-opbenchmark-build: | ||||||
|     if: github.repository_owner == 'pytorch' |     if: github.repository_owner == 'pytorch' | ||||||
|     name: opbenchmark-build |     name: x86-opbenchmark-build | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|     with: |     with: | ||||||
|       build-environment: linux-jammy-py3.10-gcc11-build |       build-environment: linux-jammy-py3.10-gcc11-build | ||||||
| @ -43,12 +43,36 @@ jobs: | |||||||
|         ]} |         ]} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|   opbenchmark-test: |   x86-opbenchmark-test: | ||||||
|     name: opbenchmark-test |     name: x86-opbenchmark-test | ||||||
|     uses: ./.github/workflows/_linux-test.yml |     uses: ./.github/workflows/_linux-test.yml | ||||||
|     needs: opbenchmark-build |     needs: x86-opbenchmark-build | ||||||
|     with: |     with: | ||||||
|       build-environment: linux-jammy-py3.10-gcc11-build |       build-environment: linux-jammy-py3.10-gcc11-build | ||||||
|       docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }} |       docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }} | ||||||
|       test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }} |       test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }} | ||||||
|  |     secrets: inherit | ||||||
|  |  | ||||||
|  |   aarch64-opbenchmark-build: | ||||||
|  |     if: github.repository_owner == 'pytorch' | ||||||
|  |     name: aarch64-opbenchmark-build | ||||||
|  |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     with: | ||||||
|  |       build-environment: linux-jammy-aarch64-py3.10 | ||||||
|  |       runner: linux.arm64.m7g.4xlarge | ||||||
|  |       docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 | ||||||
|  |       test-matrix: | | ||||||
|  |         { include: [ | ||||||
|  |           { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, | ||||||
|  |         ]} | ||||||
|  |     secrets: inherit | ||||||
|  |  | ||||||
|  |   aarch64-opbenchmark-test: | ||||||
|  |     name: aarch64-opbenchmark-test | ||||||
|  |     uses: ./.github/workflows/_linux-test.yml | ||||||
|  |     needs: aarch64-opbenchmark-build | ||||||
|  |     with: | ||||||
|  |       build-environment: linux-jammy-aarch64-py3.10 | ||||||
|  |       docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }} | ||||||
|  |       test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							| @ -45,12 +45,12 @@ jobs: | |||||||
|       sync-tag: rocm-build |       sync-tag: rocm-build | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, |           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||||
|         ]} |         ]} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										63
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,63 @@ | |||||||
|  | name: rocm-navi31 | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     tags: | ||||||
|  |       - ciflow/rocm-navi31/* | ||||||
|  |   workflow_dispatch: | ||||||
|  |   schedule: | ||||||
|  |     # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. | ||||||
|  |     # Also run less frequently on weekends. | ||||||
|  |     - cron: 45 */2 * * 1-5 | ||||||
|  |     - cron: 45 4,12 * * 0,6 | ||||||
|  |  | ||||||
|  | concurrency: | ||||||
|  |   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||||
|  |   cancel-in-progress: true | ||||||
|  |  | ||||||
|  | permissions: read-all | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   target-determination: | ||||||
|  |     if: github.repository_owner == 'pytorch' | ||||||
|  |     name: before-test | ||||||
|  |     uses: ./.github/workflows/target_determination.yml | ||||||
|  |     permissions: | ||||||
|  |       id-token: write | ||||||
|  |       contents: read | ||||||
|  |  | ||||||
|  |   linux-jammy-rocm-py3_10-build: | ||||||
|  |     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||||
|  |     name: linux-jammy-rocm-py3.10 | ||||||
|  |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     with: | ||||||
|  |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|  |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||||
|  |       sync-tag: rocm-build | ||||||
|  |       test-matrix: | | ||||||
|  |         { include: [ | ||||||
|  |           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||||
|  |           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||||
|  |         ]} | ||||||
|  |     secrets: inherit | ||||||
|  |  | ||||||
|  |   linux-jammy-rocm-py3_10-test: | ||||||
|  |     permissions: | ||||||
|  |       id-token: write | ||||||
|  |       contents: read | ||||||
|  |     name: linux-jammy-rocm-py3_10 | ||||||
|  |     uses: ./.github/workflows/_rocm-test.yml | ||||||
|  |     needs: | ||||||
|  |       - linux-jammy-rocm-py3_10-build | ||||||
|  |       - target-determination | ||||||
|  |     with: | ||||||
|  |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|  |       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||||
|  |       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||||
|  |       tests-to-include: >- | ||||||
|  |          ${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs | ||||||
|  |          test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark | ||||||
|  |          inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor | ||||||
|  |          inductor/test_torchinductor inductor/test_decompose_mem_bound_mm | ||||||
|  |          inductor/test_flex_attention inductor/test_max_autotune' || '' }} | ||||||
|  |     secrets: inherit | ||||||
							
								
								
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							| @ -36,12 +36,12 @@ jobs: | |||||||
|       sync-tag: rocm-build |       sync-tag: rocm-build | ||||||
|       test-matrix: | |       test-matrix: | | ||||||
|         { include: [ |         { include: [ | ||||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" }, |           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" }, | ||||||
|           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" }, |           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" }, | ||||||
|           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" }, |           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" }, | ||||||
|           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" }, |           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" }, | ||||||
|           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" }, |           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" }, | ||||||
|           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" }, |           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" }, | ||||||
|         ]} |         ]} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
| @ -59,29 +59,3 @@ jobs: | |||||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} |       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} |       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|   linux-jammy-rocm-py3_10-gfx1100-test: |  | ||||||
|     if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} |  | ||||||
|     permissions: |  | ||||||
|       id-token: write |  | ||||||
|       contents: read |  | ||||||
|     name: linux-jammy-rocm-py3_10-gfx1100 |  | ||||||
|     uses: ./.github/workflows/_rocm-test.yml |  | ||||||
|     needs: |  | ||||||
|       - linux-jammy-rocm-py3_10-build |  | ||||||
|       - target-determination |  | ||||||
|     with: |  | ||||||
|       build-environment: linux-jammy-rocm-py3.10 |  | ||||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} |  | ||||||
|       test-matrix: | |  | ||||||
|         { include: [ |  | ||||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, |  | ||||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, |  | ||||||
|         ]} |  | ||||||
|       tests-to-include: > |  | ||||||
|          test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs |  | ||||||
|          test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark |  | ||||||
|          inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor |  | ||||||
|          inductor/test_torchinductor inductor/test_decompose_mem_bound_mm |  | ||||||
|          inductor/test_flex_attention inductor/test_max_autotune |  | ||||||
|     secrets: inherit |  | ||||||
|  | |||||||
							
								
								
									
										51
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										51
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							| @ -190,6 +190,40 @@ jobs: | |||||||
|       runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" |       runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|  |   linux-jammy-rocm-py3_10-build: | ||||||
|  |     if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} | ||||||
|  |     name: linux-jammy-rocm-py3.10 | ||||||
|  |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  |     needs: get-label-type | ||||||
|  |     with: | ||||||
|  |       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||||
|  |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|  |       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||||
|  |       sync-tag: rocm-build | ||||||
|  |       test-matrix: | | ||||||
|  |         { include: [ | ||||||
|  |           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||||
|  |           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||||
|  |         ]} | ||||||
|  |     secrets: inherit | ||||||
|  |  | ||||||
|  |   linux-jammy-rocm-py3_10-test: | ||||||
|  |     if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} | ||||||
|  |     permissions: | ||||||
|  |       id-token: write | ||||||
|  |       contents: read | ||||||
|  |     name: linux-jammy-rocm-py3.10 | ||||||
|  |     uses: ./.github/workflows/_rocm-test.yml | ||||||
|  |     needs: | ||||||
|  |       - linux-jammy-rocm-py3_10-build | ||||||
|  |       - target-determination | ||||||
|  |     with: | ||||||
|  |       build-environment: linux-jammy-rocm-py3.10 | ||||||
|  |       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||||
|  |       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||||
|  |       tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor" | ||||||
|  |     secrets: inherit | ||||||
|  |  | ||||||
|   inductor-build: |   inductor-build: | ||||||
|     name: inductor-build |     name: inductor-build | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
| @ -200,6 +234,23 @@ jobs: | |||||||
|       cuda-arch-list: '8.0' |       cuda-arch-list: '8.0' | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|  |  | ||||||
|  |   # Test cross-compiled models with Windows libs extracted from wheel | ||||||
|  |   cross-compile-linux-test: | ||||||
|  |     name: cross-compile-linux-test | ||||||
|  |     uses: ./.github/workflows/_linux-test.yml | ||||||
|  |     needs: | ||||||
|  |       - linux-jammy-cuda12_8-py3_10-gcc11-build | ||||||
|  |       - get-label-type | ||||||
|  |       - win-vs2022-cuda12_8-py3-build | ||||||
|  |     with: | ||||||
|  |       build-environment: linux-jammy-cuda12.8-py3.10-gcc11 | ||||||
|  |       docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} | ||||||
|  |       test-matrix: | | ||||||
|  |         { include: [ | ||||||
|  |           { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" }, | ||||||
|  |         ]} | ||||||
|  |     secrets: inherit | ||||||
|  |  | ||||||
|   verify-cachebench-cpu-build: |   verify-cachebench-cpu-build: | ||||||
|     name: verify-cachebench-cpu-build |     name: verify-cachebench-cpu-build | ||||||
|     uses: ./.github/workflows/_linux-build.yml |     uses: ./.github/workflows/_linux-build.yml | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -374,6 +374,7 @@ third_party/ruy/ | |||||||
| third_party/glog/ | third_party/glog/ | ||||||
|  |  | ||||||
| # Virtualenv | # Virtualenv | ||||||
|  | .venv/ | ||||||
| venv/ | venv/ | ||||||
|  |  | ||||||
| # Log files | # Log files | ||||||
|  | |||||||
| @ -209,6 +209,46 @@ command = [ | |||||||
|     '@{{PATHSFILE}}' |     '@{{PATHSFILE}}' | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | [[linter]] | ||||||
|  | code = 'PYREFLY' | ||||||
|  | include_patterns = [ | ||||||
|  |     'torch/**/*.py', | ||||||
|  |     'torch/**/*.pyi', | ||||||
|  |     'torchgen/**/*.py', | ||||||
|  |     'torchgen/**/*.pyi', | ||||||
|  |     'functorch/**/*.py', | ||||||
|  |     'functorch/**/*.pyi', | ||||||
|  | ] | ||||||
|  | exclude_patterns = [] | ||||||
|  | command = [ | ||||||
|  |     'python3', | ||||||
|  |     'tools/linter/adapters/pyrefly_linter.py', | ||||||
|  |     '--config=pyrefly.toml', | ||||||
|  | ] | ||||||
|  | init_command = [ | ||||||
|  |     'python3', | ||||||
|  |     'tools/linter/adapters/pip_init.py', | ||||||
|  |     '--dry-run={{DRYRUN}}', | ||||||
|  |     'numpy==2.1.0 ; python_version >= "3.12"', | ||||||
|  |     'expecttest==0.3.0', | ||||||
|  |     'pyrefly==0.36.2', | ||||||
|  |     'sympy==1.13.3', | ||||||
|  |     'types-requests==2.27.25', | ||||||
|  |     'types-pyyaml==6.0.2', | ||||||
|  |     'types-tabulate==0.8.8', | ||||||
|  |     'types-protobuf==5.29.1.20250403', | ||||||
|  |     'types-setuptools==79.0.0.20250422', | ||||||
|  |     'types-jinja2==2.11.9', | ||||||
|  |     'types-colorama==0.4.6', | ||||||
|  |     'filelock==3.18.0', | ||||||
|  |     'junitparser==2.1.1', | ||||||
|  |     'rich==14.1.0', | ||||||
|  |     'optree==0.17.0', | ||||||
|  |     'types-openpyxl==3.1.5.20250919', | ||||||
|  |     'types-python-dateutil==2.9.0.20251008' | ||||||
|  | ] | ||||||
|  |  | ||||||
| [[linter]] | [[linter]] | ||||||
| code = 'CLANGTIDY' | code = 'CLANGTIDY' | ||||||
| include_patterns = [ | include_patterns = [ | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							| @ -201,3 +201,17 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A | |||||||
| /torch/csrc/stable/ @janeyx99 @mikaylagawarecki | /torch/csrc/stable/ @janeyx99 @mikaylagawarecki | ||||||
| /torch/headeronly/ @janeyx99 | /torch/headeronly/ @janeyx99 | ||||||
| /torch/header_only_apis.txt @janeyx99 | /torch/header_only_apis.txt @janeyx99 | ||||||
|  |  | ||||||
|  | # FlexAttention | ||||||
|  | /torch/nn/attention/flex_attention.py @drisspg | ||||||
|  | /torch/_higher_order_ops/flex_attention.py @drisspg | ||||||
|  | /torch/_inductor/kernel/flex/ @drisspg | ||||||
|  | /torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg | ||||||
|  | /test/inductor/test_flex_attention.py @drisspg | ||||||
|  | /test/inductor/test_flex_decoding.py @drisspg | ||||||
|  |  | ||||||
|  | # Low Precision GEMMs | ||||||
|  | /aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58 | ||||||
|  | /aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58 | ||||||
|  | /aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58 | ||||||
|  | /test/test_scaled_matmul_cuda.py @drisspg @slayton58 | ||||||
|  | |||||||
| @ -313,13 +313,14 @@ IF(USE_FBGEMM_GENAI) | |||||||
|  |  | ||||||
|     # Add additional HIPCC compiler flags for performance |     # Add additional HIPCC compiler flags for performance | ||||||
|     set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS |     set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS | ||||||
|       -mllvm |  | ||||||
|       -amdgpu-coerce-illegal-types=1 |  | ||||||
|       -mllvm |       -mllvm | ||||||
|       -enable-post-misched=0 |       -enable-post-misched=0 | ||||||
|       -mllvm |       -mllvm | ||||||
|       -greedy-reverse-local-assignment=1 |       -greedy-reverse-local-assignment=1 | ||||||
|       -fhip-new-launch-api) |       -fhip-new-launch-api) | ||||||
|  |     if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0") | ||||||
|  |         list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) | ||||||
|  |       endif() | ||||||
|  |  | ||||||
|     # Only compile for gfx942 for now. |     # Only compile for gfx942 for now. | ||||||
|     # This is rather hacky, I could not figure out a clean solution :( |     # This is rather hacky, I could not figure out a clean solution :( | ||||||
|  | |||||||
| @ -229,10 +229,10 @@ private: | |||||||
|   } |   } | ||||||
|  |  | ||||||
|  |  | ||||||
|   static const uint32_t kPhilox10A = 0x9E3779B9; |   static constexpr uint32_t kPhilox10A = 0x9E3779B9; | ||||||
|   static const uint32_t kPhilox10B = 0xBB67AE85; |   static constexpr uint32_t kPhilox10B = 0xBB67AE85; | ||||||
|   static const uint32_t kPhiloxSA = 0xD2511F53; |   static constexpr uint32_t kPhiloxSA = 0xD2511F53; | ||||||
|   static const uint32_t kPhiloxSB = 0xCD9E8D57; |   static constexpr uint32_t kPhiloxSB = 0xCD9E8D57; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| typedef philox_engine Philox4_32; | typedef philox_engine Philox4_32; | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ | |||||||
| #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | ||||||
| #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | ||||||
|  | #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #include <ATen/cpu/vec/vec128/vec128_convert.h> | #include <ATen/cpu/vec/vec128/vec128_convert.h> | ||||||
|  | |||||||
							
								
								
									
										794
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										794
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,794 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/cpu/vec/intrinsics.h> | ||||||
|  | #include <ATen/cpu/vec/vec_base.h> | ||||||
|  | #include <c10/macros/Macros.h> | ||||||
|  | #include <c10/util/irange.h> | ||||||
|  |  | ||||||
|  | namespace at::vec { | ||||||
|  | // Note [CPU_CAPABILITY namespace] | ||||||
|  | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  | // This header, and all of its subheaders, will be compiled with | ||||||
|  | // different architecture flags for each supported set of vector | ||||||
|  | // intrinsics. So we need to make sure they aren't inadvertently | ||||||
|  | // linked together. We do this by declaring objects in an `inline | ||||||
|  | // namespace` which changes the name mangling, but can still be | ||||||
|  | // accessed as `at::vec`. | ||||||
|  | inline namespace CPU_CAPABILITY { | ||||||
|  |  | ||||||
|  | #define VEC_INT_NEON_TEMPLATE(vl, bit)                                        \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {};  \ | ||||||
|  |                                                                               \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   class Vectorized<int##bit##_t> {                                            \ | ||||||
|  |     using neon_type = int##bit##x##vl##_t;                                    \ | ||||||
|  |                                                                               \ | ||||||
|  |    private:                                                                   \ | ||||||
|  |     neon_type values;                                                         \ | ||||||
|  |                                                                               \ | ||||||
|  |    public:                                                                    \ | ||||||
|  |     using value_type = int##bit##_t;                                          \ | ||||||
|  |     using size_type = int;                                                    \ | ||||||
|  |     static constexpr size_type size() {                                       \ | ||||||
|  |       return vl;                                                              \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized() {                                                            \ | ||||||
|  |       values = vdupq_n_s##bit(0);                                             \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized(neon_type v) : values(v) {}                                    \ | ||||||
|  |     Vectorized(int##bit##_t val);                                             \ | ||||||
|  |     template <                                                                \ | ||||||
|  |         typename... Args,                                                     \ | ||||||
|  |         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||||
|  |     Vectorized(Args... vals) {                                                \ | ||||||
|  |       __at_align__ int##bit##_t buffer[size()] = {vals...};                   \ | ||||||
|  |       values = vld1q_s##bit(buffer);                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     operator neon_type() const {                                              \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     static Vectorized<int##bit##_t> loadu(                                    \ | ||||||
|  |         const void* ptr,                                                      \ | ||||||
|  |         int64_t count = size());                                              \ | ||||||
|  |     void store(void* ptr, int64_t count = size()) const;                      \ | ||||||
|  |     template <int64_t mask>                                                   \ | ||||||
|  |     static Vectorized<int##bit##_t> blend(                                    \ | ||||||
|  |         const Vectorized<int##bit##_t>& a,                                    \ | ||||||
|  |         const Vectorized<int##bit##_t>& b);                                   \ | ||||||
|  |     static Vectorized<int##bit##_t> blendv(                                   \ | ||||||
|  |         const Vectorized<int##bit##_t>& a,                                    \ | ||||||
|  |         const Vectorized<int##bit##_t>& b,                                    \ | ||||||
|  |         const Vectorized<int##bit##_t>& mask_) {                              \ | ||||||
|  |       return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     template <typename step_t>                                                \ | ||||||
|  |     static Vectorized<int##bit##_t> arange(                                   \ | ||||||
|  |         value_type base = 0,                                                  \ | ||||||
|  |         step_t step = static_cast<step_t>(1));                                \ | ||||||
|  |     static Vectorized<int##bit##_t> set(                                      \ | ||||||
|  |         const Vectorized<int##bit##_t>& a,                                    \ | ||||||
|  |         const Vectorized<int##bit##_t>& b,                                    \ | ||||||
|  |         int64_t count = size());                                              \ | ||||||
|  |     const int##bit##_t& operator[](int idx) const = delete;                   \ | ||||||
|  |     int##bit##_t& operator[](int idx) = delete;                               \ | ||||||
|  |     Vectorized<int##bit##_t> abs() const {                                    \ | ||||||
|  |       return vabsq_s##bit(values);                                            \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> real() const {                                   \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> imag() const {                                   \ | ||||||
|  |       return vdupq_n_s##bit(0);                                               \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> conj() const {                                   \ | ||||||
|  |       return values;                                                          \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> neg() const {                                    \ | ||||||
|  |       return vnegq_s##bit(values);                                            \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     int##bit##_t reduce_add() const {                                         \ | ||||||
|  |       return vaddvq_s##bit(values);                                           \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     int##bit##_t reduce_max() const;                                          \ | ||||||
|  |     Vectorized<int##bit##_t> operator==(                                      \ | ||||||
|  |         const Vectorized<int##bit##_t>& other) const {                        \ | ||||||
|  |       return Vectorized<value_type>(                                          \ | ||||||
|  |           vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> operator!=(                                      \ | ||||||
|  |         const Vectorized<int##bit##_t>& other) const;                         \ | ||||||
|  |     Vectorized<int##bit##_t> operator<(                                       \ | ||||||
|  |         const Vectorized<int##bit##_t>& other) const {                        \ | ||||||
|  |       return Vectorized<value_type>(                                          \ | ||||||
|  |           vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> operator<=(                                      \ | ||||||
|  |         const Vectorized<int##bit##_t>& other) const {                        \ | ||||||
|  |       return Vectorized<value_type>(                                          \ | ||||||
|  |           vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> operator>(                                       \ | ||||||
|  |         const Vectorized<int##bit##_t>& other) const {                        \ | ||||||
|  |       return Vectorized<value_type>(                                          \ | ||||||
|  |           vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> operator>=(                                      \ | ||||||
|  |         const Vectorized<int##bit##_t>& other) const {                        \ | ||||||
|  |       return Vectorized<value_type>(                                          \ | ||||||
|  |           vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |     Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \ | ||||||
|  |     Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \ | ||||||
|  |     Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \ | ||||||
|  |     Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \ | ||||||
|  |     Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \ | ||||||
|  |     Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \ | ||||||
|  |   };                                                                          \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<int##bit##_t> inline operator+(                                  \ | ||||||
|  |       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||||
|  |     return vaddq_s##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<int##bit##_t> inline operator-(                                  \ | ||||||
|  |       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||||
|  |     return vsubq_s##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<int##bit##_t> inline operator&(                                  \ | ||||||
|  |       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||||
|  |     return vandq_s##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<int##bit##_t> inline operator|(                                  \ | ||||||
|  |       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||||
|  |     return vorrq_s##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   template <>                                                                 \ | ||||||
|  |   Vectorized<int##bit##_t> inline operator^(                                  \ | ||||||
|  |       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||||
|  |     return veorq_s##bit(a, b);                                                \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq(               \ | ||||||
|  |       const Vectorized<int##bit##_t>& other) const {                          \ | ||||||
|  |     return (*this == other) & Vectorized<int##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne(               \ | ||||||
|  |       const Vectorized<int##bit##_t>& other) const {                          \ | ||||||
|  |     return (*this != other) & Vectorized<int##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt(               \ | ||||||
|  |       const Vectorized<int##bit##_t>& other) const {                          \ | ||||||
|  |     return (*this > other) & Vectorized<int##bit##_t>(1);                     \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge(               \ | ||||||
|  |       const Vectorized<int##bit##_t>& other) const {                          \ | ||||||
|  |     return (*this >= other) & Vectorized<int##bit##_t>(1);                    \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt(               \ | ||||||
|  |       const Vectorized<int##bit##_t>& other) const {                          \ | ||||||
|  |     return (*this < other) & Vectorized<int##bit##_t>(1);                     \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le(               \ | ||||||
|  |       const Vectorized<int##bit##_t>& other) const {                          \ | ||||||
|  |     return (*this <= other) & Vectorized<int##bit##_t>(1);                    \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | VEC_INT_NEON_TEMPLATE(2, 64) | ||||||
|  | VEC_INT_NEON_TEMPLATE(4, 32) | ||||||
|  | VEC_INT_NEON_TEMPLATE(8, 16) | ||||||
|  | VEC_INT_NEON_TEMPLATE(16, 8) | ||||||
|  |  | ||||||
|  | inline int32_t Vectorized<int32_t>::reduce_max() const { | ||||||
|  |   return vmaxvq_s32(values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline int16_t Vectorized<int16_t>::reduce_max() const { | ||||||
|  |   return vmaxvq_s16(values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline int8_t Vectorized<int8_t>::reduce_max() const { | ||||||
|  |   return vmaxvq_s8(values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline operator*( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   return vmulq_s32(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline operator*( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   return vmulq_s16(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline operator*( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   return vmulq_s8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) { | ||||||
|  |   int64x2_t val = a; | ||||||
|  |   return ~val; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) { | ||||||
|  |   return vmvnq_s32(a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) { | ||||||
|  |   return vmvnq_s16(a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) { | ||||||
|  |   return vmvnq_s8(a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int64_t> Vectorized<int64_t>::operator!=( | ||||||
|  |     const Vectorized<int64_t>& other) const { | ||||||
|  |   return ~(*this == other); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int32_t> Vectorized<int32_t>::operator!=( | ||||||
|  |     const Vectorized<int32_t>& other) const { | ||||||
|  |   return ~(*this == other); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int16_t> Vectorized<int16_t>::operator!=( | ||||||
|  |     const Vectorized<int16_t>& other) const { | ||||||
|  |   return ~(*this == other); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int8_t> Vectorized<int8_t>::operator!=( | ||||||
|  |     const Vectorized<int8_t>& other) const { | ||||||
|  |   return ~(*this == other); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline minimum( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   return vminq_s32(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline minimum( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   return vminq_s16(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline minimum( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   return vminq_s8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline maximum( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   return vmaxq_s32(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline maximum( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   return vmaxq_s16(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline maximum( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   return vmaxq_s8(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int64_t mask> | ||||||
|  | Vectorized<int64_t> Vectorized<int64_t>::blend( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||||
|  |   // in 'mask' is set, 0 otherwise. | ||||||
|  |   uint64x2_t maskArray = { | ||||||
|  |       (mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||||
|  |       (mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||||
|  |   // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |   return vbslq_s64(maskArray, b.values, a.values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int64_t mask> | ||||||
|  | Vectorized<int32_t> Vectorized<int32_t>::blend( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||||
|  |   // in 'mask' is set, 0 otherwise. | ||||||
|  |   uint32x4_t maskArray = { | ||||||
|  |       (mask & 1LL) ? 0xFFFFFFFF : 0, | ||||||
|  |       (mask & 2LL) ? 0xFFFFFFFF : 0, | ||||||
|  |       (mask & 4LL) ? 0xFFFFFFFF : 0, | ||||||
|  |       (mask & 8LL) ? 0xFFFFFFFF : 0}; | ||||||
|  |   // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |   return vbslq_s32(maskArray, b.values, a.values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int64_t mask> | ||||||
|  | Vectorized<int16_t> Vectorized<int16_t>::blend( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||||
|  |   // in 'mask' is set, 0 otherwise. | ||||||
|  |   uint16x8_t maskArray = { | ||||||
|  |       (mask & 1LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 2LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 4LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 8LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 16LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 32LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 64LL) ? 0xFFFF : 0, | ||||||
|  |       (mask & 128LL) ? 0xFFFF : 0}; | ||||||
|  |   // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |   return vbslq_s16(maskArray, b.values, a.values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int64_t mask> | ||||||
|  | Vectorized<int8_t> Vectorized<int8_t>::blend( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||||
|  |   // in 'mask' is set, 0 otherwise. | ||||||
|  |   uint8x16_t maskArray = { | ||||||
|  |       (mask & 1LL) ? 0xFF : 0, | ||||||
|  |       (mask & 2LL) ? 0xFF : 0, | ||||||
|  |       (mask & 4LL) ? 0xFF : 0, | ||||||
|  |       (mask & 8LL) ? 0xFF : 0, | ||||||
|  |       (mask & 16LL) ? 0xFF : 0, | ||||||
|  |       (mask & 32LL) ? 0xFF : 0, | ||||||
|  |       (mask & 64LL) ? 0xFF : 0, | ||||||
|  |       (mask & 128LL) ? 0xFF : 0, | ||||||
|  |       (mask & 256LL) ? 0xFF : 0, | ||||||
|  |       (mask & 512LL) ? 0xFF : 0, | ||||||
|  |       (mask & 1024LL) ? 0xFF : 0, | ||||||
|  |       (mask & 2048LL) ? 0xFF : 0, | ||||||
|  |       (mask & 4096LL) ? 0xFF : 0, | ||||||
|  |       (mask & 8192LL) ? 0xFF : 0, | ||||||
|  |       (mask & 16384LL) ? 0xFF : 0, | ||||||
|  |       (mask & 32768LL) ? 0xFF : 0}; | ||||||
|  |   // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |   return vbslq_s8(maskArray, b.values, a.values); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define VEC_INT_NEON_OPS(vl, bit)                                             \ | ||||||
|  |   inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) {             \ | ||||||
|  |     values = vdupq_n_s##bit(val);                                             \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu(            \ | ||||||
|  |       const void* ptr, int64_t count) {                                       \ | ||||||
|  |     if (count == size()) {                                                    \ | ||||||
|  |       return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr));        \ | ||||||
|  |     } else {                                                                  \ | ||||||
|  |       __at_align__ int##bit##_t tmp_values[size()];                           \ | ||||||
|  |       for (const auto i : c10::irange(size())) {                              \ | ||||||
|  |         tmp_values[i] = 0;                                                    \ | ||||||
|  |       }                                                                       \ | ||||||
|  |       std::memcpy(                                                            \ | ||||||
|  |           tmp_values,                                                         \ | ||||||
|  |           reinterpret_cast<const int##bit##_t*>(ptr),                         \ | ||||||
|  |           count * sizeof(int##bit##_t));                                      \ | ||||||
|  |       return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \ | ||||||
|  |     }                                                                         \ | ||||||
|  |   }                                                                           \ | ||||||
|  |   inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count)       \ | ||||||
|  |       const {                                                                 \ | ||||||
|  |     if (count == size()) {                                                    \ | ||||||
|  |       vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values);             \ | ||||||
|  |     } else {                                                                  \ | ||||||
|  |       int##bit##_t tmp_values[size()];                                        \ | ||||||
|  |       vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values);      \ | ||||||
|  |       std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t));             \ | ||||||
|  |     }                                                                         \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | VEC_INT_NEON_OPS(2, 64) | ||||||
|  | VEC_INT_NEON_OPS(4, 32) | ||||||
|  | VEC_INT_NEON_OPS(8, 16) | ||||||
|  | VEC_INT_NEON_OPS(16, 8) | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline operator*( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   int64x2_t x = a; | ||||||
|  |   int64x2_t y = b; | ||||||
|  |   return x * y; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline operator/( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   int64x2_t x = a; | ||||||
|  |   int64x2_t y = b; | ||||||
|  |   return x / y; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline operator/( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   int32x4_t x = a; | ||||||
|  |   int32x4_t y = b; | ||||||
|  |   return x / y; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline int64_t Vectorized<int64_t>::reduce_max() const { | ||||||
|  |   return std::max(values[0], values[1]); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline minimum( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   int64x2_t x = a; | ||||||
|  |   int64x2_t y = b; | ||||||
|  |   return {std::min(x[0], y[0]), std::min(x[1], y[1])}; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline maximum( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   int64x2_t x = a; | ||||||
|  |   int64x2_t y = b; | ||||||
|  |   return {std::max(x[0], y[0]), std::max(x[1], y[1])}; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename step_t> | ||||||
|  | inline Vectorized<int64_t> Vectorized<int64_t>::arange( | ||||||
|  |     int64_t base, | ||||||
|  |     step_t step) { | ||||||
|  |   const Vectorized<int64_t> base_vec(base); | ||||||
|  |   const Vectorized<int64_t> step_vec(step); | ||||||
|  |   const int64x2_t step_sizes = {0, 1}; | ||||||
|  |   return base_vec.values + step_sizes * step_vec.values; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename step_t> | ||||||
|  | inline Vectorized<int32_t> Vectorized<int32_t>::arange( | ||||||
|  |     int32_t base, | ||||||
|  |     step_t step) { | ||||||
|  |   const Vectorized<int32_t> base_vec(base); | ||||||
|  |   const Vectorized<int32_t> step_vec(step); | ||||||
|  |   const int32x4_t step_sizes = {0, 1, 2, 3}; | ||||||
|  |   return vmlaq_s32(base_vec, step_sizes, step_vec); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename step_t> | ||||||
|  | inline Vectorized<int16_t> Vectorized<int16_t>::arange( | ||||||
|  |     int16_t base, | ||||||
|  |     step_t step) { | ||||||
|  |   const Vectorized<int16_t> base_vec(base); | ||||||
|  |   const Vectorized<int16_t> step_vec(step); | ||||||
|  |   const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7}; | ||||||
|  |   return vmlaq_s16(base_vec, step_sizes, step_vec); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename step_t> | ||||||
|  | inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) { | ||||||
|  |   const Vectorized<int8_t> base_vec(base); | ||||||
|  |   const Vectorized<int8_t> step_vec(step); | ||||||
|  |   const int8x16_t step_sizes = { | ||||||
|  |       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||||
|  |   return vmlaq_s8(base_vec, step_sizes, step_vec); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline operator>>( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   int64x2_t x = a; | ||||||
|  |   int64x2_t y = b; | ||||||
|  |   uint64x2_t u = vreinterpretq_u64_s64(y); | ||||||
|  |   uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)}; | ||||||
|  |   return x >> vreinterpretq_s64_u64(z); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline operator>>( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   int32x4_t x = a; | ||||||
|  |   int32x4_t y = b; | ||||||
|  |   uint32x4_t bound = vdupq_n_u32(31); | ||||||
|  |   uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); | ||||||
|  |   return x >> vreinterpretq_s32_u32(z); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline operator>>( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   int16x8_t x = a; | ||||||
|  |   int16x8_t y = b; | ||||||
|  |   uint16x8_t bound = vdupq_n_u16(15); | ||||||
|  |   uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); | ||||||
|  |   return x >> vreinterpretq_s16_u16(z); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline operator>>( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   int8x16_t x = a; | ||||||
|  |   int8x16_t y = b; | ||||||
|  |   uint8x16_t bound = vdupq_n_u8(7); | ||||||
|  |   int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); | ||||||
|  |   return x >> z; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline operator<<( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b) { | ||||||
|  |   int64x2_t y = b; | ||||||
|  |   uint64x2_t u = vreinterpretq_u64_s64(y); | ||||||
|  |   uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)}; | ||||||
|  |   return vshlq_s64(a, vreinterpretq_s64_u64(z)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline operator<<( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b) { | ||||||
|  |   int32x4_t y = b; | ||||||
|  |   uint32x4_t bound = vdupq_n_u32(32); | ||||||
|  |   uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); | ||||||
|  |   return vshlq_s32(a, vreinterpretq_s32_u32(z)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline operator<<( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   int16x8_t y = b; | ||||||
|  |   uint16x8_t bound = vdupq_n_u16(16); | ||||||
|  |   uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); | ||||||
|  |   return vshlq_s16(a, vreinterpretq_s16_u16(z)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline operator<<( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   int8x16_t y = b; | ||||||
|  |   uint8x16_t bound = vdupq_n_u8(8); | ||||||
|  |   int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); | ||||||
|  |   return vshlq_s8(a, z); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int64_t> Vectorized<int64_t>::set( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& b, | ||||||
|  |     int64_t count) { | ||||||
|  |   if (count == 0) { | ||||||
|  |     return a; | ||||||
|  |   } else if (count >= 2) { | ||||||
|  |     return b; | ||||||
|  |   } else { | ||||||
|  |     int64x2_t c = {b.values[0], a.values[1]}; | ||||||
|  |     return c; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int32_t> Vectorized<int32_t>::set( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& b, | ||||||
|  |     int64_t count) { | ||||||
|  |   if (count == 0) { | ||||||
|  |     return a; | ||||||
|  |   } else if (count >= 4) { | ||||||
|  |     return b; | ||||||
|  |   } else { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint32x4_t maskArray = { | ||||||
|  |         (count >= 1LL) ? 0xFFFFFFFF : 0, | ||||||
|  |         (count >= 2LL) ? 0xFFFFFFFF : 0, | ||||||
|  |         (count >= 3LL) ? 0xFFFFFFFF : 0, | ||||||
|  |         0}; | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_s32(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int16_t> Vectorized<int16_t>::set( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b, | ||||||
|  |     int64_t count) { | ||||||
|  |   if (count == 0) { | ||||||
|  |     return a; | ||||||
|  |   } else if (count >= 8) { | ||||||
|  |     return b; | ||||||
|  |   } else { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint16x8_t maskArray = { | ||||||
|  |         static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0), | ||||||
|  |         static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0), | ||||||
|  |         static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0), | ||||||
|  |         static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0), | ||||||
|  |         static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0), | ||||||
|  |         static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0), | ||||||
|  |         static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0), | ||||||
|  |         0}; | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_s16(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline Vectorized<int8_t> Vectorized<int8_t>::set( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b, | ||||||
|  |     int64_t count) { | ||||||
|  |   if (count == 0) { | ||||||
|  |     return a; | ||||||
|  |   } else if (count >= 16) { | ||||||
|  |     return b; | ||||||
|  |   } else { | ||||||
|  |     // Build an array of flags: each bit of element is 1 if the corresponding | ||||||
|  |     // bit in 'mask' is set, 0 otherwise. | ||||||
|  |     uint8x16_t maskArray = { | ||||||
|  |         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||||
|  |         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||||
|  |         0}; | ||||||
|  |  | ||||||
|  |     // Use BSL to select elements from b where the mask is 1, else from a | ||||||
|  |     return vbslq_s8(maskArray, b.values, a.values); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline operator/( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& b) { | ||||||
|  |   Vectorized<int32_t> highBitsA = vmovl_high_s16(a); | ||||||
|  |   Vectorized<int32_t> highBitsB = vmovl_high_s16(b); | ||||||
|  |   Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a)); | ||||||
|  |   Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b)); | ||||||
|  |   int32x4_t highBitsResult = highBitsA / highBitsB; | ||||||
|  |   int32x4_t lowBitsResult = lowBitsA / lowBitsB; | ||||||
|  |   return vuzp1q_s16( | ||||||
|  |       vreinterpretq_s16_s32(lowBitsResult), | ||||||
|  |       vreinterpretq_s16_s32(highBitsResult)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline operator/( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& b) { | ||||||
|  |   Vectorized<int16_t> highBitsA = vmovl_high_s8(a); | ||||||
|  |   Vectorized<int16_t> highBitsB = vmovl_high_s8(b); | ||||||
|  |   Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a)); | ||||||
|  |   Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b)); | ||||||
|  |   int16x8_t highBitsResult = highBitsA / highBitsB; | ||||||
|  |   int16x8_t lowBitsResult = lowBitsA / lowBitsB; | ||||||
|  |   return vuzp1q_s8( | ||||||
|  |       vreinterpretq_s8_s16(lowBitsResult), | ||||||
|  |       vreinterpretq_s8_s16(highBitsResult)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline clamp( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& min, | ||||||
|  |     const Vectorized<int64_t>& max) { | ||||||
|  |   return minimum(max, maximum(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline clamp( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& min, | ||||||
|  |     const Vectorized<int32_t>& max) { | ||||||
|  |   return minimum(max, maximum(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline clamp( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& min, | ||||||
|  |     const Vectorized<int16_t>& max) { | ||||||
|  |   return minimum(max, maximum(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline clamp( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& min, | ||||||
|  |     const Vectorized<int8_t>& max) { | ||||||
|  |   return minimum(max, maximum(min, a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline clamp_max( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& max) { | ||||||
|  |   return minimum(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline clamp_max( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& max) { | ||||||
|  |   return minimum(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline clamp_max( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& max) { | ||||||
|  |   return minimum(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline clamp_max( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& max) { | ||||||
|  |   return minimum(max, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int64_t> inline clamp_min( | ||||||
|  |     const Vectorized<int64_t>& a, | ||||||
|  |     const Vectorized<int64_t>& min) { | ||||||
|  |   return maximum(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int32_t> inline clamp_min( | ||||||
|  |     const Vectorized<int32_t>& a, | ||||||
|  |     const Vectorized<int32_t>& min) { | ||||||
|  |   return maximum(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int16_t> inline clamp_min( | ||||||
|  |     const Vectorized<int16_t>& a, | ||||||
|  |     const Vectorized<int16_t>& min) { | ||||||
|  |   return maximum(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | Vectorized<int8_t> inline clamp_min( | ||||||
|  |     const Vectorized<int8_t>& a, | ||||||
|  |     const Vectorized<int8_t>& min) { | ||||||
|  |   return maximum(min, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace CPU_CAPABILITY | ||||||
|  | } // namespace at::vec | ||||||
| @ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum( | |||||||
| #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | ||||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||||
|     at::vec::Vectorized<int8_t> src) { |     at::vec::Vectorized<int8_t> src) { | ||||||
|   auto s8x8 = vld1_s8(src.operator const int8_t*()); |   auto s8x8 = vget_low_s8(src); | ||||||
|   auto s16x8 = vmovl_s8(s8x8); |   auto s16x8 = vmovl_s8(s8x8); | ||||||
|  |  | ||||||
|   auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); |   auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); | ||||||
| @ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | |||||||
|  |  | ||||||
| Vectorized<float> inline convert_int8_half_register_to_float( | Vectorized<float> inline convert_int8_half_register_to_float( | ||||||
|     at::vec::Vectorized<int8_t> src) { |     at::vec::Vectorized<int8_t> src) { | ||||||
|   auto s8x8 = vld1_s8(src.operator const int8_t*()); |   auto s8x8 = vget_low_s8(src); | ||||||
|   auto s16x8 = vmovl_s8(s8x8); |   auto s16x8 = vmovl_s8(s8x8); | ||||||
|  |  | ||||||
|   auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); |   auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); | ||||||
|  | |||||||
| @ -16,6 +16,8 @@ | |||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
| #include <c10/core/ScalarType.h> | #include <c10/core/ScalarType.h> | ||||||
|  |  | ||||||
|  | #include <ATen/cuda/detail/BLASConstants.h> | ||||||
|  |  | ||||||
| #ifdef USE_ROCM | #ifdef USE_ROCM | ||||||
| #include <c10/cuda/CUDAStream.h> | #include <c10/cuda/CUDAStream.h> | ||||||
| #include <hipblaslt/hipblaslt-ext.hpp> | #include <hipblaslt/hipblaslt-ext.hpp> | ||||||
| @ -1954,13 +1956,15 @@ void scaled_gemm( | |||||||
|     const void *result_scale_ptr, |     const void *result_scale_ptr, | ||||||
|     int64_t result_ld, |     int64_t result_ld, | ||||||
|     ScalarType result_dtype, |     ScalarType result_dtype, | ||||||
|     bool use_fast_accum) { |     bool use_fast_accum, | ||||||
|  |     const std::optional<Tensor>& alpha) { | ||||||
|   // Note: see `cublasCommonArgs` for various non-intuitive manupulations |   // Note: see `cublasCommonArgs` for various non-intuitive manupulations | ||||||
|   // of input arguments to this function. |   // of input arguments to this function. | ||||||
|   const auto computeType = CUBLAS_COMPUTE_32F; |   const auto computeType = CUBLAS_COMPUTE_32F; | ||||||
|   const auto scaleType = CUDA_R_32F; |   const auto scaleType = CUDA_R_32F; | ||||||
|   const float alpha_val = 1.0; |   // Note: alpha_val may change later depending on user-passed argument | ||||||
|   const float beta_val = 0.0; |   float alpha_val = 1.0; | ||||||
|  |   float beta_val = 0.0; | ||||||
|   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); |   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); | ||||||
|   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); |   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); | ||||||
|   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); |   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); | ||||||
| @ -2031,6 +2035,33 @@ void scaled_gemm( | |||||||
|     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS); |     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS); | ||||||
|     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); |     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   // Handle user-passed alpha | ||||||
|  |   float *alpha_ptr = &alpha_val; | ||||||
|  |   float *beta_ptr = &beta_val; | ||||||
|  |  | ||||||
|  |   if (alpha.has_value()) { | ||||||
|  |     auto& a = alpha.value(); | ||||||
|  |  | ||||||
|  |     // if device-tensor | ||||||
|  |     if (a.is_cuda()) { | ||||||
|  |       // NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be | ||||||
|  |       //       valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus | ||||||
|  |       //       we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically | ||||||
|  |       //       managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory | ||||||
|  |       //       for beta respectively. | ||||||
|  |       float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr(); | ||||||
|  |       at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat)); | ||||||
|  |       user_alpha.copy_(a); | ||||||
|  |       // Tell cublasLt we're using device-side pointers for alpha/beta | ||||||
|  |       auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; | ||||||
|  |       computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode); | ||||||
|  |       alpha_ptr = user_alpha.data_ptr<float>(); | ||||||
|  |       beta_ptr = at::cuda::detail::get_cublas_device_zero(); | ||||||
|  |     } else { | ||||||
|  |       alpha_val = a.item<float>(); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|     // For other data types, use the get_scale_mode function based on scaling type |     // For other data types, use the get_scale_mode function based on scaling type | ||||||
|     // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt, |     // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt, | ||||||
|     // but we must invoke get_scale_mode anyways to trigger the version checks. |     // but we must invoke get_scale_mode anyways to trigger the version checks. | ||||||
| @ -2048,6 +2079,7 @@ void scaled_gemm( | |||||||
|   cublasLtMatmulHeuristicResult_t heuristicResult = {}; |   cublasLtMatmulHeuristicResult_t heuristicResult = {}; | ||||||
|   int returnedResult = 0; |   int returnedResult = 0; | ||||||
|   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); |   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); | ||||||
|  |  | ||||||
|   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( |   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( | ||||||
|       ltHandle, |       ltHandle, | ||||||
|       computeDesc.descriptor(), |       computeDesc.descriptor(), | ||||||
| @ -2088,10 +2120,10 @@ void scaled_gemm( | |||||||
|         auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( |         auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( | ||||||
|                 ltHandle, |                 ltHandle, | ||||||
|                 computeDesc.descriptor(), |                 computeDesc.descriptor(), | ||||||
|                 &alpha_val, |                 alpha_ptr, | ||||||
|                 Adesc.descriptor(), |                 Adesc.descriptor(), | ||||||
|                 Bdesc.descriptor(), |                 Bdesc.descriptor(), | ||||||
|                 &beta_val, |                 beta_ptr, | ||||||
|                 Cdesc.descriptor(), |                 Cdesc.descriptor(), | ||||||
|                 Ddesc.descriptor(), |                 Ddesc.descriptor(), | ||||||
|                 all_algos[i].algo, |                 all_algos[i].algo, | ||||||
| @ -2110,17 +2142,14 @@ void scaled_gemm( | |||||||
|   cublasStatus_t cublasStatus = cublasLtMatmul( |   cublasStatus_t cublasStatus = cublasLtMatmul( | ||||||
|       ltHandle, |       ltHandle, | ||||||
|       computeDesc.descriptor(), |       computeDesc.descriptor(), | ||||||
|       &alpha_val, |       alpha_ptr, | ||||||
|       mat1_ptr, |       mat1_ptr, | ||||||
|       Adesc.descriptor(), |       Adesc.descriptor(), | ||||||
|       mat2_ptr, |       mat2_ptr, | ||||||
|       Bdesc.descriptor(), |       Bdesc.descriptor(), | ||||||
|       &beta_val, |       beta_ptr, | ||||||
| #ifdef USE_ROCM |       // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either | ||||||
|       result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr |       result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr | ||||||
| #else |  | ||||||
|       nullptr, |  | ||||||
| #endif // ifdef USE_ROCM |  | ||||||
|       Cdesc.descriptor(), |       Cdesc.descriptor(), | ||||||
|       result_ptr, |       result_ptr, | ||||||
|       Ddesc.descriptor(), |       Ddesc.descriptor(), | ||||||
|  | |||||||
| @ -161,7 +161,8 @@ void scaled_gemm( | |||||||
|     const void* result_scale_ptr, |     const void* result_scale_ptr, | ||||||
|     int64_t result_ld, |     int64_t result_ld, | ||||||
|     ScalarType result_dtype, |     ScalarType result_dtype, | ||||||
|     bool use_fast_accum); |     bool use_fast_accum, | ||||||
|  |     const std::optional<Tensor>& alpha); | ||||||
|  |  | ||||||
| #define CUDABLAS_BGEMM_ARGTYPES(Dtype)  CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) | #define CUDABLAS_BGEMM_ARGTYPES(Dtype)  CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) | ||||||
|  |  | ||||||
|  | |||||||
| @ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() { | |||||||
|  */ |  */ | ||||||
| c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | ||||||
|   // The RNG state comprises the seed, and an offset used for Philox. |   // The RNG state comprises the seed, and an offset used for Philox. | ||||||
|   static const size_t seed_size = sizeof(uint64_t); |   constexpr size_t seed_size = sizeof(uint64_t); | ||||||
|   static const size_t offset_size = sizeof(int64_t); |   constexpr size_t offset_size = sizeof(int64_t); | ||||||
|   static const size_t total_size = seed_size + offset_size; |   constexpr size_t total_size = seed_size + offset_size; | ||||||
|  |  | ||||||
|   auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); |   auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); | ||||||
|   auto rng_state = state_tensor.data_ptr<uint8_t>(); |   auto rng_state = state_tensor.data_ptr<uint8_t>(); | ||||||
| @ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | |||||||
|  * and size of the internal state. |  * and size of the internal state. | ||||||
|  */ |  */ | ||||||
| void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||||||
|   static const size_t seed_size = sizeof(uint64_t); |   constexpr size_t seed_size = sizeof(uint64_t); | ||||||
|   static const size_t offset_size = sizeof(int64_t); |   constexpr size_t offset_size = sizeof(int64_t); | ||||||
|   static const size_t total_size = seed_size + offset_size; |   constexpr size_t total_size = seed_size + offset_size; | ||||||
|  |  | ||||||
|   detail::check_rng_state(new_state); |   detail::check_rng_state(new_state); | ||||||
|  |  | ||||||
|  | |||||||
| @ -183,11 +183,6 @@ struct CUDACachingHostAllocatorImpl | |||||||
|     return true; |     return true; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool pinned_use_background_threads() override { |  | ||||||
|     return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: |  | ||||||
|         pinned_use_background_threads(); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   EventPool::Event create_event_internal(DeviceIndex idx) { |   EventPool::Event create_event_internal(DeviceIndex idx) { | ||||||
|     // Leak the event pool to avoid shutdown issue. |     // Leak the event pool to avoid shutdown issue. | ||||||
|     static auto* event_pool = new EventPool(); |     static auto* event_pool = new EventPool(); | ||||||
|  | |||||||
| @ -177,7 +177,6 @@ inline void segmented_sort_pairs( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() |  | ||||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> | template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> | ||||||
| inline void unique_by_key( | inline void unique_by_key( | ||||||
|   KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, |   KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, | ||||||
| @ -193,7 +192,6 @@ inline void unique_by_key( | |||||||
|   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, |   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, | ||||||
|     keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); |     keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); | ||||||
| } | } | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace impl { | namespace impl { | ||||||
|  |  | ||||||
| @ -579,7 +577,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|  |  | ||||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> | template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> | ||||||
| inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { | inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { | ||||||
| @ -607,7 +604,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> | template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> | ||||||
| void unique(InputIteratorT input, OutputIteratorT output, | void unique(InputIteratorT input, OutputIteratorT output, | ||||||
|  | |||||||
| @ -28,22 +28,6 @@ | |||||||
| #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| // cub support for UniqueByKey is added to cub 1.16 in: |  | ||||||
| // https://github.com/NVIDIA/cub/pull/405 |  | ||||||
| #if CUB_VERSION >= 101600 |  | ||||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() true |  | ||||||
| #else |  | ||||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() false |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // cub support for scan by key is added to cub 1.15 |  | ||||||
| // in https://github.com/NVIDIA/cub/pull/376 |  | ||||||
| #if CUB_VERSION >= 101500 |  | ||||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 1 |  | ||||||
| #else |  | ||||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 0 |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // cub support for cub::FutureValue is added to cub 1.15 in: | // cub support for cub::FutureValue is added to cub 1.15 in: | ||||||
| // https://github.com/NVIDIA/cub/pull/305 | // https://github.com/NVIDIA/cub/pull/305 | ||||||
| #if CUB_VERSION >= 101500 | #if CUB_VERSION >= 101500 | ||||||
|  | |||||||
							
								
								
									
										54
									
								
								aten/src/ATen/cuda/detail/BLASConstants.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								aten/src/ATen/cuda/detail/BLASConstants.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | |||||||
|  | #include <ATen/Functions.h> | ||||||
|  | #include <ATen/Tensor.h> | ||||||
|  | #include <ATen/cuda/Exceptions.h> | ||||||
|  |  | ||||||
|  | #include <mutex> | ||||||
|  |  | ||||||
|  | namespace at { | ||||||
|  | namespace cuda { | ||||||
|  | namespace detail { | ||||||
|  |  | ||||||
|  | __device__ __constant__ float cublas_one_device; | ||||||
|  | __device__ __constant__ float cublas_zero_device; | ||||||
|  |  | ||||||
|  | float *get_cublas_device_one() { | ||||||
|  |   static c10::once_flag init_flag; | ||||||
|  |  | ||||||
|  |   c10::call_once(init_flag, []() { | ||||||
|  |     const float one = 1.f; | ||||||
|  |     AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   float *ptr; | ||||||
|  |   AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device)); | ||||||
|  |   return ptr; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | float *get_cublas_device_zero() { | ||||||
|  |   static c10::once_flag init_flag; | ||||||
|  |  | ||||||
|  |   c10::call_once(init_flag, []() { | ||||||
|  |     const float zero = 0.f; | ||||||
|  |     AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   float *ptr; | ||||||
|  |   AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device)); | ||||||
|  |   return ptr; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | float *get_user_alpha_ptr() { | ||||||
|  |   static float *alpha_ptr; | ||||||
|  |  | ||||||
|  |   static c10::once_flag init_flag; | ||||||
|  |  | ||||||
|  |   c10::call_once(init_flag, []() { | ||||||
|  |     AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   return alpha_ptr; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace detail | ||||||
|  | } // namespace cuda | ||||||
|  | } // namespace at | ||||||
							
								
								
									
										11
									
								
								aten/src/ATen/cuda/detail/BLASConstants.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								aten/src/ATen/cuda/detail/BLASConstants.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/core/TensorBase.h> | ||||||
|  |  | ||||||
|  | namespace at::cuda::detail { | ||||||
|  |  | ||||||
|  | float *get_cublas_device_one(); | ||||||
|  | float *get_cublas_device_zero(); | ||||||
|  | float *get_user_alpha_ptr(); | ||||||
|  |  | ||||||
|  | } // namespace at::cuda::detail | ||||||
| @ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> { | |||||||
|           params->c_scale_ptr, |           params->c_scale_ptr, | ||||||
|           params->ldc, |           params->ldc, | ||||||
|           params->c_dtype, |           params->c_dtype, | ||||||
|           params->use_fast_accum); |           params->use_fast_accum, | ||||||
|  |           std::nullopt /* alpha */); | ||||||
|       return OK; |       return OK; | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  | |||||||
| @ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ | |||||||
|   DispatchKey::CUDA, |   DispatchKey::CUDA, | ||||||
|   DispatchKey::CPU, |   DispatchKey::CPU, | ||||||
|   DispatchKey::PrivateUse1, |   DispatchKey::PrivateUse1, | ||||||
|  |   DispatchKey::SparseCPU, | ||||||
|  |   DispatchKey::SparseCUDA, | ||||||
|  |   DispatchKey::SparseCsrCPU, | ||||||
|  |   DispatchKey::SparseCsrCUDA, | ||||||
| }); | }); | ||||||
|  |  | ||||||
| inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { | inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { | ||||||
|  | |||||||
| @ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) ( | |||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| static const double SELU_ALPHA = 1.6732632423543772848170429916717; | static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717; | ||||||
| static const double SELU_SCALE = 1.0507009873554804934193349852946; | static constexpr double SELU_SCALE = 1.0507009873554804934193349852946; | ||||||
|  |  | ||||||
| DEFINE_DISPATCH(elu_stub); | DEFINE_DISPATCH(elu_stub); | ||||||
| DEFINE_DISPATCH(elu_backward_stub); | DEFINE_DISPATCH(elu_backward_stub); | ||||||
|  | |||||||
| @ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in | |||||||
| #if AT_BUILD_WITH_BLAS() | #if AT_BUILD_WITH_BLAS() | ||||||
| template <> | template <> | ||||||
| bool scal_use_fast_path<double>(int64_t n, int64_t incx) { | bool scal_use_fast_path<double>(int64_t n, int64_t incx) { | ||||||
|   auto intmax = std::numeric_limits<int>::max(); |   auto constexpr intmax = std::numeric_limits<int>::max(); | ||||||
|   return n <= intmax && incx <= intmax; |   return n <= intmax && incx <= intmax; | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -315,7 +315,7 @@ bool gemv_use_fast_path<float>( | |||||||
|     int64_t incx, |     int64_t incx, | ||||||
|     [[maybe_unused]] float beta, |     [[maybe_unused]] float beta, | ||||||
|     int64_t incy) { |     int64_t incy) { | ||||||
|   auto intmax = std::numeric_limits<int>::max(); |   auto constexpr intmax = std::numeric_limits<int>::max(); | ||||||
|   return (m <= intmax) && (n <= intmax) && (lda <= intmax) && |   return (m <= intmax) && (n <= intmax) && (lda <= intmax) && | ||||||
|          (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); |          (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); | ||||||
| } | } | ||||||
|  | |||||||
| @ -658,6 +658,7 @@ static void check_shape_forward(const at::Tensor& input, | |||||||
|   TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); |   TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); | ||||||
|   TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); |   TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); | ||||||
|   TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); |   TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); | ||||||
|  |   TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups); | ||||||
|  |  | ||||||
|   TORCH_CHECK(weight_dim == k, |   TORCH_CHECK(weight_dim == k, | ||||||
|            "Expected ", weight_dim, "-dimensional input for ", weight_dim, |            "Expected ", weight_dim, "-dimensional input for ", weight_dim, | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
|  | #include <array> | ||||||
| #include <ATen/native/Math.h> | #include <ATen/native/Math.h> | ||||||
| #include <c10/macros/Macros.h> | #include <c10/macros/Macros.h> | ||||||
| #include <c10/util/MathConstants.h> | #include <c10/util/MathConstants.h> | ||||||
| @ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor | |||||||
|  |  | ||||||
| template<typename scalar_t> | template<typename scalar_t> | ||||||
| C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | ||||||
|   const static scalar_t kTailValues[] = { |   constexpr static scalar_t kTailValues[] = { | ||||||
|     0.0810614667953272, |     0.0810614667953272, | ||||||
|     0.0413406959554092, |     0.0413406959554092, | ||||||
|     0.0276779256849983, |     0.0276779256849983, | ||||||
| @ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | |||||||
|     0.00925546218271273, |     0.00925546218271273, | ||||||
|     0.00833056343336287 |     0.00833056343336287 | ||||||
|   }; |   }; | ||||||
|   if (k <= 9) { |   if (k < std::size(kTailValues)) { | ||||||
|     return kTailValues[static_cast<size_t>(k)]; |     return kTailValues[static_cast<size_t>(k)]; | ||||||
|   } |   } | ||||||
|   scalar_t kp1sq = (k + 1) * (k + 1); |   scalar_t kp1sq = (k + 1) * (k + 1); | ||||||
|  | |||||||
| @ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, | |||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||||
|   // lanczos approximation |   // lanczos approximation | ||||||
|   static const scalar_t lanczos_sum_expg_scaled_num[13] = { |   static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = { | ||||||
|     0.006061842346248906525783753964555936883222, |     0.006061842346248906525783753964555936883222, | ||||||
|     0.5098416655656676188125178644804694509993, |     0.5098416655656676188125178644804694509993, | ||||||
|     19.51992788247617482847860966235652136208, |     19.51992788247617482847860966235652136208, | ||||||
| @ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | |||||||
|     103794043.1163445451906271053616070238554, |     103794043.1163445451906271053616070238554, | ||||||
|     56906521.91347156388090791033559122686859 |     56906521.91347156388090791033559122686859 | ||||||
|   }; |   }; | ||||||
|   static const scalar_t lanczos_sum_expg_scaled_denom[13] = { |   static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||||
|     1., |     1., | ||||||
|     66., |     66., | ||||||
|     1925., |     1925., | ||||||
| @ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { | |||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { | static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { | ||||||
|   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] |   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] | ||||||
|   static const scalar_t d[25][25] = |   static constexpr scalar_t d[25][25] = | ||||||
|     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, |     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, | ||||||
|       1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, |       1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, | ||||||
|       3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, |       3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, | ||||||
|  | |||||||
| @ -62,7 +62,7 @@ | |||||||
| #include <utility> | #include <utility> | ||||||
| #include <vector> | #include <vector> | ||||||
|  |  | ||||||
| static const int MIOPEN_DIM_MAX = 5; | static constexpr int MIOPEN_DIM_MAX = 5; | ||||||
|  |  | ||||||
| namespace at::meta { | namespace at::meta { | ||||||
|  |  | ||||||
|  | |||||||
| @ -1906,11 +1906,9 @@ Tensor& index_fill_( | |||||||
|         "This also applies to advanced indexing e.g. tensor[mask] = scalar"); |         "This also applies to advanced indexing e.g. tensor[mask] = scalar"); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (!self.is_complex() && source.isComplex()) { |   TORCH_CHECK( | ||||||
|     TORCH_CHECK( |       self.is_complex() || !source.isComplex(), | ||||||
|         false, |       "index_fill_(): Converting complex Scalar to non-complex type is not supported"); | ||||||
|         "index_fill_(): Converting complex Scalar to non-complex type is not supported"); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Handle the case when `self` is 0-dim |   // Handle the case when `self` is 0-dim | ||||||
|   Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; |   Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; | ||||||
|  | |||||||
| @ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { | |||||||
|   // next broadcast all index tensors together |   // next broadcast all index tensors together | ||||||
|   try { |   try { | ||||||
|     indices = expand_outplace(indices); |     indices = expand_outplace(indices); | ||||||
|   } catch (std::exception& e) { |   } catch (std::exception&) { | ||||||
|     TORCH_CHECK_INDEX( |     TORCH_CHECK_INDEX( | ||||||
|         false, |         false, | ||||||
|         "shape mismatch: indexing tensors could not be broadcast together" |         "shape mismatch: indexing tensors could not be broadcast together" | ||||||
|  | |||||||
| @ -3079,7 +3079,9 @@ Tensor slice( | |||||||
|   } |   } | ||||||
|   auto storage_offset = self.storage_offset() + start_val * strides[dim]; |   auto storage_offset = self.storage_offset() + start_val * strides[dim]; | ||||||
|   auto len = end_val - start_val; |   auto len = end_val - start_val; | ||||||
|   sizes[dim] = (len + step - 1) / step; // round-up |   // NB: (len + step - 1) / step is equivalent implementation, | ||||||
|  |   // but len + step could overflow if step or len is very large | ||||||
|  |   sizes[dim] = (len == 0) ? 0 : (1 + (len - 1) / step); // safely round up | ||||||
|   strides[dim] *= step; |   strides[dim] *= step; | ||||||
|  |  | ||||||
|   Tensor result; |   Tensor result; | ||||||
|  | |||||||
| @ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel( | |||||||
|   } else if (dtype == ScalarType::Half) { |   } else if (dtype == ScalarType::Half) { | ||||||
|     [&]() { |     [&]() { | ||||||
|       using scalar_t = |       using scalar_t = | ||||||
|           c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>; |           decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t); | ||||||
|       const auto exp = exp_scalar.to<scalar_t>(); |       const auto exp = exp_scalar.to<scalar_t>(); | ||||||
|       using Vec = Vectorized<scalar_t>; |       using Vec = Vectorized<scalar_t>; | ||||||
|       cpu_kernel_vec(iter, |       cpu_kernel_vec(iter, | ||||||
|  | |||||||
| @ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase { | |||||||
|   // We keep this structure for BC and consider as deprecated. |   // We keep this structure for BC and consider as deprecated. | ||||||
|   // See HelperInterpNearestExact as replacement |   // See HelperInterpNearestExact as replacement | ||||||
|  |  | ||||||
|   static const int interp_size = 1; |   static constexpr int interp_size = 1; | ||||||
|  |  | ||||||
|   static inline void init_indices_weights( |   static inline void init_indices_weights( | ||||||
|     at::ScalarType output_type, |     at::ScalarType output_type, | ||||||
| @ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest { | |||||||
|  |  | ||||||
| struct HelperInterpLinear : public HelperInterpBase { | struct HelperInterpLinear : public HelperInterpBase { | ||||||
|  |  | ||||||
|   static const int interp_size = 2; |   static constexpr int interp_size = 2; | ||||||
|  |  | ||||||
|   // Compute indices and weights for each interpolated dimension |   // Compute indices and weights for each interpolated dimension | ||||||
|   // indices_weights = { |   // indices_weights = { | ||||||
| @ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase { | |||||||
|  |  | ||||||
| struct HelperInterpCubic : public HelperInterpBase { | struct HelperInterpCubic : public HelperInterpBase { | ||||||
|  |  | ||||||
|   static const int interp_size = 4; |   static constexpr int interp_size = 4; | ||||||
|  |  | ||||||
|   // Compute indices and weights for each interpolated dimension |   // Compute indices and weights for each interpolated dimension | ||||||
|   // indices_weights = { |   // indices_weights = { | ||||||
|  | |||||||
| @ -1359,7 +1359,8 @@ _scaled_gemm( | |||||||
|           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, |           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, | ||||||
|           const std::optional<Tensor>& bias, |           const std::optional<Tensor>& bias, | ||||||
|           const bool use_fast_accum, |           const bool use_fast_accum, | ||||||
|           Tensor& out) { |           Tensor& out, | ||||||
|  |           const std::optional<Tensor>& alpha = std::nullopt) { | ||||||
|   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); |   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); | ||||||
|   const auto out_dtype_ = args.result->scalar_type(); |   const auto out_dtype_ = args.result->scalar_type(); | ||||||
|   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); |   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); | ||||||
| @ -1410,7 +1411,8 @@ _scaled_gemm( | |||||||
|           args.scale_result_ptr, |           args.scale_result_ptr, | ||||||
|           args.result_ld, |           args.result_ld, | ||||||
|           out_dtype_, |           out_dtype_, | ||||||
|           use_fast_accum); |           use_fast_accum, | ||||||
|  |           alpha); | ||||||
|       return out; |       return out; | ||||||
|   } |   } | ||||||
| } | } | ||||||
| @ -1759,6 +1761,7 @@ enum class ScaledGemmImplementation { | |||||||
|   MXFP8_MXFP8 = 6, |   MXFP8_MXFP8 = 6, | ||||||
|   NVFP4_NVFP4 = 7, |   NVFP4_NVFP4 = 7, | ||||||
|   NVFP4_NVFP4_SINGLE_SCALE = 8, |   NVFP4_NVFP4_SINGLE_SCALE = 8, | ||||||
|  |   MXFP4_MXFP4 = 9, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| /** | /** | ||||||
| @ -1955,10 +1958,39 @@ bool check_mxfp8_recipe(c10::ScalarType type_a, | |||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Both inputs must be fp4 | ||||||
|  |  * A, B must have 1 scale each, {Blockwise_1x32, e8m0} | ||||||
|  |  */ | ||||||
|  | bool check_mxfp4_recipe(c10::ScalarType type_a, | ||||||
|  |                         std::vector<ScalingType>& recipe_a, | ||||||
|  |                         ArrayRef<Tensor>& scales_a, | ||||||
|  |                         c10::ScalarType type_b, | ||||||
|  |                         std::vector<ScalingType>& recipe_b, | ||||||
|  |                         ArrayRef<Tensor>& scales_b) { | ||||||
|  |   // both types must be fp4 | ||||||
|  |   if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // 1 scales, 1 recipes for each input | ||||||
|  |   if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Need {Blockwise_1x32, e8m0} for A & B | ||||||
|  |   if (recipe_a[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |   if (recipe_b[0] != ScalingType::BlockWise1x32) return false; | ||||||
|  |   if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
| using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>; | using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>; | ||||||
| using namespace std::placeholders; | using namespace std::placeholders; | ||||||
|  |  | ||||||
| std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{ | std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 9> scale_kernel_dispatch = {{ | ||||||
|   { "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE }, |   { "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE }, | ||||||
|   { "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, |   { "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, | ||||||
|   { "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6), |   { "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6), | ||||||
| @ -1969,7 +2001,8 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> | |||||||
|     ScaledGemmImplementation::BLOCK_1x128_1x128}, |     ScaledGemmImplementation::BLOCK_1x128_1x128}, | ||||||
|   { "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}, |   { "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}, | ||||||
|   { "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE }, |   { "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE }, | ||||||
|   { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; |   { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}, | ||||||
|  |   { "mxfp4_mxfp4", check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}}}; | ||||||
|  |  | ||||||
| Tensor& | Tensor& | ||||||
| _scaled_tensorwise_tensorwise( | _scaled_tensorwise_tensorwise( | ||||||
| @ -2187,15 +2220,22 @@ _scaled_mxfp8_mxfp8( | |||||||
|   TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", |   TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", | ||||||
|       mat_a.scalar_type(), mat_b.scalar_type()); |       mat_a.scalar_type(), mat_b.scalar_type()); | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |   auto scale_a_elems = ceil_div<int64_t>(mat_a.size(0), 32) * mat_a.size(1); | ||||||
|  |   auto scale_b_elems = ceil_div<int64_t>(mat_b.size(1), 32) * mat_b.size(0); | ||||||
|  | #else | ||||||
|   auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4); |   auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4); | ||||||
|   auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4); |   auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4); | ||||||
|  | #endif | ||||||
|   TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), |   TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), | ||||||
|          "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); |          "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); | ||||||
|   TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), |   TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), | ||||||
|          "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); |          "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); | ||||||
|  |  | ||||||
|  | #ifndef USE_ROCM | ||||||
|   TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); |   TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); | ||||||
|   TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); |   TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|   TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), |   TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), | ||||||
|         "For Blockwise scaling both scales should be contiguous"); |         "For Blockwise scaling both scales should be contiguous"); | ||||||
| @ -2225,6 +2265,56 @@ _scaled_mxfp8_mxfp8( | |||||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); |   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Tensor& | ||||||
|  | _scaled_mxfp4_mxfp4( | ||||||
|  |           const Tensor& mat_a, const Tensor& mat_b, | ||||||
|  |           const Tensor& scale_a, const SwizzleType swizzle_a, | ||||||
|  |           const Tensor& scale_b, const SwizzleType swizzle_b, | ||||||
|  |           const std::optional<Tensor>& bias, | ||||||
|  |           const c10::ScalarType out_dtype, | ||||||
|  |           Tensor& out) { | ||||||
|  | #ifndef USE_ROCM | ||||||
|  |   TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only"); | ||||||
|  | #endif | ||||||
|  |   // Restrictions: | ||||||
|  |   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 | ||||||
|  |   TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", | ||||||
|  |       mat_a.scalar_type(), mat_b.scalar_type()); | ||||||
|  |  | ||||||
|  |   auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1); | ||||||
|  |   auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0); | ||||||
|  |   TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), | ||||||
|  |          "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); | ||||||
|  |   TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), | ||||||
|  |          "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), | ||||||
|  |         "For Blockwise scaling both scales should be contiguous"); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype); | ||||||
|  |  | ||||||
|  |   auto scaling_choice_a = ScalingType::BlockWise1x32; | ||||||
|  |   auto scaling_choice_b = ScalingType::BlockWise1x32; | ||||||
|  |  | ||||||
|  | #if ROCM_VERSION >= 70000 | ||||||
|  |   TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||||||
|  |               "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && | ||||||
|  |               mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, | ||||||
|  |               "Matrix dimensions must be multiples of 32 for block-wise scaling"); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || | ||||||
|  |               out.scalar_type() == ScalarType::Half, | ||||||
|  |               "Block-wise scaling only supports BFloat16 or Half output types"); | ||||||
|  | #else | ||||||
|  |     TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||||
|  | } | ||||||
|  |  | ||||||
| Tensor& | Tensor& | ||||||
| _scaled_nvfp4_nvfp4( | _scaled_nvfp4_nvfp4( | ||||||
|           const Tensor& mat_a, const Tensor& mat_b, |           const Tensor& mat_a, const Tensor& mat_b, | ||||||
| @ -2232,12 +2322,23 @@ _scaled_nvfp4_nvfp4( | |||||||
|           const Tensor& scale_b, const SwizzleType swizzle_b, |           const Tensor& scale_b, const SwizzleType swizzle_b, | ||||||
|           const std::optional<Tensor>& bias, |           const std::optional<Tensor>& bias, | ||||||
|           const c10::ScalarType out_dtype, |           const c10::ScalarType out_dtype, | ||||||
|           const bool single_scale, |           Tensor& out, | ||||||
|           Tensor& out) { |           const std::optional<Tensor>& global_scale_a = std::nullopt, | ||||||
|  |           const std::optional<Tensor>& global_scale_b = std::nullopt) { | ||||||
| #ifdef USE_ROCM | #ifdef USE_ROCM | ||||||
|   TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM"); |   TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM"); | ||||||
| #endif | #endif | ||||||
|   TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported"); |   std::optional<Tensor> alpha = std::nullopt; | ||||||
|  |   // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise, | ||||||
|  |   //       if this is "And" we would silently do nothing in the case where one global scale is | ||||||
|  |   //       passed and not the other. | ||||||
|  |   if (global_scale_a.has_value() || global_scale_b.has_value()) { | ||||||
|  |     TORCH_CHECK_VALUE(global_scale_a.has_value(), | ||||||
|  |         "For two-level-scaled NVFP4, global_scale_a must have a value"); | ||||||
|  |     TORCH_CHECK_VALUE(global_scale_b.has_value(), | ||||||
|  |         "For two-level-scaled NVFP4, global_scale_b must have a value"); | ||||||
|  |     alpha = global_scale_a.value().mul(global_scale_b.value()); | ||||||
|  |   } | ||||||
|   // Restrictions: |   // Restrictions: | ||||||
|   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 |   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 | ||||||
|   // Scales must be swizzled |   // Scales must be swizzled | ||||||
| @ -2259,7 +2360,7 @@ _scaled_nvfp4_nvfp4( | |||||||
|  |  | ||||||
|   auto scaling_choice_a = ScalingType::BlockWise1x16; |   auto scaling_choice_a = ScalingType::BlockWise1x16; | ||||||
|   auto scaling_choice_b = ScalingType::BlockWise1x16; |   auto scaling_choice_b = ScalingType::BlockWise1x16; | ||||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); |   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha); | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -2465,9 +2566,12 @@ _scaled_mm_cuda_v2_out( | |||||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) { |   } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) { | ||||||
|     return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); |     return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) { |   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) { | ||||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported"); |     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out, | ||||||
|  |                                scale_a[1], scale_b[1]); | ||||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) { |   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) { | ||||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out); |     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||||
|  |   } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) { | ||||||
|  |     return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||||
|   } else { |   } else { | ||||||
|     TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really"); |     TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really"); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -856,13 +856,9 @@ struct type_specialized_kernel_launcher { | |||||||
|       out_calc_t output_offset_calculator, |       out_calc_t output_offset_calculator, | ||||||
|       loader_t loader, |       loader_t loader, | ||||||
|       storer_t storer) { |       storer_t storer) { | ||||||
|     constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0]; |     if (ret_t == rt_binary_specializations[arg_index][0] && | ||||||
|     constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1]; |         arg0_t == rt_binary_specializations[arg_index][1] && | ||||||
|     constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2]; |         arg1_t == rt_binary_specializations[arg_index][2]) | ||||||
|     if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) { |  | ||||||
|       using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>; |  | ||||||
|       using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>; |  | ||||||
|       using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>; |  | ||||||
|       launch_vectorized_templated_kernel< |       launch_vectorized_templated_kernel< | ||||||
|           func_t, |           func_t, | ||||||
|           array_t, |           array_t, | ||||||
| @ -870,9 +866,12 @@ struct type_specialized_kernel_launcher { | |||||||
|           out_calc_t, |           out_calc_t, | ||||||
|           loader_t, |           loader_t, | ||||||
|           storer_t, |           storer_t, | ||||||
|           cret_t, |           decltype(c10::impl::ScalarTypeToCPPType< | ||||||
|           carg0_t, |                    rt_binary_specializations[arg_index][0]>::t), | ||||||
|           carg1_t>( |           decltype(c10::impl::ScalarTypeToCPPType< | ||||||
|  |                    rt_binary_specializations[arg_index][1]>::t), | ||||||
|  |           decltype(c10::impl::ScalarTypeToCPPType< | ||||||
|  |                    rt_binary_specializations[arg_index][2]>::t)>( | ||||||
|           numel, |           numel, | ||||||
|           f, |           f, | ||||||
|           data, |           data, | ||||||
| @ -880,7 +879,6 @@ struct type_specialized_kernel_launcher { | |||||||
|           output_offset_calculator, |           output_offset_calculator, | ||||||
|           loader, |           loader, | ||||||
|           storer); |           storer); | ||||||
|     } |  | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | |||||||
| @ -38,12 +38,41 @@ __device__ inline int min(int a, int b) { | |||||||
| #define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched | #define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) { | template <typename index_t> | ||||||
|   return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1; | static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) { | ||||||
|  |   const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1); | ||||||
|  |   return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1; | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) { | template <typename index_t> | ||||||
|   return min((size + pad) / stride + 1, pooled_size); | static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) { | ||||||
|  |   return std::min((size + pad) / stride + 1, pooled_size); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static inline bool can_use_int32_nhwc( | ||||||
|  |     int64_t nbatch, int64_t channels, | ||||||
|  |     int64_t height, int64_t width, | ||||||
|  |     int64_t pooled_height, int64_t pooled_width, | ||||||
|  |     int64_t in_stride_n, int64_t in_stride_c, | ||||||
|  |     int64_t in_stride_h, int64_t in_stride_w) | ||||||
|  | { | ||||||
|  |   constexpr int64_t int_max = std::numeric_limits<int>::max(); | ||||||
|  |  | ||||||
|  |   int64_t max_intra_batch = | ||||||
|  |       (height ? (height - 1) * in_stride_h : 0) + | ||||||
|  |       (width ? (width - 1) * in_stride_w : 0) + | ||||||
|  |       (channels? (channels - 1) * in_stride_c : 0); | ||||||
|  |  | ||||||
|  |   int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch; | ||||||
|  |  | ||||||
|  |   if (max_input_offset > int_max) return false; | ||||||
|  |  | ||||||
|  |   int64_t out_batch_stride = pooled_height * pooled_width * channels; | ||||||
|  |   if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false; | ||||||
|  |  | ||||||
|  |   if (height * width > int_max) return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| // kernels borrowed from Caffe | // kernels borrowed from Caffe | ||||||
| @ -85,21 +114,25 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename scalar_t> | template <typename scalar_t, typename index_t> | ||||||
| C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS) | C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS) | ||||||
| __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, | __global__ void max_pool_forward_nhwc( | ||||||
|                                    const int64_t channels, const int64_t height, |     const scalar_t* bottom_data, | ||||||
|                                    const int64_t width, const int pooled_height, const int pooled_width, |     const int nbatch, | ||||||
|                                    const int kernel_h, const int kernel_w, const int stride_h, |     const index_t channels, const index_t height, const index_t width, | ||||||
|                                    const int stride_w, const int pad_h, const int pad_w, |     const index_t pooled_height, const index_t pooled_width, | ||||||
|                                    const int dilation_h, const int dilation_w, |     const int kernel_h, const int kernel_w, const int stride_h, | ||||||
|                                    const int in_stride_n, const int in_stride_c, |     const int stride_w, const int pad_h, const int pad_w, | ||||||
|                                    const int in_stride_h, const int in_stride_w, |     const int dilation_h, const int dilation_w, | ||||||
|                                    const int kernel_stride_C, const int kernel_size_C, |     const index_t in_stride_n, const index_t in_stride_c, | ||||||
|                                    scalar_t* top_data, int64_t* top_mask) { |     const index_t in_stride_h, const index_t in_stride_w, | ||||||
|   extern __shared__ int smem[]; |     const int kernel_stride_C, const int kernel_size_C, | ||||||
|   int *out_mask_cached = smem; |     scalar_t* top_data, int64_t* top_mask) { | ||||||
|   scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]); |  | ||||||
|  |   extern __shared__ unsigned char smem_raw[]; | ||||||
|  |   index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw); | ||||||
|  |   scalar_t *out_cached = reinterpret_cast<scalar_t*>( | ||||||
|  |       out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z); | ||||||
|  |  | ||||||
|   // flattening cta for pre-computation & smem initialization; |   // flattening cta for pre-computation & smem initialization; | ||||||
|   int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); |   int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); | ||||||
| @ -118,26 +151,26 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba | |||||||
|   int channel_id = blockIdx.x / nbatch; |   int channel_id = blockIdx.x / nbatch; | ||||||
|   int channel_offset = threadIdx.x + channel_id * blockDim.x; |   int channel_offset = threadIdx.x + channel_id * blockDim.x; | ||||||
|  |  | ||||||
|   top_data = top_data + batch_id * pooled_height * pooled_width * channels; |   top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels); | ||||||
|   top_mask = top_mask + batch_id * pooled_height * pooled_width * channels; |   top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels); | ||||||
|   bottom_data = bottom_data + batch_id * in_stride_n; |   bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n; | ||||||
|  |  | ||||||
|   out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; |   out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x; | ||||||
|   out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; |   out_mask_cached  += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x; | ||||||
|  |  | ||||||
|   int oH = (pooled_height + gridDim.z-1) / gridDim.z; |   int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z; | ||||||
|   int oW = (pooled_width + gridDim.y-1) / gridDim.y; |   int oW = (static_cast<int>(pooled_width)  + gridDim.y - 1) / gridDim.y; | ||||||
|   int ostartH = threadIdx.z + blockIdx.z*oH; |   int ostartH = threadIdx.z + blockIdx.z*oH; | ||||||
|   int oendH = ::min(ostartH+oH, pooled_height); |   int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height)); | ||||||
|   int ostartW = threadIdx.y + blockIdx.y*oW; |   int ostartW = threadIdx.y + blockIdx.y*oW; | ||||||
|   int oendW = ::min(ostartW+oW, pooled_width); |   int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width)); | ||||||
|  |  | ||||||
|   for (int oh = ostartH; oh < oendH; oh+=blockDim.z) { |   for (int oh = ostartH; oh < oendH; oh+=blockDim.z) { | ||||||
|     int hstart = oh * stride_h - pad_h; |     index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h; | ||||||
|     int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height); |     index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height); | ||||||
|     for (int ow = ostartW; ow < oendW; ow+=blockDim.y) { |     for (int ow = ostartW; ow < oendW; ow+=blockDim.y) { | ||||||
|       int wstart = ow * stride_w - pad_w; |       index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w; | ||||||
|       int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width); |       index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width); | ||||||
|       while(hstart < 0) |       while(hstart < 0) | ||||||
|         hstart += dilation_h; |         hstart += dilation_h; | ||||||
|       while(wstart < 0) |       while(wstart < 0) | ||||||
| @ -185,12 +218,12 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba | |||||||
|       // Else do it Non-Prefetch... |       // Else do it Non-Prefetch... | ||||||
|       else |       else | ||||||
| #endif | #endif | ||||||
|       for (int ih = hstart; ih < hend; ih += dilation_h) { |       for (index_t ih = hstart; ih < hend; ih += dilation_h) { | ||||||
|         for (int iw = wstart; iw < wend; iw += dilation_w) { |         for (index_t iw = wstart; iw < wend; iw += dilation_w) { | ||||||
|           int cached_index = threadIdx.x; |           int cached_index = threadIdx.x; | ||||||
|           const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; |           const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; | ||||||
|           for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { |           for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) { | ||||||
|             scalar_t val = ptr_input[c*in_stride_c]; |             scalar_t val = ptr_input[c * in_stride_c]; | ||||||
|             if ((val > out_cached[cached_index]) || at::_isnan(val)) { |             if ((val > out_cached[cached_index]) || at::_isnan(val)) { | ||||||
|               out_cached[cached_index] = val; |               out_cached[cached_index] = val; | ||||||
|               out_mask_cached[cached_index] = ih * width + iw; |               out_mask_cached[cached_index] = ih * width + iw; | ||||||
| @ -200,15 +233,15 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba | |||||||
|         } |         } | ||||||
|       } |       } | ||||||
|  |  | ||||||
|       scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels; |       scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels; | ||||||
|       int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels; |       int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels; | ||||||
|  |  | ||||||
|       int cached_index = threadIdx.x; |       int cached_index = threadIdx.x; | ||||||
|       for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { |       for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) { | ||||||
|         ptr_output_data[c] = out_cached[cached_index]; |         ptr_output_data[c] = out_cached[cached_index]; | ||||||
|         ptr_output_mask[c] = out_mask_cached[cached_index]; |         ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]); | ||||||
|         out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound(); |         out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound(); | ||||||
|         out_mask_cached[cached_index] = 0; |         out_mask_cached[cached_index] = index_t(0); | ||||||
|         cached_index += blockDim.x; |         cached_index += blockDim.x; | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| @ -216,7 +249,7 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba | |||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| static const int BLOCK_THREADS = 256; | static constexpr int BLOCK_THREADS = 256; | ||||||
|  |  | ||||||
| template <typename scalar_t, typename accscalar_t> | template <typename scalar_t, typename accscalar_t> | ||||||
| #if defined (USE_ROCM) | #if defined (USE_ROCM) | ||||||
| @ -462,6 +495,11 @@ const Tensor& indices) { | |||||||
|               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z)); |               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z)); | ||||||
|           const dim3 block(block_x, block_y, block_z); |           const dim3 block(block_x, block_y, block_z); | ||||||
|  |  | ||||||
|  |           bool use_int32 = can_use_int32_nhwc( | ||||||
|  |               nbatch, nInputPlane, inputHeight, inputWidth, | ||||||
|  |               outputHeight, outputWidth, | ||||||
|  |               in_stride_n, in_stride_c, in_stride_h, in_stride_w); | ||||||
|  |  | ||||||
|           int kernel_stride_C = ceil_div( |           int kernel_stride_C = ceil_div( | ||||||
|               safe_downcast<int, int64_t>(nInputPlane), block_x * 4); |               safe_downcast<int, int64_t>(nInputPlane), block_x * 4); | ||||||
|           int kernel_size_C = ceil_div( |           int kernel_size_C = ceil_div( | ||||||
| @ -476,18 +514,41 @@ const Tensor& indices) { | |||||||
|               ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD)); |               ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD)); | ||||||
|           const dim3 grid(grid_x, grid_y, grid_z); |           const dim3 grid(grid_x, grid_y, grid_z); | ||||||
|  |  | ||||||
|           size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t)); |           size_t shmem_size; | ||||||
|           AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); |           size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z; | ||||||
|  |  | ||||||
|           max_pool_forward_nhwc<scalar_t> |           if (use_int32) { | ||||||
|           <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>( |             shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t)); | ||||||
|               input_data, nbatch, |             TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock, | ||||||
|                   nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, |                         "shared memory too small"); | ||||||
|                   kH, kW, dH, dW, padH, padW, dilationH, dilationW, |             max_pool_forward_nhwc<scalar_t, int32_t> | ||||||
|                   in_stride_n, in_stride_c, |               <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>( | ||||||
|                   in_stride_h, in_stride_w, |                 input_data, static_cast<int>(nbatch), | ||||||
|                   kernel_stride_C, kernel_size_C, |                 static_cast<int32_t>(nInputPlane), | ||||||
|                   output_data, indices_data); |                 static_cast<int32_t>(inputHeight), | ||||||
|  |                 static_cast<int32_t>(inputWidth), | ||||||
|  |                 static_cast<int32_t>(outputHeight), | ||||||
|  |                 static_cast<int32_t>(outputWidth), | ||||||
|  |                 kH, kW, dH, dW, padH, padW, dilationH, dilationW, | ||||||
|  |                 static_cast<int32_t>(in_stride_n), | ||||||
|  |                 static_cast<int32_t>(in_stride_c), | ||||||
|  |                 static_cast<int32_t>(in_stride_h), | ||||||
|  |                 static_cast<int32_t>(in_stride_w), | ||||||
|  |                 kernel_stride_C, kernel_size_C, | ||||||
|  |                 output_data, indices_data); | ||||||
|  |           } else { | ||||||
|  |             shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t)); | ||||||
|  |             TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock, | ||||||
|  |                         "shared memory too small"); | ||||||
|  |             max_pool_forward_nhwc<scalar_t, int64_t> | ||||||
|  |               <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>( | ||||||
|  |                 input_data, static_cast<int>(nbatch), | ||||||
|  |                 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, | ||||||
|  |                 kH, kW, dH, dW, padH, padW, dilationH, dilationW, | ||||||
|  |                 in_stride_n, in_stride_c, in_stride_h, in_stride_w, | ||||||
|  |                 kernel_stride_C, kernel_size_C, | ||||||
|  |                 output_data, indices_data); | ||||||
|  |           } | ||||||
|           C10_CUDA_KERNEL_LAUNCH_CHECK(); |           C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
|           break; |           break; | ||||||
|         } |         } | ||||||
|  | |||||||
| @ -15,9 +15,7 @@ | |||||||
| #include <ATen/native/cuda/block_reduce.cuh> | #include <ATen/native/cuda/block_reduce.cuh> | ||||||
| #include <ATen/native/cuda/thread_constants.h> | #include <ATen/native/cuda/thread_constants.h> | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
| #include <thrust/iterator/reverse_iterator.h> | #include <thrust/iterator/reverse_iterator.h> | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifndef AT_PER_OPERATOR_HEADERS | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
| #include <ATen/Functions.h> | #include <ATen/Functions.h> | ||||||
| @ -36,9 +34,9 @@ namespace at::native { | |||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| #if defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
| static const int BLOCKDIMY = 16; | static constexpr int BLOCKDIMY = 16; | ||||||
| #else | #else | ||||||
| static const int BLOCKDIMY = 32; | static constexpr int BLOCKDIMY = 32; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| template | template | ||||||
| @ -240,10 +238,6 @@ __global__ void renorm_kernel( | |||||||
|  |  | ||||||
| } // anonymous namespace | } // anonymous namespace | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
| template<typename index_t> |  | ||||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, | Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, | ||||||
|                                int64_t num_weights, int64_t padding_idx, |                                int64_t num_weights, int64_t padding_idx, | ||||||
| @ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | |||||||
|  |  | ||||||
|   if (scale_grad_by_freq) { |   if (scale_grad_by_freq) { | ||||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { |     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |  | ||||||
| @ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | |||||||
|         num_indices |         num_indices | ||||||
|       ); |       ); | ||||||
|     }); |     }); | ||||||
| #else |  | ||||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { |  | ||||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); |  | ||||||
|     }); |  | ||||||
| #endif |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, |   return embedding_backward_cuda_kernel(grad, orig_indices, | ||||||
|  | |||||||
| @ -10,9 +10,7 @@ | |||||||
|  |  | ||||||
| #include <c10/macros/Macros.h> | #include <c10/macros/Macros.h> | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() |  | ||||||
| #include <thrust/iterator/counting_iterator.h> | #include <thrust/iterator/counting_iterator.h> | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifndef AT_PER_OPERATOR_HEADERS | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
| #include <ATen/Functions.h> | #include <ATen/Functions.h> | ||||||
| @ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm | |||||||
|             partials_per_segment_offset[num_of_segments-1]; |             partials_per_segment_offset[num_of_segments-1]; | ||||||
| } | } | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() |  | ||||||
| __global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) { |  | ||||||
|   *num_of_segments_ptr = num_of_segments; |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| } // anon namespace | } // anon namespace | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() |  | ||||||
| template<typename index_t> |  | ||||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| Tensor embedding_backward_cuda_kernel( | Tensor embedding_backward_cuda_kernel( | ||||||
|         const Tensor &grad, |         const Tensor &grad, | ||||||
| @ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel( | |||||||
|   auto segment_offsets = at::empty({numel}, orig_indices.options()); |   auto segment_offsets = at::empty({numel}, orig_indices.options()); | ||||||
|   auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); |   auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); | ||||||
|   int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>(); |   int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>(); | ||||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() |  | ||||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { |  | ||||||
|     int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets); |  | ||||||
|     write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments); |  | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |  | ||||||
|   }); |  | ||||||
| #else |  | ||||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { |   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||||
|     cuda::cub::unique_by_key( |     cuda::cub::unique_by_key( | ||||||
|       sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0), |       sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0), | ||||||
|       segment_offsets.mutable_data_ptr<index_t>(), |       segment_offsets.mutable_data_ptr<index_t>(), | ||||||
|       num_of_segments_ptr, sorted_indices.numel()); |       num_of_segments_ptr, sorted_indices.numel()); | ||||||
|   }); |   }); | ||||||
| #endif |  | ||||||
|  |  | ||||||
|   int64_t max_segments = std::min<int64_t>(numel, num_weights); |   int64_t max_segments = std::min<int64_t>(numel, num_weights); | ||||||
|  |  | ||||||
|  | |||||||
| @ -31,16 +31,10 @@ | |||||||
|  |  | ||||||
| #include <c10/macros/Macros.h> | #include <c10/macros/Macros.h> | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
| #include <thrust/iterator/reverse_iterator.h> | #include <thrust/iterator/reverse_iterator.h> | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
| template<typename index_t> |  | ||||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| @ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg( | |||||||
|  |  | ||||||
|   if (scale_grad_by_freq) { |   if (scale_grad_by_freq) { | ||||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { |     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |  | ||||||
| @ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg( | |||||||
|         num_indices |         num_indices | ||||||
|       ); |       ); | ||||||
|     }); |     }); | ||||||
| #else |  | ||||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { |  | ||||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); |  | ||||||
|     }); |  | ||||||
| #endif |  | ||||||
|   } |   } | ||||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, |   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, | ||||||
|       count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, |       count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, | ||||||
|  | |||||||
| @ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { | |||||||
|   // lanczos approximation |   // lanczos approximation | ||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|  |  | ||||||
|   static const accscalar_t lanczos_sum_expg_scaled_num[13] = { |   constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = { | ||||||
|     0.006061842346248906525783753964555936883222, |     0.006061842346248906525783753964555936883222, | ||||||
|     0.5098416655656676188125178644804694509993, |     0.5098416655656676188125178644804694509993, | ||||||
|     19.51992788247617482847860966235652136208, |     19.51992788247617482847860966235652136208, | ||||||
| @ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { | |||||||
|     103794043.1163445451906271053616070238554, |     103794043.1163445451906271053616070238554, | ||||||
|     56906521.91347156388090791033559122686859 |     56906521.91347156388090791033559122686859 | ||||||
|   }; |   }; | ||||||
|   static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { |   constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||||
|     1., |     1., | ||||||
|     66., |     66., | ||||||
|     1925., |     1925., | ||||||
| @ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { | |||||||
|  |  | ||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|   accscalar_t ax, fac, res, num, numfac; |   accscalar_t ax, fac, res, num, numfac; | ||||||
|   static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? |   constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? | ||||||
|     7.09782712893383996843E2 : 88.72283905206835; |     7.09782712893383996843E2 : 88.72283905206835; | ||||||
|   static const accscalar_t EXP1 = 2.718281828459045; |   constexpr accscalar_t EXP1 = 2.718281828459045; | ||||||
|   static const accscalar_t lanczos_g = 6.024680040776729583740234375; |   constexpr accscalar_t lanczos_g = 6.024680040776729583740234375; | ||||||
|  |  | ||||||
|   if (::fabs(a - x) > 0.4 * ::fabs(a)) { |   if (::fabs(a - x) > 0.4 * ::fabs(a)) { | ||||||
|     ax = a * ::log(x) - x - ::lgamma(a); |     ax = a * ::log(x) - x - ::lgamma(a); | ||||||
| @ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) { | |||||||
|   // Compute igam using DLMF 8.11.4. [igam1] |   // Compute igam using DLMF 8.11.4. [igam1] | ||||||
|  |  | ||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? |   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; |     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||||
|   static const int MAXITER = 2000; |   constexpr int MAXITER = 2000; | ||||||
|  |  | ||||||
|   int i; |   int i; | ||||||
|   accscalar_t ans, ax, c, r; |   accscalar_t ans, ax, c, r; | ||||||
| @ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { | |||||||
|   accscalar_t fac = 1; |   accscalar_t fac = 1; | ||||||
|   accscalar_t sum = 0; |   accscalar_t sum = 0; | ||||||
|   accscalar_t term, logx; |   accscalar_t term, logx; | ||||||
|   static const int MAXITER = 2000; |   constexpr int MAXITER = 2000; | ||||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? |   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; |     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||||
|  |  | ||||||
|   for (n = 1; n < MAXITER; n++) { |   for (n = 1; n < MAXITER; n++) { | ||||||
| @ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t | |||||||
|   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] |   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] | ||||||
|  |  | ||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|   static const accscalar_t d[25][25] = |   constexpr accscalar_t d[25][25] = | ||||||
|     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, |     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, | ||||||
|     {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, |     {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, | ||||||
|     {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, |     {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, | ||||||
| @ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t | |||||||
|  |  | ||||||
|   int k, n, sgn; |   int k, n, sgn; | ||||||
|   int maxpow = 0; |   int maxpow = 0; | ||||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? |   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; |     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||||
|   accscalar_t lambda = x / a; |   accscalar_t lambda = x / a; | ||||||
|   accscalar_t sigma = (x - a) / a; |   accscalar_t sigma = (x - a) / a; | ||||||
| @ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar | |||||||
|   int i; |   int i; | ||||||
|   accscalar_t ans, ax, c, yc, r, t, y, z; |   accscalar_t ans, ax, c, yc, r, t, y, z; | ||||||
|   accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; |   accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; | ||||||
|   static const int MAXITER = 2000; |   constexpr int MAXITER = 2000; | ||||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? |   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; |     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||||
|   static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ? |   constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ? | ||||||
|     4.503599627370496e15 : 16777216.; |     4.503599627370496e15 : 16777216.; | ||||||
|   static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? |   constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? | ||||||
|     2.22044604925031308085e-16 : 5.9604644775390625E-8; |     2.22044604925031308085e-16 : 5.9604644775390625E-8; | ||||||
|  |  | ||||||
|   ax = _igam_helper_fac(a, x); |   ax = _igam_helper_fac(a, x); | ||||||
| @ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { | |||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|   accscalar_t absxma_a; |   accscalar_t absxma_a; | ||||||
|  |  | ||||||
|   static const accscalar_t SMALL = 20.0; |   constexpr accscalar_t SMALL = 20.0; | ||||||
|   static const accscalar_t LARGE = 200.0; |   constexpr accscalar_t LARGE = 200.0; | ||||||
|   static const accscalar_t SMALLRATIO = 0.3; |   constexpr accscalar_t SMALLRATIO = 0.3; | ||||||
|   static const accscalar_t LARGERATIO = 4.5; |   constexpr accscalar_t LARGERATIO = 4.5; | ||||||
|  |  | ||||||
|   if ((x < 0) || (a < 0)) { |   if ((x < 0) || (a < 0)) { | ||||||
|     // out of defined-region of the function |     // out of defined-region of the function | ||||||
| @ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { | |||||||
|  |  | ||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|   accscalar_t absxma_a; |   accscalar_t absxma_a; | ||||||
|   static const accscalar_t SMALL = 20.0; |   constexpr accscalar_t SMALL = 20.0; | ||||||
|   static const accscalar_t LARGE = 200.0; |   constexpr accscalar_t LARGE = 200.0; | ||||||
|   static const accscalar_t SMALLRATIO = 0.3; |   constexpr accscalar_t SMALLRATIO = 0.3; | ||||||
|   static const accscalar_t LARGERATIO = 4.5; |   constexpr accscalar_t LARGERATIO = 4.5; | ||||||
|  |  | ||||||
|   // boundary values following SciPy |   // boundary values following SciPy | ||||||
|   if ((x < 0) || (a < 0)) { |   if ((x < 0) || (a < 0)) { | ||||||
|  | |||||||
| @ -1,90 +0,0 @@ | |||||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |  | ||||||
| #include <ATen/core/Tensor.h> |  | ||||||
| #include <ATen/native/cuda/SortingCommon.cuh> |  | ||||||
| #include <ATen/cuda/cub_definitions.cuh> |  | ||||||
|  |  | ||||||
| #ifndef AT_PER_OPERATOR_HEADERS |  | ||||||
| #include <ATen/Functions.h> |  | ||||||
| #else |  | ||||||
| #include <ATen/ops/empty_like.h> |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #include <ATen/cuda/ThrustAllocator.h> |  | ||||||
| #include <thrust/device_ptr.h> |  | ||||||
| #include <thrust/execution_policy.h> |  | ||||||
| #include <thrust/sort.h> |  | ||||||
| #include <thrust/unique.h> |  | ||||||
| #include <thrust/device_ptr.h> |  | ||||||
| #include <thrust/iterator/constant_iterator.h> |  | ||||||
|  |  | ||||||
| namespace at::native { |  | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|  |  | ||||||
| template<typename index_t> |  | ||||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { |  | ||||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |  | ||||||
|   at::cuda::ThrustAllocator allocator; |  | ||||||
|   auto policy = thrust::cuda::par(allocator).on(stream); |  | ||||||
|  |  | ||||||
|   auto num_indices = count.numel(); |  | ||||||
|  |  | ||||||
|   // Compute an increasing sequence per unique item in sortedIndices: |  | ||||||
|   // sorted: 2 5 5 5 7 7 8 9 9 |  | ||||||
|   //  count: 1 1 2 3 1 2 1 1 2 |  | ||||||
|   auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); |  | ||||||
|   auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>()); |  | ||||||
|   thrust::inclusive_scan_by_key( |  | ||||||
|     policy, |  | ||||||
|     sorted_data, |  | ||||||
|     sorted_data + num_indices, |  | ||||||
|     thrust::make_constant_iterator(1), |  | ||||||
|     count_data |  | ||||||
|   ); |  | ||||||
|  |  | ||||||
|   // Take the maximum of each count per unique key in reverse: |  | ||||||
|   // sorted: 2 5 5 5 7 7 8 9 9 |  | ||||||
|   //  count: 1 3 3 3 2 2 1 2 2 |  | ||||||
|   thrust::inclusive_scan_by_key( |  | ||||||
|     policy, |  | ||||||
|     thrust::make_reverse_iterator(sorted_data + num_indices), |  | ||||||
|     thrust::make_reverse_iterator(sorted_data), |  | ||||||
|     thrust::make_reverse_iterator(count_data + num_indices), |  | ||||||
|     thrust::make_reverse_iterator(count_data + num_indices), |  | ||||||
|     thrust::equal_to<index_t>(), |  | ||||||
|     thrust::maximum<index_t>() |  | ||||||
|   ); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template |  | ||||||
| void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count); |  | ||||||
| template |  | ||||||
| void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count); |  | ||||||
|  |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| template<typename index_t> |  | ||||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { |  | ||||||
|   auto stream = at::cuda::getCurrentCUDAStream(); |  | ||||||
|   at::cuda::ThrustAllocator allocator; |  | ||||||
|   auto policy = thrust::cuda::par(allocator).on(stream); |  | ||||||
|   const ptrdiff_t numel = sorted_indices.numel(); |  | ||||||
|   auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); |  | ||||||
|   auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |  | ||||||
|   auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>()); |  | ||||||
|   auto ends = thrust::unique_by_key_copy( |  | ||||||
|           policy, |  | ||||||
|           sorted_indices_dev, |  | ||||||
|           sorted_indices_dev + numel, |  | ||||||
|           thrust::make_counting_iterator(0), |  | ||||||
|           dummy_dev, |  | ||||||
|           thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>())); |  | ||||||
|   return thrust::get<0>(ends) - dummy_dev; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template |  | ||||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets); |  | ||||||
| template |  | ||||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets); |  | ||||||
|  |  | ||||||
| } // namespace at::native |  | ||||||
| @ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify( | |||||||
| const auto digamma_string = jiterator_stringify( | const auto digamma_string = jiterator_stringify( | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   T digamma(T x) { |   T digamma(T x) { | ||||||
|     static const double PI_f64 = 3.14159265358979323846; |     static constexpr double PI_f64 = 3.14159265358979323846; | ||||||
|  |  | ||||||
|     // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard |     // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard | ||||||
|     if (x == 0) { |     if (x == 0) { | ||||||
| @ -3072,9 +3072,9 @@ template <typename scalar_t> | |||||||
| static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { | static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { | ||||||
|   // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma |   // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma | ||||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||||
|   static const double PI_f64 = 3.14159265358979323846; |   static constexpr double PI_f64 = 3.14159265358979323846; | ||||||
|   const accscalar_t PSI_10 = 2.25175258906672110764; |   constexpr accscalar_t PSI_10 = 2.25175258906672110764; | ||||||
|   const accscalar_t A[] = { |   constexpr accscalar_t A[] = { | ||||||
|       8.33333333333333333333E-2, |       8.33333333333333333333E-2, | ||||||
|       -2.10927960927960927961E-2, |       -2.10927960927960927961E-2, | ||||||
|       7.57575757575757575758E-3, |       7.57575757575757575758E-3, | ||||||
|  | |||||||
| @ -146,6 +146,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel( | |||||||
|   int64_t batch_size = target.size(0); |   int64_t batch_size = target.size(0); | ||||||
|   int64_t H = target.size(1); |   int64_t H = target.size(1); | ||||||
|   int64_t W = target.size(2); |   int64_t W = target.size(2); | ||||||
|  |   int64_t n_classes = grad_input.size(1); | ||||||
|  |  | ||||||
|   CUDA_KERNEL_LOOP(index, n_threads) { |   CUDA_KERNEL_LOOP(index, n_threads) { | ||||||
|     const int64_t b = index % batch_size; |     const int64_t b = index % batch_size; | ||||||
| @ -156,6 +157,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel( | |||||||
|     if (cur_target == ignore_index) { |     if (cur_target == ignore_index) { | ||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
|  |     CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes); | ||||||
|     scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1)); |     scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1)); | ||||||
|     grad_input[b][cur_target][h][w] = value * grad_output[b][h][w]; |     grad_input[b][cur_target][h][w] = value * grad_output[b][h][w]; | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -413,14 +413,12 @@ struct ReduceOp { | |||||||
|       value = thread_reduce<output_vec_size>(input_slice); |       value = thread_reduce<output_vec_size>(input_slice); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (config.should_block_y_reduce()) { |  | ||||||
|       value = block_y_reduce<output_vec_size>(value, shared_memory); |  | ||||||
|     } |  | ||||||
|     __syncthreads(); |  | ||||||
|     if (config.should_block_x_reduce()) { |     if (config.should_block_x_reduce()) { | ||||||
|       value = block_x_reduce<output_vec_size>(value, shared_memory); |       value = block_x_reduce<output_vec_size>(value, shared_memory); | ||||||
|     } |     } | ||||||
|  |     if (config.should_block_y_reduce()) { | ||||||
|  |       value = block_y_reduce<output_vec_size>(value, shared_memory); | ||||||
|  |     } | ||||||
|     using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>; |     using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>; | ||||||
|     using offset_vec_t = std::array<index_t, output_vec_size>; |     using offset_vec_t = std::array<index_t, output_vec_size>; | ||||||
|     offset_vec_t base_offsets; |     offset_vec_t base_offsets; | ||||||
| @ -657,8 +655,8 @@ struct ReduceOp { | |||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|     // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics |     // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics | ||||||
|     // matching Triton, etc. |     // matching Triton, etc. | ||||||
|     // todo for AMD |     // TODO(PaulZhang12): AMD and internal | ||||||
|     #ifdef USE_ROCM |     #if defined(USE_ROCM) || defined(FBCODE_CAFFE2) | ||||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { |     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||||
|     #else |     #else | ||||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { |     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||||
| @ -1097,11 +1095,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ | |||||||
|   // threads with different threadIdx.x are independent and will produce results for different outputs. |   // threads with different threadIdx.x are independent and will produce results for different outputs. | ||||||
|   // In such case, values in each loaded vector always correspond to different outputs. |   // In such case, values in each loaded vector always correspond to different outputs. | ||||||
|   if (fastest_moving_stride == sizeof(scalar_t)) { |   if (fastest_moving_stride == sizeof(scalar_t)) { | ||||||
| #ifdef USE_ROCM |  | ||||||
|     if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) { |     if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) { | ||||||
| #else |  | ||||||
|     if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) { |  | ||||||
| #endif |  | ||||||
|       // Case 1: "vectorize along input" |       // Case 1: "vectorize along input" | ||||||
|       // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, |       // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, | ||||||
|       // we should avoid vectorization. |       // we should avoid vectorization. | ||||||
|  | |||||||
| @ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta | |||||||
| template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> | template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> | ||||||
| void mean_kernel_impl(TensorIterator& iter) { | void mean_kernel_impl(TensorIterator& iter) { | ||||||
|   //  returns acc_t for all non-complex dtypes and returns T for c10::complex<T> |   //  returns acc_t for all non-complex dtypes and returns T for c10::complex<T> | ||||||
|  |   constexpr bool is_16_bits = sizeof(scalar_t) == 2; | ||||||
|   using factor_t = typename c10::scalar_value_type<acc_t>::type; |   using factor_t = typename c10::scalar_value_type<acc_t>::type; | ||||||
|   factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel(); |   factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel(); | ||||||
|   gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); |   if constexpr (is_16_bits) { | ||||||
|  |     gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); | ||||||
|  |   } else { | ||||||
|  |     gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void mean_kernel_cuda(TensorIterator& iter) { | static void mean_kernel_cuda(TensorIterator& iter) { | ||||||
|  | |||||||
| @ -13,24 +13,19 @@ namespace at::native { | |||||||
| template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t> | template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t> | ||||||
| struct sum_functor { | struct sum_functor { | ||||||
|   void operator()(TensorIterator& iter) { |   void operator()(TensorIterator& iter) { | ||||||
| #ifdef USE_ROCM |     const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||||||
|     // Half and BFloat16 can be packed in groups of up to 8 elements and |       return a + b; | ||||||
|     // can use *_DWORDX4 instructions to achieve that. |     }; | ||||||
|     const bool is_16_bits = |     constexpr bool is_16_bits = sizeof(scalar_t) == 2; | ||||||
|       ( (std::is_same<at::Half, scalar_t>::value) || |     if constexpr (is_16_bits) { | ||||||
|         (std::is_same<at::BFloat16, scalar_t>::value) ); |  | ||||||
|     if (is_16_bits) { |  | ||||||
|       gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>( |       gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>( | ||||||
|         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { |         iter, func_wrapper<out_t>(sum_combine) | ||||||
|           return a + b; |       ); | ||||||
|         })); |     } else { | ||||||
|       return; |       gpu_reduce_kernel<scalar_t, out_t>( | ||||||
|  |         iter, func_wrapper<out_t>(sum_combine) | ||||||
|  |       ); | ||||||
|     } |     } | ||||||
| #endif |  | ||||||
|     gpu_reduce_kernel<scalar_t, out_t>( |  | ||||||
|         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { |  | ||||||
|           return a + b; |  | ||||||
|         })); |  | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | |||||||
| @ -19,7 +19,6 @@ | |||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| // TODO: remove this when CUDA <11.6 is no longer supported |  | ||||||
| void topk_out_with_sort( | void topk_out_with_sort( | ||||||
|   const Tensor& self, |   const Tensor& self, | ||||||
|   int64_t k, int64_t dim, bool largest, |   int64_t k, int64_t dim, bool largest, | ||||||
| @ -31,21 +30,12 @@ void topk_out_with_sort( | |||||||
|   indices.copy_(sorted_indices.narrow(dim, 0, k)); |   indices.copy_(sorted_indices.narrow(dim, 0, k)); | ||||||
| } | } | ||||||
|  |  | ||||||
| // TODO: remove this when CUDA <11.6 is no longer supported |  | ||||||
| bool disable_sort_for_topk(); |  | ||||||
| bool should_use_sort(const Tensor& self, int64_t dim) { | bool should_use_sort(const Tensor& self, int64_t dim) { | ||||||
| #if defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
|   if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972 |   if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972 | ||||||
|   return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387 |   return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387 | ||||||
| #else | #else | ||||||
|   if (disable_sort_for_topk()) return false; |   return false; | ||||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632 |  | ||||||
|   if (self.dim() == 0) return false; |  | ||||||
|   if (self.dtype() == kBool) return false; // Bool is not support by topk |  | ||||||
|   int64_t slice_size = self.size(dim); |  | ||||||
|   if (slice_size == 0) return false; |  | ||||||
|   int64_t num_slices = self.numel() / slice_size; |  | ||||||
|   return num_slices <= 10 && slice_size >= 100000; |  | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -21,11 +21,6 @@ using namespace at::native; | |||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| // TODO: remove this when CUDA <11.6 is no longer supported |  | ||||||
| bool disable_sort_for_topk() { |  | ||||||
|   return CUB_SUPPORTS_SCAN_BY_KEY(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| namespace sbtopk { // single_block_topk | namespace sbtopk { // single_block_topk | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| @ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts( | |||||||
|   } |   } | ||||||
|   __syncthreads(); |   __syncthreads(); | ||||||
|  |  | ||||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|   return; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|   Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS); |   Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS); | ||||||
|  |  | ||||||
|   // if largest, then only threads that has tidx > desired_digit are active |   // if largest, then only threads that has tidx > desired_digit are active | ||||||
| @ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
| // Assumption: slice_size can not be larger than UINT32_MAX | // Assumption: slice_size can not be larger than UINT32_MAX | ||||||
| template <typename Bitwise> | template <typename Bitwise> | ||||||
| __global__ void computeBlockwiseKthCounts( | __global__ void computeBlockwiseKthCounts( | ||||||
| @ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| #endif |  | ||||||
|  |  | ||||||
| int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { | int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { | ||||||
|   // occupancy of this kernel is limited by registers per threads |   // occupancy of this kernel is limited by registers per threads | ||||||
| @ -687,16 +676,12 @@ void launch( | |||||||
|   uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get()); |   uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get()); | ||||||
|   AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream)); |   AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream)); | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|   auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); |   auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); | ||||||
|   uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get()); |   uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get()); | ||||||
|   AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); |   AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); | ||||||
|  |  | ||||||
|   auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); |   auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); | ||||||
|   uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get()); |   uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get()); | ||||||
| #else |  | ||||||
|   uint32_t* withinKCounts = nullptr; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|   Bitwise desiredMask = 0; |   Bitwise desiredMask = 0; | ||||||
|   dim3 grid; |   dim3 grid; | ||||||
| @ -743,7 +728,6 @@ void launch( | |||||||
|   } |   } | ||||||
|   desired = desired_in; |   desired = desired_in; | ||||||
|  |  | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|   computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>( |   computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>( | ||||||
|     desired, counts, num_blocks, blocks_per_slice, kthCounts); |     desired, counts, num_blocks, blocks_per_slice, kthCounts); | ||||||
|   C10_CUDA_KERNEL_LAUNCH_CHECK(); |   C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
| @ -759,28 +743,6 @@ void launch( | |||||||
|     topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, |     topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, | ||||||
|     blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); |     blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); | ||||||
|   C10_CUDA_KERNEL_LAUNCH_CHECK(); |   C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
| #else |  | ||||||
|   // Find topk values based on kth values |  | ||||||
|   { |  | ||||||
|     dim3 grid; |  | ||||||
|     TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk"); |  | ||||||
|     int warp_size = at::cuda::warp_size(); |  | ||||||
|     dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024)); |  | ||||||
|     sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>( |  | ||||||
|         input, |  | ||||||
|         inputSliceSize, |  | ||||||
|         outputSliceSize, |  | ||||||
|         largest, |  | ||||||
|         numInputSlices, |  | ||||||
|         inputWithinSliceStride, |  | ||||||
|         topK, |  | ||||||
|         topKWithinSliceStride, |  | ||||||
|         indices, |  | ||||||
|         indicesWithinSliceStride, |  | ||||||
|         kthValues); |  | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |  | ||||||
|   } |  | ||||||
| #endif |  | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace mbtopk | } // namespace mbtopk | ||||||
| @ -788,7 +750,6 @@ void launch( | |||||||
| bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | ||||||
|   if (num_slices > std::numeric_limits<uint32_t>::max() || |   if (num_slices > std::numeric_limits<uint32_t>::max() || | ||||||
|       slice_size > std::numeric_limits<uint32_t>::max()) return false; |       slice_size > std::numeric_limits<uint32_t>::max()) return false; | ||||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() |  | ||||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 |   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 | ||||||
|   return (num_slices <= 20 && slice_size >= 20000) || |   return (num_slices <= 20 && slice_size >= 20000) || | ||||||
|       (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || |       (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || | ||||||
| @ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { | |||||||
|       (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || |       (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || | ||||||
|       (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || |       (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || | ||||||
|       (num_slices > 4000 && slice_size >= 400); |       (num_slices > 4000 && slice_size >= 400); | ||||||
| #else |  | ||||||
|   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081 |  | ||||||
|   return (num_slices <= 400 && slice_size >= 5000) || |  | ||||||
|       (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) || |  | ||||||
|       (num_slices >= 4000 && slice_size >= 300); |  | ||||||
| #endif |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void launch_gather_topk_kernel( | void launch_gather_topk_kernel( | ||||||
|  | |||||||
| @ -44,7 +44,7 @@ __global__ void triu_tril_kernel( | |||||||
|     const int64_t k, |     const int64_t k, | ||||||
|     const int64_t N_padded, |     const int64_t N_padded, | ||||||
|     const IndexType last_dim_padded) { |     const IndexType last_dim_padded) { | ||||||
|   int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread; |   int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread; | ||||||
|   if (linear_idx >= N_padded) { |   if (linear_idx >= N_padded) { | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -277,7 +277,7 @@ struct BilinearFilterFunctor { | |||||||
|     return 0; |     return 0; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   static const int size = 2; |   static constexpr int size = 2; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // taken from | // taken from | ||||||
| @ -301,7 +301,7 @@ struct BicubicFilterFunctor { | |||||||
|     return 0; |     return 0; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   static const int size = 4; |   static constexpr int size = 4; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <typename accscalar_t> | template <typename accscalar_t> | ||||||
|  | |||||||
| @ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #ifdef USE_ROCM |  | ||||||
| // Helper function to compute output pixel range that can contribute to input pixel |  | ||||||
| template <typename accscalar_t> |  | ||||||
| __device__ __forceinline__ void compute_output_range( |  | ||||||
|     int input_pos, |  | ||||||
|     accscalar_t scale, |  | ||||||
|     int output_size, |  | ||||||
|     bool align_corners, |  | ||||||
|     int& min_output, |  | ||||||
|     int& max_output) { |  | ||||||
|   accscalar_t lo, hi; |  | ||||||
|   if (align_corners) { |  | ||||||
|       lo = static_cast<accscalar_t>(input_pos - 1) / scale; |  | ||||||
|       hi = static_cast<accscalar_t>(input_pos + 1) / scale; |  | ||||||
|   } else { |  | ||||||
|       lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5); |  | ||||||
|       hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5); |  | ||||||
|   } |  | ||||||
|   min_output = max(0, static_cast<int>(ceil(lo))); |  | ||||||
|   max_output = min(output_size - 1, static_cast<int>(floor(hi))); |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // Backward (adjoint) operation 1 <- 2 (accumulates) | // Backward (adjoint) operation 1 <- 2 (accumulates) | ||||||
| template <typename scalar_t, typename accscalar_t> | template <typename scalar_t, typename accscalar_t> | ||||||
| C10_LAUNCH_BOUNDS_1(1024) | C10_LAUNCH_BOUNDS_1(1024) | ||||||
| @ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame( | |||||||
|     const bool align_corners, |     const bool align_corners, | ||||||
|     scalar_t* __restrict__ idata, |     scalar_t* __restrict__ idata, | ||||||
|     const scalar_t* __restrict__ odata) { |     const scalar_t* __restrict__ odata) { | ||||||
|   // In C++, integer multiplication, like in standard arithmetic, is generally commutative. |  | ||||||
|   const size_t i_numel = nc * width1 * height1; |  | ||||||
| #ifdef USE_ROCM |  | ||||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel; |  | ||||||
|        index += blockDim.x * gridDim.x) { |  | ||||||
|     // Decode input pixel coordinates |  | ||||||
|     size_t index_temp = index; |  | ||||||
|     const int w1 = index_temp % width1; |  | ||||||
|     index_temp /= width1; |  | ||||||
|     const int h1 = index_temp % height1; |  | ||||||
|     const size_t nc_idx = index_temp / height1; |  | ||||||
|  |  | ||||||
|     accscalar_t grad_sum = 0; |  | ||||||
|  |  | ||||||
|     // Find range of output pixels that could interpolate from this input pixel |  | ||||||
|     int h2_min, h2_max, w2_min, w2_max; |  | ||||||
|     compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max); |  | ||||||
|     compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max); |  | ||||||
|  |  | ||||||
|     // Iterate over potential output pixels |  | ||||||
|     for (int h2 = h2_min; h2 <= h2_max; h2++) { |  | ||||||
|       for (int w2 = w2_min; w2 <= w2_max; w2++) { |  | ||||||
|         // Compute source coordinates for this output pixel |  | ||||||
|         const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>( |  | ||||||
|             rheight, h2, align_corners, /*cubic=*/false); |  | ||||||
|         const int h1_base = (int)h1r; |  | ||||||
|         const int h1p = (h1_base < height1 - 1) ? 1 : 0; |  | ||||||
|         const accscalar_t h1lambda = h1r - h1_base; |  | ||||||
|         const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda; |  | ||||||
|  |  | ||||||
|         const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>( |  | ||||||
|             rwidth, w2, align_corners, /*cubic=*/false); |  | ||||||
|         const int w1_base = (int)w1r; |  | ||||||
|         const int w1p = (w1_base < width1 - 1) ? 1 : 0; |  | ||||||
|         const accscalar_t w1lambda = w1r - w1_base; |  | ||||||
|         const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda; |  | ||||||
|  |  | ||||||
|         // Check if our input pixel participates in this interpolation and accumulate all weights |  | ||||||
|         // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse |  | ||||||
|         // to the same pixel, so we need to accumulate weights from all matching positions |  | ||||||
|         accscalar_t weight = 0; |  | ||||||
|  |  | ||||||
|         // Check all four interpolation positions and accumulate weights |  | ||||||
|         if (h1 == h1_base && w1 == w1_base) { |  | ||||||
|           weight += h0lambda * w0lambda;  // top-left |  | ||||||
|         } |  | ||||||
|         if (h1 == h1_base && w1 == w1_base + w1p) { |  | ||||||
|           weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0) |  | ||||||
|         } |  | ||||||
|         if (h1 == h1_base + h1p && w1 == w1_base) { |  | ||||||
|           weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0) |  | ||||||
|         } |  | ||||||
|         if (h1 == h1_base + h1p && w1 == w1_base + w1p) { |  | ||||||
|           weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions) |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         if (weight > 0) { |  | ||||||
|           const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2; |  | ||||||
|           grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]); |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // Write accumulated gradient (no atomics needed) |  | ||||||
|     idata[index] = static_cast<scalar_t>(grad_sum); |  | ||||||
|   } |  | ||||||
| #else |  | ||||||
|   const size_t o_numel = nc * width2 * height2; |   const size_t o_numel = nc * width2 * height2; | ||||||
|  |   const size_t i_numel = nc * width1 * height1; | ||||||
|   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; |   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; | ||||||
|        index += blockDim.x * gridDim.x) { |        index += blockDim.x * gridDim.x) { | ||||||
|     size_t index_temp = index; |     size_t index_temp = index; | ||||||
| @ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame( | |||||||
|         static_cast<scalar_t>(h1lambda * w1lambda * d2val), |         static_cast<scalar_t>(h1lambda * w1lambda * d2val), | ||||||
|         true); |         true); | ||||||
|   } |   } | ||||||
| #endif |  | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename scalar_t, typename accscalar_t> | template <typename scalar_t, typename accscalar_t> | ||||||
| @ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|   // threads are not covering the whole input tensor. |   // threads are not covering the whole input tensor. | ||||||
|   grad_input.zero_(); |   grad_input.zero_(); | ||||||
|  |  | ||||||
|  |   const size_t num_kernels = nbatch * channels * output_height * output_width; | ||||||
|   const int num_threads = std::min( |   const int num_threads = std::min( | ||||||
|       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); |       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); | ||||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
| @ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| #ifdef USE_ROCM |  | ||||||
|   constexpr bool use_input = true; |  | ||||||
| #else |  | ||||||
|   constexpr bool use_input = false; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|   AT_DISPATCH_FLOATING_TYPES_AND2( |   AT_DISPATCH_FLOATING_TYPES_AND2( | ||||||
|       at::ScalarType::Half, at::ScalarType::BFloat16, |       at::ScalarType::Half, at::ScalarType::BFloat16, | ||||||
|       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { |       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { | ||||||
| @ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( |       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||||
|           input_width, output_width, align_corners, scales_w); |           input_width, output_width, align_corners, scales_w); | ||||||
|  |  | ||||||
|       const size_t num_kernels = nbatch * channels * output_height * output_width; |  | ||||||
|  |  | ||||||
|       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t> |       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t> | ||||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>( |           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>( | ||||||
|               input_height, |               input_height, | ||||||
| @ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( | |||||||
|       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( |       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>( | ||||||
|           input_width, output_width, align_corners, scales_w); |           input_width, output_width, align_corners, scales_w); | ||||||
|  |  | ||||||
|       const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width); |  | ||||||
|  |  | ||||||
|       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t> |       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t> | ||||||
|           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), |           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), | ||||||
|              num_threads, |              num_threads, | ||||||
|  | |||||||
| @ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum( | |||||||
|   if constexpr (!rms_norm){ |   if constexpr (!rms_norm){ | ||||||
|     U delta = val - curr_sum.mean; |     U delta = val - curr_sum.mean; | ||||||
|     U new_count = curr_sum.count + 1.f; |     U new_count = curr_sum.count + 1.f; | ||||||
|  | #if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL) | ||||||
|  |     U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count); | ||||||
|  | #else | ||||||
|     U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster |     U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster | ||||||
|  | #endif | ||||||
|     return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; |     return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; | ||||||
|   } else{ |   } else{ | ||||||
|     return {0.f, curr_sum.sigma2 + val * val, 0}; |     return {0.f, curr_sum.sigma2 + val * val, 0}; | ||||||
| @ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine( | |||||||
|     U count = dataA.count + dataB.count; |     U count = dataA.count + dataB.count; | ||||||
|     U mean, sigma2; |     U mean, sigma2; | ||||||
|     if (count > decltype(dataB.count){0}) { |     if (count > decltype(dataB.count){0}) { | ||||||
|  | #if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL) | ||||||
|  |       auto coef = __builtin_amdgcn_rcpf(count); | ||||||
|  | #else | ||||||
|       auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division |       auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division | ||||||
|  | #endif | ||||||
|       auto nA = dataA.count * coef; |       auto nA = dataA.count * coef; | ||||||
|       auto nB = dataB.count * coef; |       auto nB = dataB.count * coef; | ||||||
|       mean = nA*dataA.mean + nB*dataB.mean; |       mean = nA*dataA.mean + nB*dataB.mean; | ||||||
|  | |||||||
| @ -466,7 +466,7 @@ struct ReduceJitOp { | |||||||
|  |  | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |  | ||||||
|     #ifdef USE_ROCM |     #if defined(USE_ROCM) || defined(FBCODE_CAFFE2) | ||||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { |     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||||
|     #else |     #else | ||||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { |     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||||
|  | |||||||
| @ -487,9 +487,7 @@ std::unique_ptr<fe::graph::Graph> build_graph( | |||||||
|   auto scaled_dot_product_flash_attention_options = |   auto scaled_dot_product_flash_attention_options = | ||||||
|       fe::graph::SDPA_attributes() |       fe::graph::SDPA_attributes() | ||||||
|           .set_name("CUDNN_SDPA") |           .set_name("CUDNN_SDPA") | ||||||
|           .set_is_inference(return_softmaxstats == false) |           .set_generate_stats(return_softmaxstats) | ||||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded |  | ||||||
|           // .set_generate_stats(return_softmaxstats) |  | ||||||
|           .set_causal_mask(is_causal) |           .set_causal_mask(is_causal) | ||||||
|           .set_attn_scale(attn_scale); |           .set_attn_scale(attn_scale); | ||||||
|   if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { |   if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { | ||||||
| @ -707,9 +705,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor( | |||||||
|   auto scaled_dot_product_flash_attention_options = |   auto scaled_dot_product_flash_attention_options = | ||||||
|       fe::graph::SDPA_attributes() |       fe::graph::SDPA_attributes() | ||||||
|           .set_name("CUDNN_SDPA_NESTEDTENSOR") |           .set_name("CUDNN_SDPA_NESTEDTENSOR") | ||||||
|           .set_is_inference(return_softmaxstats == false) |           .set_generate_stats(return_softmaxstats) | ||||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded |  | ||||||
|           // .set_generate_stats(return_softmaxstats) |  | ||||||
|           .set_causal_mask(is_causal) |           .set_causal_mask(is_causal) | ||||||
|           .set_attn_scale(attn_scale) |           .set_attn_scale(attn_scale) | ||||||
|           .set_seq_len_q(SEQ_LEN_Q_) |           .set_seq_len_q(SEQ_LEN_Q_) | ||||||
|  | |||||||
| @ -160,8 +160,12 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ | |||||||
| } | } | ||||||
|  |  | ||||||
| static bool mkldnn_conv_enabled_fpmath_mode_tf32(){ | static bool mkldnn_conv_enabled_fpmath_mode_tf32(){ | ||||||
|   return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 && | #if defined(__x86_64__) || defined(_M_X64) | ||||||
|       cpuinfo_has_x86_amx_fp16(); |     return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 && | ||||||
|  |         cpuinfo_has_x86_amx_fp16(); | ||||||
|  | #else | ||||||
|  |     return false;   //TF32 not supported on power system | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { | static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { | ||||||
|  | |||||||
| @ -74,8 +74,12 @@ static bool use_mkldnn_bf32_linear() { | |||||||
| } | } | ||||||
|  |  | ||||||
| static bool use_mkldnn_tf32_linear() { | static bool use_mkldnn_tf32_linear() { | ||||||
|   return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && | #if defined(__x86_64__) || defined(_M_X64) | ||||||
|  |     return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && | ||||||
|       cpuinfo_has_x86_amx_fp16(); |       cpuinfo_has_x86_amx_fp16(); | ||||||
|  | #else | ||||||
|  |   return false;  // TF32 not supported on power system | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| Tensor mkldnn_linear( | Tensor mkldnn_linear( | ||||||
|  | |||||||
| @ -114,8 +114,13 @@ static bool use_mkldnn_bf32_matmul() { | |||||||
|   return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16; |   return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16; | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| static bool use_mkldnn_tf32_matmul() { | static bool use_mkldnn_tf32_matmul() { | ||||||
|   return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32; | #if defined(__x86_64__) || defined(_M_X64) | ||||||
|  |     return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32; | ||||||
|  | #else | ||||||
|  |     return false;  // TF32 not supported on power system | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| // returns an ideep::tensor | // returns an ideep::tensor | ||||||
| @ -411,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ | |||||||
|   // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) |   // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) | ||||||
|   // else called from aten::mv, mat1.size = (m * n), mat2.size = (n) |   // else called from aten::mv, mat1.size = (m * n), mat2.size = (n) | ||||||
|   // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel |   // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel | ||||||
|   static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16; |   constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16; | ||||||
|   if (mat1.dim() == 1 && mat2.dim() == 1) { |   if (mat1.dim() == 1 && mat2.dim() == 1) { | ||||||
|     // aten::dot |     // aten::dot | ||||||
|     return mat1.size(0) > mkldnn_gemm_min_size; |     return mat1.size(0) > mkldnn_gemm_min_size; | ||||||
|  | |||||||
| @ -712,7 +712,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) { | |||||||
|   } else if (scalar.isBoolean()) { |   } else if (scalar.isBoolean()) { | ||||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool)); |     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool)); | ||||||
|   } else if (scalar.isComplex()) { |   } else if (scalar.isComplex()) { | ||||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble)); |     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexFloat)); | ||||||
|   } else { |   } else { | ||||||
|     TORCH_INTERNAL_ASSERT(scalar.isIntegral(false)); |     TORCH_INTERNAL_ASSERT(scalar.isIntegral(false)); | ||||||
|     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong)); |     tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong)); | ||||||
|  | |||||||
| @ -441,7 +441,7 @@ kernel void applySYRK( | |||||||
|     uint3 tid [[thread_position_in_threadgroup]], |     uint3 tid [[thread_position_in_threadgroup]], | ||||||
|     uint3 tgid [[threadgroup_position_in_grid]], |     uint3 tgid [[threadgroup_position_in_grid]], | ||||||
|     uint3 tpg [[threads_per_threadgroup]], |     uint3 tpg [[threads_per_threadgroup]], | ||||||
|     uint sgitg [[simdgroup_index_in_threadgroup]]) { |     uint warp_id [[simdgroup_index_in_threadgroup]]) { | ||||||
|   const uint tx = tid.x; |   const uint tx = tid.x; | ||||||
|   const uint ty = tid.y; |   const uint ty = tid.y; | ||||||
|   const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32; |   const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32; | ||||||
| @ -474,11 +474,8 @@ kernel void applySYRK( | |||||||
|       (actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0); |       (actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0); | ||||||
|  |  | ||||||
|   if (use_simdgroup) { |   if (use_simdgroup) { | ||||||
|     uint warp_id = sgitg; |  | ||||||
|  |  | ||||||
|     simdgroup_matrix<float, 8, 8> negative_identity = |     simdgroup_matrix<float, 8, 8> negative_identity = | ||||||
|         simdgroup_matrix<float, 8, 8>(-1.0); |         simdgroup_matrix<float, 8, 8>(-1.0); | ||||||
|     simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0); |  | ||||||
|     simdgroup_matrix<float, 8, 8> Prod; |     simdgroup_matrix<float, 8, 8> Prod; | ||||||
|     simdgroup_matrix<float, 8, 8> Afrag; |     simdgroup_matrix<float, 8, 8> Afrag; | ||||||
|     simdgroup_matrix<float, 8, 8> Bfrag; |     simdgroup_matrix<float, 8, 8> Bfrag; | ||||||
| @ -521,8 +518,7 @@ kernel void applySYRK( | |||||||
|             /* transpose = */ upper); |             /* transpose = */ upper); | ||||||
|  |  | ||||||
|         simdgroup_multiply(Prod, Afrag, Bfrag); |         simdgroup_multiply(Prod, Afrag, Bfrag); | ||||||
|         simdgroup_multiply(Prod, Prod, negative_identity); |         simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag); | ||||||
|         simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod); |  | ||||||
|       } |       } | ||||||
|  |  | ||||||
|       simdgroup_store( |       simdgroup_store( | ||||||
|  | |||||||
| @ -1,16 +1,16 @@ | |||||||
| #pragma once | #pragma once | ||||||
| #include <c10/metal/common.h> | #include <c10/metal/common.h> | ||||||
|  |  | ||||||
| template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t> | template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim> | ||||||
| struct CatLargeSharedParams { | struct CatSharedParams { | ||||||
|   int32_t ndim; |   int32_t ndim; | ||||||
|   int32_t cat_dim; |   int32_t cat_dim; | ||||||
|   ::c10::metal::array<idx_type_t, N> output_strides; |   ::c10::metal::array<idx_type_t, N> output_strides; | ||||||
|   ::c10::metal::array<idx_type_t, N> output_sizes; |   ::c10::metal::array<idx_type_t, N> output_sizes; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t> | template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim> | ||||||
| struct CatLargeInputParams { | struct CatInputParams { | ||||||
|   idx_type_t cat_dim_offset; |   idx_type_t cat_dim_offset; | ||||||
|   idx_type_t input_element_offset; |   idx_type_t input_element_offset; | ||||||
|   ::c10::metal::array<idx_type_t, N> input_strides; |   ::c10::metal::array<idx_type_t, N> input_strides; | ||||||
|  | |||||||
| @ -6,26 +6,25 @@ | |||||||
| using namespace metal; | using namespace metal; | ||||||
| using namespace c10::metal; | using namespace c10::metal; | ||||||
|  |  | ||||||
| template <typename T_in, typename T_out> | template <typename I, typename T_in, typename T_out> | ||||||
| kernel void cat_large( | kernel void cat( | ||||||
|     constant T_in* input [[buffer(0)]], |     constant T_in* input [[buffer(0)]], | ||||||
|     device T_out* output [[buffer(1)]], |     device T_out* output [[buffer(1)]], | ||||||
|     constant CatLargeSharedParams<>& shared_params [[buffer(2)]], |     constant CatSharedParams<I>& shared_params [[buffer(2)]], | ||||||
|     constant CatLargeInputParams<>& input_params [[buffer(3)]], |     constant CatInputParams<I>& input_params [[buffer(3)]], | ||||||
|     uint tid [[thread_position_in_grid]]) { |     uint tid [[thread_position_in_grid]]) { | ||||||
|   auto ndim = shared_params.ndim; |   auto ndim = shared_params.ndim; | ||||||
|   auto cat_dim = shared_params.cat_dim; |   auto cat_dim = shared_params.cat_dim; | ||||||
|   constant auto& output_strides = shared_params.output_strides; |   constant auto& output_strides = shared_params.output_strides; | ||||||
|   constant auto& output_sizes = shared_params.output_sizes; |  | ||||||
|  |  | ||||||
|   auto cat_dim_offset = input_params.cat_dim_offset; |   auto cat_dim_offset = input_params.cat_dim_offset; | ||||||
|   auto input_element_offset = input_params.input_element_offset; |   auto input_element_offset = input_params.input_element_offset; | ||||||
|   constant auto& input_strides = input_params.input_strides; |   constant auto& input_strides = input_params.input_strides; | ||||||
|   constant auto& input_sizes = input_params.input_sizes; |   constant auto& input_sizes = input_params.input_sizes; | ||||||
|  |  | ||||||
|   auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset; |   auto input_element_idx = static_cast<I>(tid) + input_element_offset; | ||||||
|   int64_t input_offset = 0; |   I input_offset = 0; | ||||||
|   int64_t output_offset = 0; |   I output_offset = 0; | ||||||
|  |  | ||||||
|   for (auto dim = ndim - 1; dim >= 0; dim--) { |   for (auto dim = ndim - 1; dim >= 0; dim--) { | ||||||
|     auto dim_size = input_sizes[dim]; |     auto dim_size = input_sizes[dim]; | ||||||
| @ -42,41 +41,45 @@ kernel void cat_large( | |||||||
|   output[output_offset] = static_cast<T_out>(input[input_offset]); |   output[output_offset] = static_cast<T_out>(input[input_offset]); | ||||||
| } | } | ||||||
|  |  | ||||||
| #define REGISTER_CAT_LARGE_OP(T_in, T_out)                           \ | #define REGISTER_CAT_OP(I, T_in, T_out)                          \ | ||||||
|   template [[host_name("cat_large_" #T_in "_" #T_out)]]              \ |   template [[host_name("cat_" #I "_" #T_in "_" #T_out)]]         \ | ||||||
|   kernel void cat_large<T_in, T_out>(                                \ |   kernel void cat<I, T_in, T_out>(                               \ | ||||||
|       constant T_in * input [[buffer(0)]],                           \ |       constant T_in * input [[buffer(0)]],                       \ | ||||||
|       device T_out * output [[buffer(1)]],                           \ |       device T_out * output [[buffer(1)]],                       \ | ||||||
|       constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \ |       constant CatSharedParams<I> & shared_params [[buffer(2)]], \ | ||||||
|       constant CatLargeInputParams<> & input_params [[buffer(3)]],   \ |       constant CatInputParams<I> & input_params [[buffer(3)]],   \ | ||||||
|       uint tid [[thread_position_in_grid]]); |       uint tid [[thread_position_in_grid]]); | ||||||
|  |  | ||||||
| #define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \ | #define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \ | ||||||
|   REGISTER_CAT_LARGE_OP(float, T_out);               \ |   REGISTER_CAT_OP(I, float, T_out);               \ | ||||||
|   REGISTER_CAT_LARGE_OP(half, T_out);                \ |   REGISTER_CAT_OP(I, half, T_out);                \ | ||||||
|   REGISTER_CAT_LARGE_OP(bfloat, T_out);              \ |   REGISTER_CAT_OP(I, bfloat, T_out);              \ | ||||||
|   REGISTER_CAT_LARGE_OP(int, T_out);                 \ |   REGISTER_CAT_OP(I, int, T_out);                 \ | ||||||
|   REGISTER_CAT_LARGE_OP(uint, T_out);                \ |   REGISTER_CAT_OP(I, uint, T_out);                \ | ||||||
|   REGISTER_CAT_LARGE_OP(long, T_out);                \ |   REGISTER_CAT_OP(I, long, T_out);                \ | ||||||
|   REGISTER_CAT_LARGE_OP(ulong, T_out);               \ |   REGISTER_CAT_OP(I, ulong, T_out);               \ | ||||||
|   REGISTER_CAT_LARGE_OP(short, T_out);               \ |   REGISTER_CAT_OP(I, short, T_out);               \ | ||||||
|   REGISTER_CAT_LARGE_OP(ushort, T_out);              \ |   REGISTER_CAT_OP(I, ushort, T_out);              \ | ||||||
|   REGISTER_CAT_LARGE_OP(char, T_out);                \ |   REGISTER_CAT_OP(I, char, T_out);                \ | ||||||
|   REGISTER_CAT_LARGE_OP(uchar, T_out);               \ |   REGISTER_CAT_OP(I, uchar, T_out);               \ | ||||||
|   REGISTER_CAT_LARGE_OP(bool, T_out); |   REGISTER_CAT_OP(I, bool, T_out); | ||||||
|  |  | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float); | #define REGISTER_CAT_FOR_INDEX_TYPE(I)        \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float);  \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half);   \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int);    \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint);   \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long);   \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong);  \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short);  \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char);   \ | ||||||
| REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool); |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar);  \ | ||||||
|  |   REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool);   \ | ||||||
|  |                                               \ | ||||||
|  |   REGISTER_CAT_OP(I, float2, float2);         \ | ||||||
|  |   REGISTER_CAT_OP(I, half2, half2); | ||||||
|  |  | ||||||
| REGISTER_CAT_LARGE_OP(float2, float2); | REGISTER_CAT_FOR_INDEX_TYPE(int64_t); | ||||||
| REGISTER_CAT_LARGE_OP(half2, half2); | REGISTER_CAT_FOR_INDEX_TYPE(int32_t); | ||||||
|  | |||||||
| @ -907,6 +907,8 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te | |||||||
|   TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, |   TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, | ||||||
|               "index_fill_(): Expected dtype int32 or int64 for index"); |               "index_fill_(): Expected dtype int32 or int64 for index"); | ||||||
|   TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor"); |   TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor"); | ||||||
|  |   TORCH_CHECK(self.is_complex() || !source.is_complex(), | ||||||
|  |               "index_fill_(): Converting complex Scalar to non-complex type is not supported"); | ||||||
|   // MPS.scatter crashes if used with complex dtypes |   // MPS.scatter crashes if used with complex dtypes | ||||||
|   TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported"); |   TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported"); | ||||||
|  |  | ||||||
|  | |||||||
| @ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) | |||||||
|        other.size(0) > max_stride_size || other.size(1) > max_stride_size); |        other.size(0) > max_stride_size || other.size(1) > max_stride_size); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void map_mps_decomposition_error_code_to_blas(const Tensor& status) { | ||||||
|  |   const auto& status_flat = status.view(-1); | ||||||
|  |  | ||||||
|  |   for (const auto i : c10::irange(status_flat.size(0))) { | ||||||
|  |     int code = status_flat[i].item<int>(); | ||||||
|  |     switch (code) { | ||||||
|  |       case MPSMatrixDecompositionStatusSuccess: | ||||||
|  |         status_flat[i] = 0; | ||||||
|  |         break; | ||||||
|  |       case MPSMatrixDecompositionStatusNonPositiveDefinite: | ||||||
|  |       case MPSMatrixDecompositionStatusSingular: | ||||||
|  |         status_flat[i] = 2; | ||||||
|  |         break; | ||||||
|  |       case MPSMatrixDecompositionStatusFailure: | ||||||
|  |         status_flat[i] = -1; | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| } // anonymous namespace | } // anonymous namespace | ||||||
|  |  | ||||||
| static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, | static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, | ||||||
| @ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A, | |||||||
|                   "mpsmatrixdecompositionstatus for details."); |                   "mpsmatrixdecompositionstatus for details."); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   map_mps_decomposition_error_code_to_blas(info); | ||||||
|  |  | ||||||
|   if (!left) { |   if (!left) { | ||||||
|     // If this was a right solve, transpose the result back |     // If this was a right solve, transpose the result back | ||||||
|     result.copy_(result_t.transpose(-2, -1).contiguous()); |     result.copy_(result_t.transpose(-2, -1).contiguous()); | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| #include <ATen/MemoryOverlap.h> | #include <ATen/MemoryOverlap.h> | ||||||
| #include <ATen/WrapDimUtils.h> | #include <ATen/WrapDimUtils.h> | ||||||
| #include <ATen/mps/MPSProfiler.h> | #include <ATen/mps/MPSProfiler.h> | ||||||
|  | #include <ATen/native/Pool.h> | ||||||
| #include <ATen/native/TensorShape.h> | #include <ATen/native/TensorShape.h> | ||||||
| #include <ATen/native/TypeProperties.h> | #include <ATen/native/TypeProperties.h> | ||||||
| #include <ATen/native/mps/OperationUtils.h> | #include <ATen/native/mps/OperationUtils.h> | ||||||
| @ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| // This implementation of cat is used only if one of the inputs or the output is | template <typename T> | ||||||
| // too large to use MPSGraph. | std::string get_type_str(); | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | std::string get_type_str<int64_t>() { | ||||||
|  |   return "int64_t"; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | std::string get_type_str<int32_t>() { | ||||||
|  |   return "int32_t"; | ||||||
|  | } | ||||||
|  |  | ||||||
| // NOTE: `output` is expected to already have the correct size. | // NOTE: `output` is expected to already have the correct size. | ||||||
| static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) { | template <typename idx_type_t> | ||||||
|   CatLargeSharedParams shared_params; | static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) { | ||||||
|  |   CatSharedParams<idx_type_t> shared_params; | ||||||
|  |  | ||||||
|   shared_params.ndim = output.dim(); |   shared_params.ndim = output.dim(); | ||||||
|   shared_params.cat_dim = dimension; |   shared_params.cat_dim = dimension; | ||||||
|  |  | ||||||
|   for (const auto dim : c10::irange(output.dim())) { |   for (const auto dim : c10::irange(output.dim())) { | ||||||
|     shared_params.output_strides[dim] = output.stride(dim); |     shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim)); | ||||||
|     shared_params.output_sizes[dim] = output.size(dim); |     shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim)); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   int64_t cat_dim_offset = 0; |   idx_type_t cat_dim_offset = 0; | ||||||
|   size_t input_idx = 0; |   size_t input_idx = 0; | ||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   // Launch a separate kernels for each input. This will produce some overhead, |   // Launch a separate kernels for each input. This will produce some overhead. | ||||||
|   // but that should be relatively minimal since at least one of the inputs is |   // In order to launch only one kernel to process all inputs, we would have to | ||||||
|   // very large. In order to launch only one kernel to process all inputs, we |   // copy all the input tensor data into a packed buffer, which would not be | ||||||
|   // would have to copy all the input tensor data into a packed buffer, which |   // ideal. | ||||||
|   // would not be ideal. |  | ||||||
|   for (const Tensor& input : inputs) { |   for (const Tensor& input : inputs) { | ||||||
|     if (input.numel() == 0) { |     if (input.numel() == 0) { | ||||||
|       continue; |       continue; | ||||||
| @ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen | |||||||
|  |  | ||||||
|     for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) { |     for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) { | ||||||
|       auto num_threads = std::min(max_num_threads, numel_remaining); |       auto num_threads = std::min(max_num_threads, numel_remaining); | ||||||
|       CatLargeInputParams input_params; |       CatInputParams<idx_type_t> input_params; | ||||||
|  |  | ||||||
|       input_params.cat_dim_offset = cat_dim_offset; |       input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset); | ||||||
|       input_params.input_element_offset = input.numel() - numel_remaining; |       input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining); | ||||||
|  |  | ||||||
|       for (const auto dim : c10::irange(input.dim())) { |       for (const auto dim : c10::irange(input.dim())) { | ||||||
|         input_params.input_strides[dim] = input.stride(dim); |         input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim)); | ||||||
|         input_params.input_sizes[dim] = input.size(dim); |         input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim)); | ||||||
|       } |       } | ||||||
|  |  | ||||||
|       dispatch_sync_with_rethrow(stream->queue(), ^() { |       dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|         @autoreleasepool { |         @autoreleasepool { | ||||||
|           id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder(); |           id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder(); | ||||||
|           auto pipeline_state = lib.getPipelineStateForFunc( |           auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}", | ||||||
|               fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output))); |                                                                         get_type_str<idx_type_t>(), | ||||||
|  |                                                                         scalarToMetalTypeString(input), | ||||||
|  |                                                                         scalarToMetalTypeString(output))); | ||||||
|           getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input}); |           getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input}); | ||||||
|           [computeEncoder setComputePipelineState:pipeline_state]; |           [computeEncoder setComputePipelineState:pipeline_state]; | ||||||
|           mtl_setArgs(computeEncoder, input, output, shared_params, input_params); |           mtl_setArgs(computeEncoder, input, output, shared_params, input_params); | ||||||
| @ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps) | |||||||
|               " and out is on ", |               " and out is on ", | ||||||
|               out.device()); |               out.device()); | ||||||
|  |  | ||||||
|   // TODO: For better performance by eliminating input tensor gathering and post transpose, |  | ||||||
|   // TODO: it is better to keep the out tensor's memory format. |  | ||||||
|   // TODO: dimension needs to be recomputed as: |  | ||||||
|   // TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1 |  | ||||||
|   if (needsGather(out)) { |  | ||||||
|     out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); |  | ||||||
|   } |  | ||||||
|   std::vector<int64_t> size(notSkippedTensor.sizes().vec()); |   std::vector<int64_t> size(notSkippedTensor.sizes().vec()); | ||||||
|  |  | ||||||
|   // Compute size of the result in the cat dimension |   // Compute size of the result in the cat dimension | ||||||
| @ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps) | |||||||
|   has_large_tensor |= isTooLargeForMPSGraph(out); |   has_large_tensor |= isTooLargeForMPSGraph(out); | ||||||
|  |  | ||||||
|   if (has_large_tensor) { |   if (has_large_tensor) { | ||||||
|     return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out); |     return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out); | ||||||
|   } |   } else { | ||||||
|  |     return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out); | ||||||
|   struct CachedGraph : public MPSCachedGraph { |  | ||||||
|     CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} |  | ||||||
|     std::vector<MPSGraphTensor*> inputTensors_; |  | ||||||
|     MPSGraphTensor* outputTensor_ = nil; |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   @autoreleasepool { |  | ||||||
|     std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" + |  | ||||||
|         (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); |  | ||||||
|     if (!all_same_dtype) { |  | ||||||
|       key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride); |  | ||||||
|     } else { |  | ||||||
|       key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size()); |  | ||||||
|     } |  | ||||||
|     for (auto idx : skipped_tensor_indices) { |  | ||||||
|       key += "," + std::to_string(idx); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |  | ||||||
|       auto len_tensor_array = inputs.size() - skipped_tensor_indices.size(); |  | ||||||
|       std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array); |  | ||||||
|       newCachedGraph->inputTensors_.reserve(len_tensor_array); |  | ||||||
|  |  | ||||||
|       for (const auto idx : c10::irange(len_tensor_array)) { |  | ||||||
|         const Tensor& tensor = input_tensors[idx]; |  | ||||||
|         auto scalar_type = getMPSScalarType(tensor.scalar_type()); |  | ||||||
|         if (tensor.scalar_type() == kBool) { |  | ||||||
|           scalar_type = MPSDataTypeInt8; |  | ||||||
|         } |  | ||||||
|         newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type); |  | ||||||
|         if (tensor.scalar_type() != out_dtype) { |  | ||||||
|           castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx] |  | ||||||
|                                                 toType:getMPSDataType(out_dtype) |  | ||||||
|                                                   name:@"castInput"]; |  | ||||||
|         } else { |  | ||||||
|           castInputTensors[idx] = newCachedGraph->inputTensors_[idx]; |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array]; |  | ||||||
|       MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray |  | ||||||
|                                                    dimension:dimension // Maybe convert this from int64_t -> int32 |  | ||||||
|                                                         name:nil]; |  | ||||||
|       if (getMPSDataType(out_dtype) == MPSDataTypeBool) { |  | ||||||
|         outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"]; |  | ||||||
|       } |  | ||||||
|       newCachedGraph->outputTensor_ = outputTensor; |  | ||||||
|     }); |  | ||||||
|  |  | ||||||
|     std::vector<Placeholder> inputPlaceholders; |  | ||||||
|     int i = 0; |  | ||||||
|     int t_idx = 0; |  | ||||||
|     for (const Tensor& tensor : materialized_inputs) { |  | ||||||
|       if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) { |  | ||||||
|         auto scalar_type = getMPSScalarType(tensor.scalar_type()); |  | ||||||
|         if (tensor.scalar_type() == kBool) { |  | ||||||
|           scalar_type = MPSDataTypeInt8; |  | ||||||
|         } |  | ||||||
|         inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type); |  | ||||||
|         t_idx++; |  | ||||||
|       } |  | ||||||
|       i++; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     auto outputDataType = getMPSScalarType(out.scalar_type()); |  | ||||||
|     Placeholder outputPlaceholder = |  | ||||||
|         Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType); |  | ||||||
|  |  | ||||||
|     NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; |  | ||||||
|     for (auto& inputPlaceholder : inputPlaceholders) { |  | ||||||
|       feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); |  | ||||||
|     } |  | ||||||
|     runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -706,6 +706,7 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all |     NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
|  |  | ||||||
| - func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor | - func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor | ||||||
| @ -715,6 +716,7 @@ | |||||||
|   cpp_no_default_args: ['dim'] |   cpp_no_default_args: ['dim'] | ||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: all_dims_default |     CompositeExplicitAutograd: all_dims_default | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -723,6 +725,7 @@ | |||||||
|     CPU, CUDA: all_out |     CPU, CUDA: all_out | ||||||
|     MPS: all_out_mps |     MPS: all_out_mps | ||||||
|     MTIA: all_out_mtia |     MTIA: all_out_mtia | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -731,13 +734,16 @@ | |||||||
|     CPU, CUDA: all_dims_out |     CPU, CUDA: all_dims_out | ||||||
|     CompositeExplicitAutograd: all_dims_out_default |     CompositeExplicitAutograd: all_dims_out_default | ||||||
|   cpp_no_default_args: ['dim'] |   cpp_no_default_args: ['dim'] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor | - func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool | - func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool | ||||||
|   variants: function, method |   variants: function, method | ||||||
| @ -749,14 +755,14 @@ | |||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   structured_delegate: any.out |   structured_delegate: any.out | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor | - func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   structured_delegate: any.dims_out |   structured_delegate: any.dims_out | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   cpp_no_default_args: ['dim'] |   cpp_no_default_args: ['dim'] | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: any_dims_default |     CompositeExplicitAutograd: any_dims_default | ||||||
|  |  | ||||||
| @ -766,6 +772,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: any_out |     CPU, CUDA: any_out | ||||||
|     MPS: any_out_mps |     MPS: any_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -774,13 +781,16 @@ | |||||||
|     CPU, CUDA: any_dims_out |     CPU, CUDA: any_dims_out | ||||||
|     CompositeExplicitAutograd: any_dims_out_default |     CompositeExplicitAutograd: any_dims_out_default | ||||||
|   cpp_no_default_args: ['dim'] |   cpp_no_default_args: ['dim'] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor | - func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | - func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||||||
|   dispatch: |   dispatch: | ||||||
| @ -826,25 +836,27 @@ | |||||||
|   structured_delegate: argmax.out |   structured_delegate: argmax.out | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: argmax_out |     CPU, CUDA: argmax_out | ||||||
|     MPS: argmax_out_mps |     MPS: argmax_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor | - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor | ||||||
|   structured_delegate: argmin.out |   structured_delegate: argmin.out | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: argmin_out |     CPU, CUDA: argmin_out | ||||||
|     MPS: argmin_out_mps |     MPS: argmin_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: acosh(Tensor self) -> Tensor | - func: acosh(Tensor self) -> Tensor | ||||||
|   variants: function, method |   variants: function, method | ||||||
| @ -1370,6 +1382,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU: bmm_sparse_cpu |     SparseCPU: bmm_sparse_cpu | ||||||
|     SparseCUDA: bmm_sparse_cuda |     SparseCUDA: bmm_sparse_cuda | ||||||
|  |     SparseMPS: bmm_sparse_mps | ||||||
|     NestedTensorCPU: bmm_nested |     NestedTensorCPU: bmm_nested | ||||||
|     NestedTensorCUDA: bmm_nested_cuda |     NestedTensorCUDA: bmm_nested_cuda | ||||||
|   tags: core |   tags: core | ||||||
| @ -1385,6 +1398,7 @@ | |||||||
|     MTIA: bmm_out_mtia |     MTIA: bmm_out_mtia | ||||||
|     SparseCPU: bmm_out_sparse_cpu |     SparseCPU: bmm_out_sparse_cpu | ||||||
|     SparseCUDA: bmm_out_sparse_cuda |     SparseCUDA: bmm_out_sparse_cuda | ||||||
|  |     SparseMPS: bmm_out_sparse_mps | ||||||
|     SparseCsrCUDA: bmm_out_sparse_csr_cuda |     SparseCsrCUDA: bmm_out_sparse_csr_cuda | ||||||
|  |  | ||||||
| - func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor | - func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor | ||||||
| @ -1867,12 +1881,14 @@ | |||||||
|     CUDA: count_nonzero_cuda |     CUDA: count_nonzero_cuda | ||||||
|     MPS: count_nonzero_mps |     MPS: count_nonzero_mps | ||||||
|   autogen: count_nonzero.dim_IntList_out |   autogen: count_nonzero.dim_IntList_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: count_nonzero(Tensor self, int? dim=None) -> Tensor | - func: count_nonzero(Tensor self, int? dim=None) -> Tensor | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: count_nonzero |     CompositeExplicitAutograd: count_nonzero | ||||||
|   autogen: count_nonzero.out |   autogen: count_nonzero.out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor | - func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor | ||||||
|   variants: function, method |   variants: function, method | ||||||
| @ -3793,19 +3809,23 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: logsumexp |     CompositeExplicitAutograd: logsumexp | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   dispatch: |   dispatch: | ||||||
|     # calls squeeze |     # calls squeeze | ||||||
|     CompositeExplicitAutogradNonFunctional: logsumexp_out |     CompositeExplicitAutogradNonFunctional: logsumexp_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor | - func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor | - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor | ||||||
|  |  | ||||||
| @ -3855,6 +3875,7 @@ | |||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   structured_delegate: aminmax.out |   structured_delegate: aminmax.out | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) | - func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -3862,6 +3883,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA, MTIA: aminmax_out |     CPU, CUDA, MTIA: aminmax_out | ||||||
|     MPS: aminmax_out_mps |     MPS: aminmax_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor | - func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor | ||||||
|   dispatch: |   dispatch: | ||||||
| @ -3877,7 +3899,7 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     QuantizedCPU, QuantizedCUDA: qmax |     QuantizedCPU, QuantizedCUDA: qmax | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) | - func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -3887,13 +3909,16 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA, MTIA: max_out |     CPU, CUDA, MTIA: max_out | ||||||
|     MPS: max_out_mps |     MPS: max_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) | - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) | - func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor | - func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor | ||||||
|   variants: function |   variants: function | ||||||
| @ -3906,13 +3931,14 @@ | |||||||
| - func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor | - func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   structured_delegate: amax.out |   structured_delegate: amax.out | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA, MTIA: amax_out |     CPU, CUDA, MTIA: amax_out | ||||||
|     MPS: amax_out_mps |     MPS: amax_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| # Return: (Tensor output, Tensor indices) | # Return: (Tensor output, Tensor indices) | ||||||
| - func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) | - func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) | ||||||
| @ -3974,13 +4000,14 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: mean |     CompositeExplicitAutograd: mean | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| # For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. | # For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. | ||||||
| - func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: mean_dtype_out |     CompositeExplicitAutograd: mean_dtype_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   structured_delegate: mean.out |   structured_delegate: mean.out | ||||||
| @ -3988,7 +4015,7 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     QuantizedCPU: mean_quantized_cpu |     QuantizedCPU: mean_quantized_cpu | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
| @ -3997,13 +4024,16 @@ | |||||||
|     CPU, CUDA: mean_out |     CPU, CUDA: mean_out | ||||||
|     MPS: mean_out_mps |     MPS: mean_out_mps | ||||||
|     QuantizedCPU: mean_out_quantized_cpu |     QuantizedCPU: mean_out_quantized_cpu | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   device_check: NoCheck   # Composite |   device_check: NoCheck   # Composite | ||||||
| @ -4066,7 +4096,7 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     QuantizedCPU, QuantizedCUDA: qmin |     QuantizedCPU, QuantizedCUDA: qmin | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) | - func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -4076,24 +4106,28 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA, MTIA: min_out |     CPU, CUDA, MTIA: min_out | ||||||
|     MPS: min_out_mps |     MPS: min_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) | - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) | - func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor | - func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   structured_delegate: amin.out |   structured_delegate: amin.out | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA, MTIA: amin_out |     CPU, CUDA, MTIA: amin_out | ||||||
|     MPS: amin_out_mps |     MPS: amin_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| # TODO: Add this function to MPS dispatch key so that we avoid declaring it in | # TODO: Add this function to MPS dispatch key so that we avoid declaring it in | ||||||
| # native_functions.yaml | # native_functions.yaml | ||||||
| @ -4173,7 +4207,7 @@ | |||||||
|   structured_delegate: mm.out |   structured_delegate: mm.out | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU, SparseCUDA: _sparse_mm |     SparseCPU, SparseCUDA, SparseMPS: _sparse_mm | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm | ||||||
|   tags: core |   tags: core | ||||||
|  |  | ||||||
| @ -5858,6 +5892,7 @@ | |||||||
|     SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo |     SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr | ||||||
|   autogen: sum.out |   autogen: sum.out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype |   # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype | ||||||
| @ -5868,11 +5903,12 @@ | |||||||
|     NestedTensorCPU: NestedTensor_sum_dim_CPU |     NestedTensorCPU: NestedTensor_sum_dim_CPU | ||||||
|     SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo |     SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
| @ -5880,9 +5916,11 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: sum_out |     CPU, CUDA: sum_out | ||||||
|     MPS: sum_out_mps |     MPS: sum_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| # TODO: this function will be replaced once nested expand semantics have been settled on | # TODO: this function will be replaced once nested expand semantics have been settled on | ||||||
| - func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor | - func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor | ||||||
| @ -5894,11 +5932,13 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: nansum |     CPU, CUDA: nansum | ||||||
|     MPS: nansum_mps |     MPS: nansum_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: nansum_out |     CPU, CUDA: nansum_out | ||||||
|     MPS: nansum_out_mps |     MPS: nansum_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor | - func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor | ||||||
|   variants: function, method |   variants: function, method | ||||||
| @ -5962,11 +6002,13 @@ | |||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor | - func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor | - func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -5975,16 +6017,19 @@ | |||||||
|     CPU, CUDA: std |     CPU, CUDA: std | ||||||
|     MPS: std_mps |     MPS: std_mps | ||||||
|     QuantizedCPU: std_quantized_cpu |     QuantizedCPU: std_quantized_cpu | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) | - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | - func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | - func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -5993,42 +6038,51 @@ | |||||||
|     CPU, CUDA: std_mean |     CPU, CUDA: std_mean | ||||||
|     MPS: std_mean_mps |     MPS: std_mean_mps | ||||||
|   autogen: std_mean.correction_out |   autogen: std_mean.correction_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | - func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | - func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | - func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: std_out |     CPU, CUDA: std_out | ||||||
|     QuantizedCPU: std_out_quantized_cpu |     QuantizedCPU: std_out_quantized_cpu | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor | - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor | - func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | - func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor | - func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -6037,13 +6091,13 @@ | |||||||
|     CPU, CUDA: prod |     CPU, CUDA: prod | ||||||
|     MPS: prod_mps |     MPS: prod_mps | ||||||
|   autogen: prod.out |   autogen: prod.out | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   structured_delegate: prod.int_out |   structured_delegate: prod.int_out | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
| @ -6051,13 +6105,16 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: prod_out |     CPU, CUDA: prod_out | ||||||
|     MPS: prod_out_mps |     MPS: prod_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: t(Tensor(a) self) -> Tensor(a) | - func: t(Tensor(a) self) -> Tensor(a) | ||||||
|   device_check: NoCheck |   device_check: NoCheck | ||||||
| @ -6518,11 +6575,12 @@ | |||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor | - func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |  | ||||||
| - func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor | - func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||||
| @ -6531,43 +6589,52 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: var |     CPU, CUDA: var | ||||||
|     MPS: var_mps |     MPS: var_mps | ||||||
|   tags: core |     MTIA: var_mtia | ||||||
|  |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | - func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: var_out |     CPU, CUDA: var_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor | - func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor | - func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | - func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) | - func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | - func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | - func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -6576,15 +6643,18 @@ | |||||||
|     CPU, CUDA: var_mean |     CPU, CUDA: var_mean | ||||||
|     MPS: var_mean_mps |     MPS: var_mean_mps | ||||||
|   autogen: var_mean.correction_out |   autogen: var_mean.correction_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | - func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|   cpp_no_default_args: ["unbiased"] |   cpp_no_default_args: ["unbiased"] | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | - func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function |   variants: function | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: view_as(Tensor(a) self, Tensor other) -> Tensor(a) | - func: view_as(Tensor(a) self, Tensor other) -> Tensor(a) | ||||||
|   variants: method |   variants: method | ||||||
| @ -6844,6 +6914,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: norm |     CompositeExplicitAutograd: norm | ||||||
|   autogen: norm.ScalarOpt_dtype_out |   autogen: norm.ScalarOpt_dtype_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor | - func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -6851,6 +6922,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CompositeExplicitAutograd: norm |     CompositeExplicitAutograd: norm | ||||||
|   autogen: norm.Scalar_out |   autogen: norm.Scalar_out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor | - func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor | ||||||
|   structured_delegate: norm.dtype_out |   structured_delegate: norm.dtype_out | ||||||
| @ -6858,6 +6930,7 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm |     SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor | - func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor | ||||||
|   structured_delegate: norm.out |   structured_delegate: norm.out | ||||||
| @ -6865,6 +6938,7 @@ | |||||||
|   variants: function, method |   variants: function, method | ||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU, SparseCUDA, SparseMPS: sparse_norm |     SparseCPU, SparseCUDA, SparseMPS: sparse_norm | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) | - func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
| @ -6872,6 +6946,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: norm_dtype_out |     CPU, CUDA: norm_dtype_out | ||||||
|     MPS: norm_dtype_out_mps |     MPS: norm_dtype_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   structured: True |   structured: True | ||||||
| @ -6879,21 +6954,26 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: norm_out |     CPU, CUDA: norm_out | ||||||
|     MPS: norm_out_mps |     MPS: norm_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| # These four redispatch in their implementation, so OK to be CompositeImplicitAutograd | # These four redispatch in their implementation, so OK to be CompositeImplicitAutograd | ||||||
| - func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor | - func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor | - func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   variants: function, method |   variants: function, method | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) | - func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | - func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) | - func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) | ||||||
|   variants: method, function |   variants: method, function | ||||||
| @ -7111,6 +7191,7 @@ | |||||||
|     MTIA: addmm_out_mtia |     MTIA: addmm_out_mtia | ||||||
|     SparseCPU: addmm_out_sparse_dense_cpu |     SparseCPU: addmm_out_sparse_dense_cpu | ||||||
|     SparseCUDA: addmm_out_sparse_dense_cuda |     SparseCUDA: addmm_out_sparse_dense_cuda | ||||||
|  |     SparseMPS: addmm_out_sparse_dense_mps | ||||||
|     SparseCsrCPU: addmm_out_sparse_compressed_cpu |     SparseCsrCPU: addmm_out_sparse_compressed_cpu | ||||||
|     SparseCsrCUDA: addmm_out_sparse_compressed_cuda |     SparseCsrCUDA: addmm_out_sparse_compressed_cuda | ||||||
|  |  | ||||||
| @ -7120,6 +7201,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU: addmm_sparse_dense_cpu |     SparseCPU: addmm_sparse_dense_cpu | ||||||
|     SparseCUDA: addmm_sparse_dense_cuda |     SparseCUDA: addmm_sparse_dense_cuda | ||||||
|  |     SparseMPS: addmm_sparse_dense_mps | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense | ||||||
|   tags: core |   tags: core | ||||||
|  |  | ||||||
| @ -7384,7 +7466,7 @@ | |||||||
| - func: sparse_mask(Tensor self, Tensor mask) -> Tensor | - func: sparse_mask(Tensor self, Tensor mask) -> Tensor | ||||||
|   variants: method |   variants: method | ||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU, SparseCUDA: sparse_mask |     SparseCPU, SparseCUDA, SparseMPS: sparse_mask | ||||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed |     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed | ||||||
|   autogen: sparse_mask.out |   autogen: sparse_mask.out | ||||||
|  |  | ||||||
| @ -10077,12 +10159,14 @@ | |||||||
|     CPU, CUDA: min |     CPU, CUDA: min | ||||||
|     MPS: min_mps |     MPS: min_mps | ||||||
|     QuantizedCPU: min_quantized_cpu |     QuantizedCPU: min_quantized_cpu | ||||||
|  |   tags: [reduction] | ||||||
|  |  | ||||||
| - func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | - func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: min_unary_out |     CPU, CUDA: min_unary_out | ||||||
|     QuantizedCPU: min_quantized_unary_out |     QuantizedCPU: min_quantized_unary_out | ||||||
|  |   tags: [reduction] | ||||||
|  |  | ||||||
| - func: fmin(Tensor self, Tensor other) -> Tensor | - func: fmin(Tensor self, Tensor other) -> Tensor | ||||||
|   structured_delegate: fmin.out |   structured_delegate: fmin.out | ||||||
| @ -10105,6 +10189,7 @@ | |||||||
|     CPU, CUDA: max |     CPU, CUDA: max | ||||||
|     MPS: max_mps |     MPS: max_mps | ||||||
|     QuantizedCPU: max_quantized_cpu |     QuantizedCPU: max_quantized_cpu | ||||||
|  |   tags: [reduction] | ||||||
|  |  | ||||||
| - func: fmax(Tensor self, Tensor other) -> Tensor | - func: fmax(Tensor self, Tensor other) -> Tensor | ||||||
|   structured_delegate: fmax.out |   structured_delegate: fmax.out | ||||||
| @ -10151,6 +10236,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: max_unary_out |     CPU, CUDA: max_unary_out | ||||||
|     QuantizedCPU: max_quantized_unary_out |     QuantizedCPU: max_quantized_unary_out | ||||||
|  |   tags: [reduction] | ||||||
|  |  | ||||||
| - func: minimum(Tensor self, Tensor other) -> Tensor | - func: minimum(Tensor self, Tensor other) -> Tensor | ||||||
|   structured_delegate: minimum.out |   structured_delegate: minimum.out | ||||||
| @ -10270,6 +10356,7 @@ | |||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
|   structured_delegate: all.all_out |   structured_delegate: all.all_out | ||||||
|   variants: method, function |   variants: method, function | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | - func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck |   device_check: NoCheck | ||||||
| @ -10278,6 +10365,7 @@ | |||||||
|     CPU, CUDA: all_all_out |     CPU, CUDA: all_all_out | ||||||
|     MTIA: all_all_out_mtia |     MTIA: all_all_out_mtia | ||||||
|     MPS: all_all_out_mps |     MPS: all_all_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: any(Tensor self) -> Tensor | - func: any(Tensor self) -> Tensor | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -10285,7 +10373,7 @@ | |||||||
|   variants: method, function |   variants: method, function | ||||||
|   dispatch: |   dispatch: | ||||||
|     SparseCPU, SparseCUDA, SparseMPS: any_sparse |     SparseCPU, SparseCUDA, SparseMPS: any_sparse | ||||||
|   tags: core |   tags: [core, reduction] | ||||||
|  |  | ||||||
| - func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | - func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck |   device_check: NoCheck | ||||||
| @ -10293,6 +10381,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: any_all_out |     CPU, CUDA: any_all_out | ||||||
|     MPS: any_all_out_mps |     MPS: any_all_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) | - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) | ||||||
|   device_check: NoCheck   # TensorIterator |   device_check: NoCheck   # TensorIterator | ||||||
| @ -14344,6 +14433,7 @@ | |||||||
|   python_module: linalg |   python_module: linalg | ||||||
|   variants: function |   variants: function | ||||||
|   structured_delegate: linalg_vector_norm.out |   structured_delegate: linalg_vector_norm.out | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | - func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
| @ -14351,6 +14441,7 @@ | |||||||
|   dispatch: |   dispatch: | ||||||
|     CPU, CUDA: linalg_vector_norm_out |     CPU, CUDA: linalg_vector_norm_out | ||||||
|     MPS: linalg_vector_norm_out_mps |     MPS: linalg_vector_norm_out_mps | ||||||
|  |   tags: reduction | ||||||
|  |  | ||||||
| - func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | - func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor | ||||||
|   python_module: linalg |   python_module: linalg | ||||||
|  | |||||||
| @ -184,15 +184,23 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba | |||||||
|           0 & \text{ else } |           0 & \text{ else } | ||||||
|         \end{cases} |         \end{cases} | ||||||
|   */ |   */ | ||||||
|   float scale_val = scale[0].item<float>(); |  | ||||||
|   float inv_scale_val = 1.0f / scale_val; |  | ||||||
|   int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false); |  | ||||||
|  |  | ||||||
|   TORCH_CHECK(dY.scalar_type() == ScalarType::Float); |   bool is_bfloat16 = (X.scalar_type() == at::kBFloat16); | ||||||
|   TORCH_CHECK(X.scalar_type() == ScalarType::Float); |  | ||||||
|   TORCH_CHECK(scale.scalar_type() == ScalarType::Float); |   at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X; | ||||||
|   TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float); |   at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY; | ||||||
|   TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size"); |   at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale; | ||||||
|  |   at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point; | ||||||
|  |  | ||||||
|  |   float scale_val = scale_[0].item<float>(); | ||||||
|  |   float inv_scale_val = 1.0f / scale_val; | ||||||
|  |   int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(dY_.scalar_type() == ScalarType::Float); | ||||||
|  |   TORCH_CHECK(X_.scalar_type() == ScalarType::Float); | ||||||
|  |   TORCH_CHECK(scale_.scalar_type() == ScalarType::Float); | ||||||
|  |   TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float); | ||||||
|  |   TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size"); | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       quant_min <= 0 && quant_max >= 0, |       quant_min <= 0 && quant_max >= 0, | ||||||
|       "`quant_min` should be less than or \ |       "`quant_min` should be less than or \ | ||||||
| @ -200,28 +208,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba | |||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       zero_point_val >= quant_min && zero_point_val <= quant_max, |       zero_point_val >= quant_min && zero_point_val <= quant_max, | ||||||
|       "`zero_point` must be between `quant_min` and `quant_max`."); |       "`zero_point` must be between `quant_min` and `quant_max`."); | ||||||
|   if (X.numel() <= 0) { |   if (X_.numel() <= 0) { | ||||||
|     return std::make_tuple(X, scale, zero_point); |     return std::make_tuple(X, scale, zero_point); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); |   auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve); | ||||||
|   auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve); |   auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve); | ||||||
|   auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve); |   auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve); | ||||||
|  |  | ||||||
|   auto iter = TensorIteratorConfig() |   auto iter = TensorIteratorConfig() | ||||||
|     .add_output(dX) |     .add_output(dX) | ||||||
|     .add_output(dScale_vec) |     .add_output(dScale_vec) | ||||||
|     .add_output(dZeroPoint_vec) |     .add_output(dZeroPoint_vec) | ||||||
|     .add_input(X) |     .add_input(X_) | ||||||
|     .add_input(dY) |     .add_input(dY_) | ||||||
|     .build(); |     .build(); | ||||||
|  |  | ||||||
|   fake_quant_grad_learnable_tensor_stub( |   fake_quant_grad_learnable_tensor_stub( | ||||||
|     X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor); |     X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor); | ||||||
|  |  | ||||||
|   // The total sums over the scale and zero point gradient vectors are what will be returned in the end. |   // The total sums over the scale and zero point gradient vectors are what will be returned in the end. | ||||||
|   auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device()); |   auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device()); | ||||||
|   auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device()); |   auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device()); | ||||||
|  |  | ||||||
|   return std::make_tuple(dX, dScale, dZeroPoint); |   return std::make_tuple(dX, dScale, dZeroPoint); | ||||||
| } | } | ||||||
|  | |||||||
| @ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu( | |||||||
|  |  | ||||||
| #if defined(__ARM_NEON__) || defined(__aarch64__) | #if defined(__ARM_NEON__) || defined(__aarch64__) | ||||||
|  |  | ||||||
| const static int PARALLEL_THRESHOLD = 1 << 20; | constexpr static int PARALLEL_THRESHOLD = 1 << 20; | ||||||
|  |  | ||||||
| // Generic template defaults to naive quantize implementation | // Generic template defaults to naive quantize implementation | ||||||
| template <typename T> | template <typename T> | ||||||
|  | |||||||
| @ -1388,7 +1388,7 @@ namespace at::native { | |||||||
|     TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1, |     TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1, | ||||||
|         "onednn int8 linear: act scale/zp size should be 1/<=1"); |         "onednn int8 linear: act scale/zp size should be 1/<=1"); | ||||||
|     static std::optional<at::Tensor> other = std::nullopt; |     static std::optional<at::Tensor> other = std::nullopt; | ||||||
|     static const std::string_view binary_post_op = "none"; |     constexpr std::string_view binary_post_op = "none"; | ||||||
|     int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0; |     int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0; | ||||||
|     return linear_int8_with_onednn_weight( |     return linear_int8_with_onednn_weight( | ||||||
|         act, act_scale.item().toDouble(), act_zp, |         act, act_scale.item().toDouble(), act_zp, | ||||||
|  | |||||||
| @ -16,8 +16,8 @@ namespace { | |||||||
|  |  | ||||||
| #ifdef USE_PYTORCH_QNNPACK | #ifdef USE_PYTORCH_QNNPACK | ||||||
|  |  | ||||||
| const static float qnnpack_softmax_output_scale = 0x1.0p-8f; | constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f; | ||||||
| const static int qnnpack_softmax_output_zero_point = 0; | constexpr static int qnnpack_softmax_output_zero_point = 0; | ||||||
|  |  | ||||||
| bool is_qnnpack_compatible( | bool is_qnnpack_compatible( | ||||||
|     const Tensor& qx, |     const Tensor& qx, | ||||||
|  | |||||||
| @ -1,6 +1,9 @@ | |||||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||||
| #include <ATen/native/SparseTensorUtils.h> | #include <ATen/native/SparseTensorUtils.h> | ||||||
|  | #include <ATen/ExpandUtils.h> | ||||||
| #include <ATen/native/mps/OperationUtils.h> | #include <ATen/native/mps/OperationUtils.h> | ||||||
|  | #include <ATen/native/sparse/SparseStubs.h> | ||||||
|  | #include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h> | ||||||
|  |  | ||||||
| #ifndef AT_PER_OPERATOR_HEADERS | #ifndef AT_PER_OPERATOR_HEADERS | ||||||
| #include <ATen/Functions.h> | #include <ATen/Functions.h> | ||||||
| @ -13,7 +16,11 @@ | |||||||
| #include <ATen/ops/mul_native.h> | #include <ATen/ops/mul_native.h> | ||||||
| #include <ATen/ops/empty_native.h> | #include <ATen/ops/empty_native.h> | ||||||
| #include <ATen/ops/zeros_native.h> | #include <ATen/ops/zeros_native.h> | ||||||
|  | #include <ATen/ops/ones_like.h> | ||||||
|  | #include <ATen/ops/argsort.h> | ||||||
| #include <ATen/ops/result_type.h> | #include <ATen/ops/result_type.h> | ||||||
|  | #include <ATen/ops/bmm_native.h> | ||||||
|  | #include <ATen/ops/addmm_native.h> | ||||||
| #include <ATen/ops/copy_sparse_to_sparse.h> | #include <ATen/ops/copy_sparse_to_sparse.h> | ||||||
| #include <ATen/ops/mul.h> | #include <ATen/ops/mul.h> | ||||||
| #endif | #endif | ||||||
| @ -29,6 +36,305 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary(); | |||||||
| #include <ATen/native/mps/Mul_metallib.h> | #include <ATen/native/mps/Mul_metallib.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | static Tensor& s_addmm_out_sparse_dense_mps( | ||||||
|  |     Tensor& r, | ||||||
|  |     const Tensor& t, | ||||||
|  |     const SparseTensor& sparse_, | ||||||
|  |     const Tensor& dense, | ||||||
|  |     const Scalar& beta, | ||||||
|  |     const Scalar& alpha) { | ||||||
|  |   TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim()); | ||||||
|  |   TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim()); | ||||||
|  |   TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim()); | ||||||
|  |   TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim()); | ||||||
|  |  | ||||||
|  |   const int64_t I = sparse_.size(0); | ||||||
|  |   const int64_t J = sparse_.size(1); | ||||||
|  |   const int64_t K = dense.size(1); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(dense.size(0) == J, | ||||||
|  |       "addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0)); | ||||||
|  |   TORCH_CHECK(t.size(0) == I && t.size(1) == K, | ||||||
|  |       "addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")"); | ||||||
|  |  | ||||||
|  |   r.resize_({I, K}); | ||||||
|  |  | ||||||
|  |   auto sparse = sparse_.coalesce(); | ||||||
|  |   const int64_t nnz = sparse._nnz(); | ||||||
|  |  | ||||||
|  |   if (nnz == 0 || I == 0 || K == 0) { | ||||||
|  |     at::mul_out(r, t, beta); | ||||||
|  |     return r; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const auto v_dtype = sparse._values().scalar_type(); | ||||||
|  |   const auto d_dtype = dense.scalar_type(); | ||||||
|  |   const auto t_dtype = t.scalar_type(); | ||||||
|  |   auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(canCast(compute_dtype, r.scalar_type()), | ||||||
|  |               "Can't convert computed type ", compute_dtype, " to output ", r.scalar_type()); | ||||||
|  |  | ||||||
|  |   auto indices2d = sparse._indices().contiguous(); | ||||||
|  |   auto values = sparse._values().to(compute_dtype); | ||||||
|  |   auto dense_c = dense.to(compute_dtype).contiguous(); | ||||||
|  |   auto t_c = t.to(compute_dtype).contiguous(); | ||||||
|  |  | ||||||
|  |   const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous(); | ||||||
|  |   Tensor out_buf = out_needs_cast | ||||||
|  |       ? at::empty({I, K}, r.options().dtype(compute_dtype)) | ||||||
|  |       : r; | ||||||
|  |   auto out_contig = out_buf.contiguous(); | ||||||
|  |  | ||||||
|  |   auto device = r.device(); | ||||||
|  |   auto stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|  |   const float alpha_f = alpha.to<float>(); | ||||||
|  |   const float beta_f  = beta.to<float>(); | ||||||
|  |  | ||||||
|  |   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |     @autoreleasepool { | ||||||
|  |       const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values); | ||||||
|  |       auto pso = lib.getPipelineStateForFunc(func); | ||||||
|  |       auto enc = stream->commandEncoder(); | ||||||
|  |       [enc setComputePipelineState:pso]; | ||||||
|  |  | ||||||
|  |       const uint32_t tew = pso.threadExecutionWidth; | ||||||
|  |       const uint32_t gridX = static_cast<uint32_t>(K); | ||||||
|  |       const uint32_t gridZ = static_cast<uint32_t>(I); | ||||||
|  |       const uint32_t tgW = std::min<uint32_t>(gridX, tew); | ||||||
|  |  | ||||||
|  |       MTLSize grid = MTLSizeMake(gridX, 1, gridZ); | ||||||
|  |       MTLSize tgs = MTLSizeMake(tgW, 1, 1); | ||||||
|  |  | ||||||
|  |       mtl_setArgs(enc, | ||||||
|  |                   indices2d, | ||||||
|  |                   values, | ||||||
|  |                   dense_c, | ||||||
|  |                   t_c, | ||||||
|  |                   out_contig, | ||||||
|  |                   std::array<uint32_t, 3>{static_cast<uint32_t>(I), | ||||||
|  |                                            static_cast<uint32_t>(J), | ||||||
|  |                                            static_cast<uint32_t>(K)}, | ||||||
|  |                   std::array<float, 2>{alpha_f, beta_f}, | ||||||
|  |                   static_cast<uint32_t>(nnz)); | ||||||
|  |       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   if (out_needs_cast) { | ||||||
|  |     r.copy_(out_contig.to(r.scalar_type())); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return r; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | static void build_batch_ptr_mps( | ||||||
|  |     const Tensor& indices_dim0, | ||||||
|  |     int64_t B, | ||||||
|  |     Tensor& batch_ptr | ||||||
|  | ) { | ||||||
|  |   // Builds an array of pointers which point to each batches elements. Example: | ||||||
|  |   // idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2]  // 9 non-zero elements | ||||||
|  |   //          └─────┘  └──┘  └─────────┘ | ||||||
|  |   //          batch 0  batch 1  batch 2 | ||||||
|  |   // batch_ptr = [0, 3, 5, 9] | ||||||
|  |   //              │  │  │  └─ end of batch 2 (total nnz) | ||||||
|  |   //              │  │  └──── batch 2 starts at index 5 | ||||||
|  |   //              │  └─────── batch 1 starts at index 3 | ||||||
|  |   //              └────────── batch 0 starts at index 0 | ||||||
|  |   TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected"); | ||||||
|  |   auto device = indices_dim0.device(); | ||||||
|  |   auto stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|  |   const int64_t nnz = indices_dim0.numel(); | ||||||
|  |  | ||||||
|  |   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |     @autoreleasepool { | ||||||
|  |       auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches"); | ||||||
|  |       auto enc = stream->commandEncoder(); | ||||||
|  |       [enc setComputePipelineState:pso]; | ||||||
|  |  | ||||||
|  |       const uint32_t tew = pso.threadExecutionWidth; | ||||||
|  |       const uint32_t Q = static_cast<uint32_t>(B + 1); | ||||||
|  |       const uint32_t tgW = std::min<uint32_t>(Q, tew); | ||||||
|  |       MTLSize grid = MTLSizeMake(Q, 1, 1); | ||||||
|  |       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||||
|  |  | ||||||
|  |       mtl_setArgs(enc, | ||||||
|  |                   indices_dim0, | ||||||
|  |                   batch_ptr, | ||||||
|  |                   std::array<uint32_t, 2>{static_cast<uint32_t>(nnz), | ||||||
|  |                                           static_cast<uint32_t>(B)}); | ||||||
|  |       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void build_row_ptr_per_batch_mps( | ||||||
|  |     const Tensor& rows, | ||||||
|  |     const Tensor& batch_ptr, | ||||||
|  |     int64_t B, | ||||||
|  |     int64_t I, | ||||||
|  |     Tensor& row_ptr | ||||||
|  | ) { | ||||||
|  |   // Build per-batch CSR-style row pointer arrays from row indices sorted by batch | ||||||
|  |   // Given: | ||||||
|  |   //   rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch | ||||||
|  |   //   batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b | ||||||
|  |   // Produces: | ||||||
|  |   //   - row_ptr: shape [B, I+1] | ||||||
|  |   // | ||||||
|  |   // Example (B = 2, I = 4): | ||||||
|  |   // rows       = [0,   0,   1,  3,  0,   2,    2]   // 7 non-zero elements | ||||||
|  |   //               └─── batch 0 ──┘  └─ batch 1 ─┘ | ||||||
|  |   // batch_ptr  = [0, 4, 7] | ||||||
|  |   //               │  │  └─ end of batch 1 (total nnz) | ||||||
|  |   //               │  └──── end of batch 0/start of batch 1 | ||||||
|  |   //               └─────── start of batch 0 | ||||||
|  |   // | ||||||
|  |   // per-batch row pointers (I+1 entries each): | ||||||
|  |   //   row_ptr[0] = [0, 2, 3, 3, 4] | ||||||
|  |   //   row_ptr[1] = [0, 1, 1, 3, 3] | ||||||
|  |   // laid out in memory: [0, 2, 3, 3, 4,  0, 1, 1, 3, 3] | ||||||
|  |   TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected"); | ||||||
|  |   auto stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|  |   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |     @autoreleasepool { | ||||||
|  |       auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch"); | ||||||
|  |       auto enc = stream->commandEncoder(); | ||||||
|  |       [enc setComputePipelineState:pso]; | ||||||
|  |  | ||||||
|  |       const uint32_t tew = pso.threadExecutionWidth; | ||||||
|  |       const uint32_t Qx = static_cast<uint32_t>(I + 1); | ||||||
|  |       const uint32_t Qy = static_cast<uint32_t>(B); | ||||||
|  |       const uint32_t tgW = std::min<uint32_t>(Qx, tew); | ||||||
|  |  | ||||||
|  |       MTLSize grid = MTLSizeMake(Qx, Qy, 1); | ||||||
|  |       MTLSize tgs = MTLSizeMake(tgW, 1, 1); | ||||||
|  |  | ||||||
|  |       mtl_setArgs(enc, | ||||||
|  |                   rows, | ||||||
|  |                   batch_ptr, | ||||||
|  |                   row_ptr, | ||||||
|  |                   std::array<uint32_t, 2>{static_cast<uint32_t>(I), | ||||||
|  |                                            static_cast<uint32_t>(B)}); | ||||||
|  |       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) { | ||||||
|  |   TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device()); | ||||||
|  |   TORCH_CHECK(self_.is_mps(),  "bmm_sparse: expected 'self' to be MPS, got ", self_.device()); | ||||||
|  |   TORCH_CHECK(mat2_.is_mps(),  "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device()); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim()); | ||||||
|  |   TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim()); | ||||||
|  |   TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim()); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match"); | ||||||
|  |   TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match"); | ||||||
|  |  | ||||||
|  |   const int64_t B = self_.size(0); | ||||||
|  |   const int64_t I = self_.size(1); | ||||||
|  |   const int64_t J = self_.size(2); | ||||||
|  |   const int64_t K = mat2_.size(2); | ||||||
|  |  | ||||||
|  |   auto self = self_.coalesce(); | ||||||
|  |   const int64_t nnz = self._nnz(); | ||||||
|  |   if (nnz == 0) { | ||||||
|  |     return result_.zero_(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const auto computeDtype = at::kFloat; | ||||||
|  |  | ||||||
|  |   auto indices = self._indices(); | ||||||
|  |   auto values  = self._values(); | ||||||
|  |  | ||||||
|  |   auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype); | ||||||
|  |   auto mat2_c = mat2_.scalar_type()   == computeDtype ? mat2_   : mat2_.to(computeDtype); | ||||||
|  |   auto mat2_contig = mat2_c.contiguous(); | ||||||
|  |  | ||||||
|  |   auto idx_b = indices.select(0, 0).contiguous(); | ||||||
|  |   auto idx_i = indices.select(0, 1).contiguous(); | ||||||
|  |   auto idx_j = indices.select(0, 2).contiguous(); | ||||||
|  |  | ||||||
|  |   // builds an array of pointers of where the batch_idx's pointer starts and ends | ||||||
|  |   // look in function for better explanation | ||||||
|  |   auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong)); | ||||||
|  |   build_batch_ptr_mps(idx_b, B, batch_ptr); | ||||||
|  |   // build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals | ||||||
|  |   auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong)); | ||||||
|  |   build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr); | ||||||
|  |  | ||||||
|  |   const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous(); | ||||||
|  |   Tensor out_buf = out_needs_cast | ||||||
|  |       ? at::empty({B, I, K}, result_.options().dtype(computeDtype)) | ||||||
|  |       : result_; | ||||||
|  |   auto out_contig = out_buf.contiguous(); | ||||||
|  |  | ||||||
|  |   auto stream = getCurrentMPSStream(); | ||||||
|  |   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |     @autoreleasepool { | ||||||
|  |       auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values)); | ||||||
|  |       auto enc = stream->commandEncoder(); | ||||||
|  |       [enc setComputePipelineState:pso]; | ||||||
|  |  | ||||||
|  |       const uint32_t tew = pso.threadExecutionWidth; | ||||||
|  |       const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew); | ||||||
|  |  | ||||||
|  |       // One threadgroup per (row i, batch b), lanes cover K | ||||||
|  |       MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B); | ||||||
|  |       MTLSize tgs  = MTLSizeMake(tgW, 1, 1); | ||||||
|  |  | ||||||
|  |       mtl_setArgs(enc, | ||||||
|  |                   idx_i, | ||||||
|  |                   idx_j, | ||||||
|  |                   values_c, | ||||||
|  |                   mat2_contig, | ||||||
|  |                   out_contig, | ||||||
|  |                   row_ptr, | ||||||
|  |                   std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K}); | ||||||
|  |       [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |   if (out_needs_cast) { | ||||||
|  |     result_.copy_(out_contig.to(result_.scalar_type())); | ||||||
|  |   } | ||||||
|  |   return result_; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) { | ||||||
|  |   Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options()); | ||||||
|  |   return bmm_out_sparse_mps(self, mat2, result); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Tensor& addmm_out_sparse_dense_mps( | ||||||
|  |     const Tensor& self, | ||||||
|  |     const SparseTensor& mat1, | ||||||
|  |     const Tensor& mat2, | ||||||
|  |     const Scalar& beta, | ||||||
|  |     const Scalar& alpha, | ||||||
|  |     Tensor& result) { | ||||||
|  |   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); | ||||||
|  |   return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Tensor addmm_sparse_dense_mps( | ||||||
|  |     const Tensor& self, | ||||||
|  |     const SparseTensor& mat1, | ||||||
|  |     const Tensor& mat2, | ||||||
|  |     const Scalar& beta, | ||||||
|  |     const Scalar& alpha | ||||||
|  | ) { | ||||||
|  |   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); | ||||||
|  |   Tensor result = at::empty({0}, self.options()); | ||||||
|  |   return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha); | ||||||
|  | } | ||||||
|  |  | ||||||
| static SparseTensor& mul_out_dense_sparse_mps( | static SparseTensor& mul_out_dense_sparse_mps( | ||||||
|     const Tensor& dense, |     const Tensor& dense, | ||||||
|     const Tensor& sparse, |     const Tensor& sparse, | ||||||
| @ -436,4 +742,137 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self, | |||||||
|   return out; |   return out; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | using OptTensor = std::optional<Tensor>; | ||||||
|  |  | ||||||
|  |  | ||||||
|  | static void sparse_mask_apply_out_mps_kernel( | ||||||
|  |     Tensor& result, | ||||||
|  |     const Tensor& src_in, | ||||||
|  |     const Tensor& mask_in, | ||||||
|  |     bool accumulate_matches, | ||||||
|  |     bool require_same_sizes, | ||||||
|  |     bool coalesce_mask) { | ||||||
|  |   TORCH_CHECK(src_in.is_sparse() && mask_in.is_sparse(), | ||||||
|  |               "sparse_mask: expected both inputs to be sparse COO"); | ||||||
|  |   TORCH_CHECK(src_in.is_mps() && mask_in.is_mps(), | ||||||
|  |               "sparse_mask: expected tensors to be on MPS device"); | ||||||
|  |   TORCH_CHECK(src_in.sparse_dim() == mask_in.sparse_dim(), | ||||||
|  |               "sparse_mask: sparse_dim mismatch: ", src_in.sparse_dim(), " vs ", mask_in.sparse_dim()); | ||||||
|  |   if (require_same_sizes) { | ||||||
|  |     TORCH_CHECK(src_in.sizes().equals(mask_in.sizes()), | ||||||
|  |                 "sparse_mask: sizes must match exactly (no broadcasting)"); | ||||||
|  |   } | ||||||
|  |   auto src  = src_in.coalesce(); | ||||||
|  |   auto mask = coalesce_mask ? mask_in.coalesce() : mask_in; | ||||||
|  |  | ||||||
|  |   const int64_t src_nnz = src._nnz(); | ||||||
|  |   const int64_t mask_nnz = mask._nnz(); | ||||||
|  |   const int64_t sd = src.sparse_dim(); | ||||||
|  |   result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim()); | ||||||
|  |  | ||||||
|  |   auto commonDtype = at::result_type(src, mask); | ||||||
|  |   TORCH_CHECK(canCast(commonDtype, result.scalar_type()), | ||||||
|  |               "Can't convert result type ", commonDtype, " to output ", result.scalar_type()); | ||||||
|  |  | ||||||
|  |   if (mask_nnz == 0) { | ||||||
|  |     alias_into_sparse( | ||||||
|  |         result, | ||||||
|  |         mask._indices().narrow(1, 0, 0), | ||||||
|  |         at::empty({0}, result.options().dtype(result.scalar_type()))); | ||||||
|  |     result._coalesced_(mask.is_coalesced()); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(sd > 0 || (src_nnz <= 1 && mask_nnz <= 1), | ||||||
|  |               "sparse_mask: invalid sparse_dim or nnz"); | ||||||
|  |  | ||||||
|  |   if (sd == 0) { | ||||||
|  |     auto out_indices = mask._indices().narrow(1, 0, 1); | ||||||
|  |     auto out_values = src_nnz | ||||||
|  |       ? src._values().narrow(0, 0, 1).to(commonDtype) | ||||||
|  |       : at::zeros({1}, at::device(result.device()).dtype(commonDtype)); | ||||||
|  |     alias_into_sparse(result, out_indices, out_values); | ||||||
|  |     result._coalesced_(mask.is_coalesced()); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (src_nnz == 0) { | ||||||
|  |     auto out_indices = mask._indices().contiguous(); | ||||||
|  |     auto src_values  = src._values().to(commonDtype); | ||||||
|  |     auto out_val_sizes = src_values.sizes().vec(); | ||||||
|  |     out_val_sizes[0] = mask_nnz; | ||||||
|  |     auto out_values = at::zeros(out_val_sizes, src_values.options()); | ||||||
|  |     alias_into_sparse(result, out_indices, out_values); | ||||||
|  |     result._coalesced_(mask.is_coalesced()); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto mask_indices = mask._indices().contiguous(); | ||||||
|  |   auto src_indices = src._indices().contiguous(); | ||||||
|  |   auto src_values = src._values().to(commonDtype).contiguous(); | ||||||
|  |  | ||||||
|  |   auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous(); | ||||||
|  |   auto src_keys  = flatten_indices(src_indices,  src.sizes().slice(0, sd)).contiguous(); | ||||||
|  |  | ||||||
|  |   const bool A_is_src = (src_nnz <= mask_nnz); | ||||||
|  |   const int64_t lenA = A_is_src ? src_nnz  : mask_nnz; | ||||||
|  |   const int64_t lenB = A_is_src ? mask_nnz : src_nnz; | ||||||
|  |   auto A_keys = A_is_src ? src_keys  : mask_keys; | ||||||
|  |   auto B_keys = A_is_src ? mask_keys : src_keys; | ||||||
|  |  | ||||||
|  |   const auto device = result.device(); | ||||||
|  |   auto stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|  |   auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong)); | ||||||
|  |   auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong)); | ||||||
|  |   auto counter = at::zeros({1}, at::device(device).dtype(at::kInt)); | ||||||
|  |  | ||||||
|  |   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||||
|  |     @autoreleasepool { | ||||||
|  |       auto pso = lib.getPipelineStateForFunc("intersect_binary_search"); | ||||||
|  |       auto enc = stream->commandEncoder(); | ||||||
|  |       [enc setComputePipelineState:pso]; | ||||||
|  |       mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter, | ||||||
|  |                   static_cast<uint32_t>(lenB), A_is_src); | ||||||
|  |       mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA)); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   const int64_t M = static_cast<int64_t>(counter.item<int32_t>()); | ||||||
|  |  | ||||||
|  |   auto out_val_sizes = src_values.sizes().vec(); | ||||||
|  |   out_val_sizes[0] = mask_nnz; | ||||||
|  |   auto out_values = at::zeros(out_val_sizes, src_values.options()); | ||||||
|  |  | ||||||
|  |   if (M > 0) { | ||||||
|  |     auto src_match = outA_idx.narrow(0, 0, M); | ||||||
|  |     auto mask_match = outB_idx.narrow(0, 0, M); | ||||||
|  |  | ||||||
|  |     auto src_rows = src_values.index_select(0, src_match); | ||||||
|  |     if (accumulate_matches) { | ||||||
|  |       out_values.index_add_(0, mask_match, src_rows); | ||||||
|  |     } else { | ||||||
|  |       out_values.index_copy_(0, mask_match, src_rows); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   alias_into_sparse(result, mask_indices, out_values); | ||||||
|  |   result._coalesced_(mask.is_coalesced()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void sparse_mask_intersection_out_mps_kernel( | ||||||
|  |     Tensor& result, | ||||||
|  |     const Tensor& lhs, | ||||||
|  |     const Tensor& rhs, | ||||||
|  |     const OptTensor& = std::nullopt) { | ||||||
|  |   sparse_mask_apply_out_mps_kernel( | ||||||
|  |       result, | ||||||
|  |       /*src_in=*/lhs, | ||||||
|  |       /*mask_in=*/rhs, | ||||||
|  |       /*accumulate_matches=*/false, | ||||||
|  |       /*require_same_sizes=*/false, | ||||||
|  |       /*coalesce_mask=*/false); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel); | ||||||
| } // namespace at::native | } // namespace at::native | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	