mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			142 Commits
		
	
	
		
			annotate_f
			...
			gh/laithsa
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b994e3dfe2 | |||
| a4ca86145d | |||
| a0be597ea4 | |||
| 13593796cc | |||
| 21bb1455e5 | |||
| 2e7752e6ce | |||
| 4b87543b65 | |||
| bb526733ef | |||
| 19050dd509 | |||
| 16686fda12 | |||
| 1d79ebb3da | |||
| a788092fb8 | |||
| a8cfab1e46 | |||
| 3161e2b61f | |||
| 15530d81d3 | |||
| f3c4db3804 | |||
| cc47c109a1 | |||
| 476de516f3 | |||
| 3b3eece835 | |||
| a9776fa0cc | |||
| b225bcaa97 | |||
| aaac8cb0f5 | |||
| 0f0b4bf029 | |||
| b8194268a6 | |||
| f02e3947f6 | |||
| 9095a9dfae | |||
| d9f94e0d7d | |||
| 23417ae50f | |||
| e4d6c56ffb | |||
| 017d2985f3 | |||
| c6a8db0b9a | |||
| de09bab4b6 | |||
| c137e222d4 | |||
| cf3a787bbc | |||
| de3da77cf7 | |||
| 543ddbf44c | |||
| e9f4999985 | |||
| 29b029648e | |||
| a25a649e70 | |||
| 69c33898fa | |||
| 1b397420f2 | |||
| fe80f03726 | |||
| e50dc40d28 | |||
| 2e22b1a61e | |||
| 616c6bdf8f | |||
| c18ddfc572 | |||
| 86ebce1766 | |||
| 8cb2fb44f2 | |||
| ab65498d71 | |||
| 06d324365c | |||
| 6c9c6e0936 | |||
| 2bcd892c86 | |||
| 75e2a9fae3 | |||
| a16fd6b488 | |||
| 382b0150de | |||
| a664b299ac | |||
| 9c12651417 | |||
| 08c97b4a1f | |||
| fae74cd52f | |||
| 7a65770013 | |||
| e4454947e2 | |||
| 3806e9767b | |||
| b08d8c2e50 | |||
| ca5b7f8ded | |||
| 9a71d96256 | |||
| 0d4c2b71e8 | |||
| d659bbde62 | |||
| 58879bfafa | |||
| a032510db3 | |||
| 39e0a832c9 | |||
| dd3b48e85d | |||
| cff1b20771 | |||
| da8517fa63 | |||
| 45afaf08a1 | |||
| 080365b7d8 | |||
| 2928c5c572 | |||
| 630520b346 | |||
| 1dc9a05d03 | |||
| bfcdbd0a97 | |||
| faff826a46 | |||
| 85c5433d38 | |||
| 935ccdbe75 | |||
| 3af2f0c12a | |||
| 6ece527fc5 | |||
| ce29d0d796 | |||
| 7231118db3 | |||
| 5d4da26ed0 | |||
| 574c9fc950 | |||
| 80d2ca7566 | |||
| 4a22139eea | |||
| cb6e4d7d82 | |||
| 202f83dc4e | |||
| 9fe3b2afbe | |||
| d0c24b392c | |||
| b44fb14906 | |||
| 51348c0219 | |||
| fdd560afd1 | |||
| e925dfcc6b | |||
| f1d882212a | |||
| 24879f0de9 | |||
| 9e94ec76b8 | |||
| 364624e209 | |||
| 7e150467f7 | |||
| 43d78423ac | |||
| fcbde24c1c | |||
| 861cdb887b | |||
| 3154482072 | |||
| 9fccbdd4f0 | |||
| 7dabfb07cb | |||
| d0add0be43 | |||
| 11e2084308 | |||
| 9726553653 | |||
| d82527b32a | |||
| 5d9b024276 | |||
| 5b2afe4c5d | |||
| b2953f5643 | |||
| 470e2f61c3 | |||
| e0fe37fa68 | |||
| d2c82bafb7 | |||
| 98a488c9aa | |||
| 5b3ea75895 | |||
| 556fc09a9f | |||
| ce109b3f79 | |||
| 4d833f859b | |||
| d7e275d4b4 | |||
| d5db3aee0d | |||
| 5641de7b6b | |||
| cbc08c8993 | |||
| 1a54d3333d | |||
| 4c1c341fa0 | |||
| 5f21cc786a | |||
| e86942f422 | |||
| 2cd5fd1588 | |||
| 7d0f872cb3 | |||
| fb06e49ce8 | |||
| 27a98e6ae9 | |||
| b10f463b1a | |||
| 431c13cf61 | |||
| aead9270f5 | |||
| 9bf5b38c14 | |||
| aba8c43594 | |||
| 37f3ba274a | 
@ -113,6 +113,7 @@ case "$tag" in
 | 
			
		||||
    UCX_COMMIT=${_UCX_COMMIT}
 | 
			
		||||
    UCC_COMMIT=${_UCC_COMMIT}
 | 
			
		||||
    TRITON=yes
 | 
			
		||||
    INSTALL_MINGW=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
 | 
			
		||||
    CUDA_VERSION=13.0.0
 | 
			
		||||
@ -361,6 +362,7 @@ docker build \
 | 
			
		||||
       --build-arg "OPENBLAS=${OPENBLAS:-}" \
 | 
			
		||||
       --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
 | 
			
		||||
       --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
 | 
			
		||||
       --build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
 | 
			
		||||
       -f $(dirname ${DOCKERFILE})/Dockerfile \
 | 
			
		||||
       -t "$tmp_tag" \
 | 
			
		||||
       "$@" \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								.ci/docker/common/install_mingw.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								.ci/docker/common/install_mingw.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
set -ex
 | 
			
		||||
 | 
			
		||||
# Install MinGW-w64 for Windows cross-compilation
 | 
			
		||||
apt-get update
 | 
			
		||||
apt-get install -y g++-mingw-w64-x86-64-posix
 | 
			
		||||
 | 
			
		||||
echo "MinGW-w64 installed successfully"
 | 
			
		||||
x86_64-w64-mingw32-g++ --version
 | 
			
		||||
@ -20,7 +20,7 @@ pip_install \
 | 
			
		||||
 | 
			
		||||
pip_install coloredlogs packaging
 | 
			
		||||
pip_install onnxruntime==1.23.0
 | 
			
		||||
pip_install onnxscript==0.5.3
 | 
			
		||||
pip_install onnxscript==0.5.4
 | 
			
		||||
 | 
			
		||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
 | 
			
		||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
 | 
			
		||||
 | 
			
		||||
@ -39,9 +39,13 @@ case ${DOCKER_TAG_PREFIX} in
 | 
			
		||||
        DOCKER_GPU_BUILD_ARG=""
 | 
			
		||||
        ;;
 | 
			
		||||
    rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
        fi
 | 
			
		||||
        BASE_TARGET=rocm
 | 
			
		||||
        GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
 | 
			
		||||
 | 
			
		||||
@ -75,9 +75,13 @@ case ${image} in
 | 
			
		||||
        DOCKERFILE_SUFFIX="_cuda_aarch64"
 | 
			
		||||
        ;;
 | 
			
		||||
    manylinux2_28-builder:rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
        fi
 | 
			
		||||
        TARGET=rocm_final
 | 
			
		||||
        MANY_LINUX_VERSION="2_28"
 | 
			
		||||
 | 
			
		||||
@ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
 | 
			
		||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
 | 
			
		||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
 | 
			
		||||
 | 
			
		||||
ARG INSTALL_MINGW
 | 
			
		||||
COPY ./common/install_mingw.sh install_mingw.sh
 | 
			
		||||
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
 | 
			
		||||
RUN rm install_mingw.sh
 | 
			
		||||
 | 
			
		||||
ARG TRITON
 | 
			
		||||
ARG TRITON_CPU
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -485,6 +485,22 @@ test_inductor_aoti() {
 | 
			
		||||
  /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
test_inductor_aoti_cross_compile_for_windows() {
 | 
			
		||||
 | 
			
		||||
  TEST_REPORTS_DIR=$(pwd)/test/test-reports
 | 
			
		||||
  mkdir -p "$TEST_REPORTS_DIR"
 | 
			
		||||
 | 
			
		||||
  # Set WINDOWS_CUDA_HOME environment variable
 | 
			
		||||
  WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
 | 
			
		||||
  export WINDOWS_CUDA_HOME
 | 
			
		||||
 | 
			
		||||
  echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
 | 
			
		||||
  echo "Contents:"
 | 
			
		||||
  ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
 | 
			
		||||
 | 
			
		||||
  python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
test_inductor_cpp_wrapper_shard() {
 | 
			
		||||
  if [[ -z "$NUM_TEST_SHARDS" ]]; then
 | 
			
		||||
    echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
 | 
			
		||||
@ -900,7 +916,7 @@ test_inductor_set_cpu_affinity(){
 | 
			
		||||
  export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
 | 
			
		||||
  export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
 | 
			
		||||
 | 
			
		||||
  if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
 | 
			
		||||
  if [[ "$(uname -m)" != "aarch64" ]]; then
 | 
			
		||||
    # Use Intel OpenMP for x86
 | 
			
		||||
    IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
 | 
			
		||||
    export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
 | 
			
		||||
@ -914,7 +930,7 @@ test_inductor_set_cpu_affinity(){
 | 
			
		||||
  cores=$((cpus / thread_per_core))
 | 
			
		||||
 | 
			
		||||
  # Set number of cores to 16 on aarch64 for performance runs
 | 
			
		||||
  if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
 | 
			
		||||
  if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
 | 
			
		||||
    cores=16
 | 
			
		||||
  fi
 | 
			
		||||
  export OMP_NUM_THREADS=$cores
 | 
			
		||||
@ -1667,7 +1683,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
 | 
			
		||||
    python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
 | 
			
		||||
  fi
 | 
			
		||||
  python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
 | 
			
		||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
 | 
			
		||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
 | 
			
		||||
  test_linux_aarch64
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
 | 
			
		||||
  test_forward_backward_compatibility
 | 
			
		||||
@ -1718,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
 | 
			
		||||
  test_inductor_triton_cpu
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
 | 
			
		||||
  test_inductor_micro_benchmark
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
 | 
			
		||||
  test_inductor_aoti_cross_compile_for_windows
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
 | 
			
		||||
  install_torchvision
 | 
			
		||||
  id=$((SHARD_NUMBER-1))
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								.flake8
									
									
									
									
									
								
							@ -13,8 +13,6 @@ ignore =
 | 
			
		||||
    EXE001,
 | 
			
		||||
    # these ignores are from flake8-bugbear; please fix!
 | 
			
		||||
    B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
 | 
			
		||||
    # these ignores are from flake8-comprehensions; please fix!
 | 
			
		||||
    C407,
 | 
			
		||||
    # these ignores are from flake8-logging-format; please fix!
 | 
			
		||||
    G100,G101,G200
 | 
			
		||||
    # these ignores are from flake8-simplify. please fix or ignore with commented reason
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							@ -133,3 +133,32 @@
 | 
			
		||||
 | 
			
		||||
"ciflow/vllm":
 | 
			
		||||
- .github/ci_commit_pins/vllm.txt
 | 
			
		||||
 | 
			
		||||
"ciflow/b200":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/h100":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/rocm":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
 | 
			
		||||
ciflow_push_tags:
 | 
			
		||||
- ciflow/b200
 | 
			
		||||
- ciflow/b200-symm-mem
 | 
			
		||||
- ciflow/b200-distributed
 | 
			
		||||
- ciflow/binaries
 | 
			
		||||
- ciflow/binaries_libtorch
 | 
			
		||||
- ciflow/binaries_wheel
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
 | 
			
		||||
        "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "12.9": (
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "13.0": (
 | 
			
		||||
        "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							@ -1092,7 +1092,7 @@ class GitHubPR:
 | 
			
		||||
        editor = node["editor"]
 | 
			
		||||
        return GitHubComment(
 | 
			
		||||
            body_text=node["bodyText"],
 | 
			
		||||
            created_at=node["createdAt"] if "createdAt" in node else "",
 | 
			
		||||
            created_at=node.get("createdAt", ""),
 | 
			
		||||
            author_login=node["author"]["login"],
 | 
			
		||||
            author_url=node["author"].get("url", None),
 | 
			
		||||
            author_association=node["authorAssociation"],
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							@ -224,6 +224,46 @@ jobs:
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        uses: ./.github/actions/download-td-artifacts
 | 
			
		||||
 | 
			
		||||
      - name: Download Windows torch wheel for cross-compilation
 | 
			
		||||
        if: matrix.win_torch_wheel_artifact != ''
 | 
			
		||||
        uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
 | 
			
		||||
        with:
 | 
			
		||||
          name: ${{ matrix.win_torch_wheel_artifact }}
 | 
			
		||||
          path: win-torch-wheel
 | 
			
		||||
 | 
			
		||||
      - name: Extract Windows wheel and setup CUDA libraries
 | 
			
		||||
        if: matrix.win_torch_wheel_artifact != ''
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -x
 | 
			
		||||
 | 
			
		||||
          # Find the wheel file
 | 
			
		||||
          WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
 | 
			
		||||
          if [ -z "$WHEEL_FILE" ]; then
 | 
			
		||||
            echo "Error: No wheel file found in win-torch-wheel directory"
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi
 | 
			
		||||
          echo "Found wheel file: $WHEEL_FILE"
 | 
			
		||||
 | 
			
		||||
          # Unzip the wheel file
 | 
			
		||||
          unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
 | 
			
		||||
          echo "Extracted wheel contents"
 | 
			
		||||
 | 
			
		||||
          # Setup CUDA libraries (cuda.lib and cudart.lib) directory
 | 
			
		||||
          mkdir -p win-torch-wheel-extracted/lib/x64
 | 
			
		||||
          if [ -f "win-torch-wheel/cuda.lib" ]; then
 | 
			
		||||
            mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
 | 
			
		||||
            echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
 | 
			
		||||
          fi
 | 
			
		||||
          if [ -f "win-torch-wheel/cudart.lib" ]; then
 | 
			
		||||
            mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
 | 
			
		||||
            echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # Verify CUDA libraries are present
 | 
			
		||||
          echo "CUDA libraries:"
 | 
			
		||||
          ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
 | 
			
		||||
 | 
			
		||||
      - name: Parse ref
 | 
			
		||||
        id: parse-ref
 | 
			
		||||
        run: .github/scripts/parse_ref.py
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							@ -168,6 +168,31 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          .ci/pytorch/win-build.sh
 | 
			
		||||
 | 
			
		||||
      # Collect Windows torch libs and CUDA libs for cross-compilation
 | 
			
		||||
      - name: Collect Windows CUDA libs for cross-compilation
 | 
			
		||||
        if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -ex
 | 
			
		||||
 | 
			
		||||
          # Create directory structure if does not exist
 | 
			
		||||
          mkdir -p /c/${{ github.run_id }}/build-results
 | 
			
		||||
 | 
			
		||||
          # Copy CUDA libs
 | 
			
		||||
          CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
 | 
			
		||||
 | 
			
		||||
          if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
 | 
			
		||||
            cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
 | 
			
		||||
            cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # List collected files
 | 
			
		||||
          echo "Collected CUDA libs:"
 | 
			
		||||
          ls -lah /c/${{ github.run_id }}/build-results/*.lib
 | 
			
		||||
 | 
			
		||||
      # Upload to github so that people can click and download artifacts
 | 
			
		||||
      - name: Upload artifacts to s3
 | 
			
		||||
        if: steps.build.outcome != 'skipped'
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
name: CI for distributed tests on B200
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  pull_request:
 | 
			
		||||
    paths:
 | 
			
		||||
      - .github/workflows/b200-distributed.yml
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/b200-distributed/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 46 8 * * *  # about 1:46am PDT
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions:
 | 
			
		||||
  id-token: write
 | 
			
		||||
  contents: read
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runner: linux.12xlarge.memory
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
 | 
			
		||||
      cuda-arch-list: '10.0'
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
 | 
			
		||||
    with:
 | 
			
		||||
      timeout-minutes: 1200
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
 | 
			
		||||
      aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -224,7 +224,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -473,7 +473,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -722,7 +722,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -971,7 +971,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1220,7 +1220,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1469,7 +1469,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1718,7 +1718,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -259,7 +259,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_10-cuda12_9-test:  # Testing
 | 
			
		||||
@ -925,7 +925,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_11-cuda12_9-test:  # Testing
 | 
			
		||||
@ -1591,7 +1591,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_12-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2257,7 +2257,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2923,7 +2923,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13t-cuda12_9-test:  # Testing
 | 
			
		||||
@ -3589,7 +3589,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14-cuda12_9-test:  # Testing
 | 
			
		||||
@ -4255,7 +4255,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14t-cuda12_9-test:  # Testing
 | 
			
		||||
 | 
			
		||||
@ -88,27 +88,27 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							@ -52,3 +52,27 @@ jobs:
 | 
			
		||||
      docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  aarch64-opbenchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: aarch64-opbenchmark-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-aarch64-py3.10
 | 
			
		||||
      runner: linux.arm64.m7g.4xlarge
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  aarch64-opbenchmark-test:
 | 
			
		||||
    name: aarch64-opbenchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: aarch64-opbenchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-aarch64-py3.10
 | 
			
		||||
      docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							@ -45,12 +45,12 @@ jobs:
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -200,6 +200,23 @@ jobs:
 | 
			
		||||
      cuda-arch-list: '8.0'
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  # Test cross-compiled models with Windows libs extracted from wheel
 | 
			
		||||
  cross-compile-linux-test:
 | 
			
		||||
    name: cross-compile-linux-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
      - win-vs2022-cuda12_8-py3-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  verify-cachebench-cpu-build:
 | 
			
		||||
    name: verify-cachebench-cpu-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -374,6 +374,7 @@ third_party/ruy/
 | 
			
		||||
third_party/glog/
 | 
			
		||||
 | 
			
		||||
# Virtualenv
 | 
			
		||||
.venv/
 | 
			
		||||
venv/
 | 
			
		||||
 | 
			
		||||
# Log files
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							@ -201,3 +201,17 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
 | 
			
		||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
 | 
			
		||||
/torch/headeronly/ @janeyx99
 | 
			
		||||
/torch/header_only_apis.txt @janeyx99
 | 
			
		||||
 | 
			
		||||
# FlexAttention
 | 
			
		||||
/torch/nn/attention/flex_attention.py @drisspg
 | 
			
		||||
/torch/_higher_order_ops/flex_attention.py @drisspg
 | 
			
		||||
/torch/_inductor/kernel/flex/ @drisspg
 | 
			
		||||
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_attention.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_decoding.py @drisspg
 | 
			
		||||
 | 
			
		||||
# Low Precision GEMMs
 | 
			
		||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
 | 
			
		||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
 | 
			
		||||
 | 
			
		||||
@ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
 | 
			
		||||
    set(fbgemm_genai_mx8mx8bf16_grouped
 | 
			
		||||
    set(fbgemm_genai_cuh
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${fbgemm_genai_cuh}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
#include <ATen/CachedTensorUtils.h>
 | 
			
		||||
#include <c10/core/GradMode.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
 | 
			
		||||
namespace at::autocast {
 | 
			
		||||
@ -37,29 +36,10 @@ namespace {
 | 
			
		||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
 | 
			
		||||
using val_type = std::tuple<weakref_type, Tensor>;
 | 
			
		||||
 | 
			
		||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
 | 
			
		||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
 | 
			
		||||
// are not incorrectly reused in gradient-enabled contexts.
 | 
			
		||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
 | 
			
		||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
 | 
			
		||||
  static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
 | 
			
		||||
  return cached_casts_grad_enabled;
 | 
			
		||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
 | 
			
		||||
  static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
 | 
			
		||||
  return cached_casts;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
 | 
			
		||||
  static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
 | 
			
		||||
  return cached_casts_grad_disabled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to get the appropriate cache based on current gradient mode.
 | 
			
		||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
 | 
			
		||||
// preventing incorrect cache hits when gradient mode changes.
 | 
			
		||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
 | 
			
		||||
  return at::GradMode::is_enabled() ?
 | 
			
		||||
    get_cached_casts_grad_enabled() :
 | 
			
		||||
    get_cached_casts_grad_disabled();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::mutex cached_casts_mutex;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -106,9 +86,7 @@ thread_local bool cache_enabled = true;
 | 
			
		||||
 | 
			
		||||
void clear_cache() {
 | 
			
		||||
  const std::lock_guard<std::mutex> lock(cached_casts_mutex);
 | 
			
		||||
  // Clear both caches to ensure consistent behavior regardless of current gradient mode
 | 
			
		||||
  get_cached_casts_grad_enabled().clear();
 | 
			
		||||
  get_cached_casts_grad_disabled().clear();
 | 
			
		||||
  get_cached_casts().clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int increment_nesting() {
 | 
			
		||||
@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
 | 
			
		||||
  if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
 | 
			
		||||
    // Heuristic:  Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
 | 
			
		||||
    // See cached_casts declaration above for detailed strategy.
 | 
			
		||||
    //
 | 
			
		||||
    // We maintain separate caches for gradient-enabled and gradient-disabled modes
 | 
			
		||||
    // (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
 | 
			
		||||
    // with torch.autocast(), while maintaining optimal performance for both training and inference.
 | 
			
		||||
    // This fixes issue #158232 without any performance regression.
 | 
			
		||||
    bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
 | 
			
		||||
                         arg.scalar_type() == at::kFloat && arg.requires_grad() &&
 | 
			
		||||
                         arg.is_leaf() && !arg.is_view() && cache_enabled &&
 | 
			
		||||
 | 
			
		||||
@ -229,10 +229,10 @@ private:
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  static const uint32_t kPhilox10A = 0x9E3779B9;
 | 
			
		||||
  static const uint32_t kPhilox10B = 0xBB67AE85;
 | 
			
		||||
  static const uint32_t kPhiloxSA = 0xD2511F53;
 | 
			
		||||
  static const uint32_t kPhiloxSB = 0xCD9E8D57;
 | 
			
		||||
  static constexpr uint32_t kPhilox10A = 0x9E3779B9;
 | 
			
		||||
  static constexpr uint32_t kPhilox10B = 0xBB67AE85;
 | 
			
		||||
  static constexpr uint32_t kPhiloxSA = 0xD2511F53;
 | 
			
		||||
  static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
typedef philox_engine Philox4_32;
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										794
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										794
									
								
								aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,794 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
#define VEC_INT_NEON_TEMPLATE(vl, bit)                                        \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {};  \
 | 
			
		||||
                                                                              \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  class Vectorized<int##bit##_t> {                                            \
 | 
			
		||||
    using neon_type = int##bit##x##vl##_t;                                    \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   private:                                                                   \
 | 
			
		||||
    neon_type values;                                                         \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   public:                                                                    \
 | 
			
		||||
    using value_type = int##bit##_t;                                          \
 | 
			
		||||
    using size_type = int;                                                    \
 | 
			
		||||
    static constexpr size_type size() {                                       \
 | 
			
		||||
      return vl;                                                              \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized() {                                                            \
 | 
			
		||||
      values = vdupq_n_s##bit(0);                                             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized(neon_type v) : values(v) {}                                    \
 | 
			
		||||
    Vectorized(int##bit##_t val);                                             \
 | 
			
		||||
    template <                                                                \
 | 
			
		||||
        typename... Args,                                                     \
 | 
			
		||||
        typename = std::enable_if_t<(sizeof...(Args) == size())>>             \
 | 
			
		||||
    Vectorized(Args... vals) {                                                \
 | 
			
		||||
      __at_align__ int##bit##_t buffer[size()] = {vals...};                   \
 | 
			
		||||
      values = vld1q_s##bit(buffer);                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    operator neon_type() const {                                              \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    static Vectorized<int##bit##_t> loadu(                                    \
 | 
			
		||||
        const void* ptr,                                                      \
 | 
			
		||||
        int64_t count = size());                                              \
 | 
			
		||||
    void store(void* ptr, int64_t count = size()) const;                      \
 | 
			
		||||
    template <int64_t mask>                                                   \
 | 
			
		||||
    static Vectorized<int##bit##_t> blend(                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& a,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& b);                                   \
 | 
			
		||||
    static Vectorized<int##bit##_t> blendv(                                   \
 | 
			
		||||
        const Vectorized<int##bit##_t>& a,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& b,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& mask_) {                              \
 | 
			
		||||
      return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    template <typename step_t>                                                \
 | 
			
		||||
    static Vectorized<int##bit##_t> arange(                                   \
 | 
			
		||||
        value_type base = 0,                                                  \
 | 
			
		||||
        step_t step = static_cast<step_t>(1));                                \
 | 
			
		||||
    static Vectorized<int##bit##_t> set(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& a,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& b,                                    \
 | 
			
		||||
        int64_t count = size());                                              \
 | 
			
		||||
    const int##bit##_t& operator[](int idx) const = delete;                   \
 | 
			
		||||
    int##bit##_t& operator[](int idx) = delete;                               \
 | 
			
		||||
    Vectorized<int##bit##_t> abs() const {                                    \
 | 
			
		||||
      return vabsq_s##bit(values);                                            \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> real() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> imag() const {                                   \
 | 
			
		||||
      return vdupq_n_s##bit(0);                                               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> conj() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> neg() const {                                    \
 | 
			
		||||
      return vnegq_s##bit(values);                                            \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    int##bit##_t reduce_add() const {                                         \
 | 
			
		||||
      return vaddvq_s##bit(values);                                           \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    int##bit##_t reduce_max() const;                                          \
 | 
			
		||||
    Vectorized<int##bit##_t> operator==(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator!=(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const;                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator<(                                       \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator<=(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator>(                                       \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator>=(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
  };                                                                          \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator+(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vaddq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator-(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vsubq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator&(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vandq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator|(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vorrq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator^(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return veorq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this == other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this != other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this > other) & Vectorized<int##bit##_t>(1);                     \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this >= other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this < other) & Vectorized<int##bit##_t>(1);                     \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this <= other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(2, 64)
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(4, 32)
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(8, 16)
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(16, 8)
 | 
			
		||||
 | 
			
		||||
inline int32_t Vectorized<int32_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_s32(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int16_t Vectorized<int16_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_s16(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int8_t Vectorized<int8_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_s8(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator*(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  return vmulq_s32(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator*(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  return vmulq_s16(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator*(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  return vmulq_s8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
 | 
			
		||||
  int64x2_t val = a;
 | 
			
		||||
  return ~val;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
 | 
			
		||||
  return vmvnq_s32(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
 | 
			
		||||
  return vmvnq_s16(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
 | 
			
		||||
  return vmvnq_s8(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
 | 
			
		||||
    const Vectorized<int64_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
 | 
			
		||||
    const Vectorized<int32_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
 | 
			
		||||
    const Vectorized<int16_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
 | 
			
		||||
    const Vectorized<int8_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline minimum(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  return vminq_s32(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline minimum(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  return vminq_s16(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline minimum(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  return vminq_s8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline maximum(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  return vmaxq_s32(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline maximum(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  return vmaxq_s16(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline maximum(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  return vmaxq_s8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int64_t> Vectorized<int64_t>::blend(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint64x2_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s64(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int32_t> Vectorized<int32_t>::blend(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint32x4_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFFFFFFFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s32(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int16_t> Vectorized<int16_t>::blend(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint16x8_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFFFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s16(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int8_t> Vectorized<int8_t>::blend(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint8x16_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 256LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 512LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 1024LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2048LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4096LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8192LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16384LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32768LL) ? 0xFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s8(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define VEC_INT_NEON_OPS(vl, bit)                                             \
 | 
			
		||||
  inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) {             \
 | 
			
		||||
    values = vdupq_n_s##bit(val);                                             \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu(            \
 | 
			
		||||
      const void* ptr, int64_t count) {                                       \
 | 
			
		||||
    if (count == size()) {                                                    \
 | 
			
		||||
      return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr));        \
 | 
			
		||||
    } else {                                                                  \
 | 
			
		||||
      __at_align__ int##bit##_t tmp_values[size()];                           \
 | 
			
		||||
      for (const auto i : c10::irange(size())) {                              \
 | 
			
		||||
        tmp_values[i] = 0;                                                    \
 | 
			
		||||
      }                                                                       \
 | 
			
		||||
      std::memcpy(                                                            \
 | 
			
		||||
          tmp_values,                                                         \
 | 
			
		||||
          reinterpret_cast<const int##bit##_t*>(ptr),                         \
 | 
			
		||||
          count * sizeof(int##bit##_t));                                      \
 | 
			
		||||
      return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count)       \
 | 
			
		||||
      const {                                                                 \
 | 
			
		||||
    if (count == size()) {                                                    \
 | 
			
		||||
      vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values);             \
 | 
			
		||||
    } else {                                                                  \
 | 
			
		||||
      int##bit##_t tmp_values[size()];                                        \
 | 
			
		||||
      vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values);      \
 | 
			
		||||
      std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t));             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_INT_NEON_OPS(2, 64)
 | 
			
		||||
VEC_INT_NEON_OPS(4, 32)
 | 
			
		||||
VEC_INT_NEON_OPS(8, 16)
 | 
			
		||||
VEC_INT_NEON_OPS(16, 8)
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator*(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return x * y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator/(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator/(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  int32x4_t x = a;
 | 
			
		||||
  int32x4_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int64_t Vectorized<int64_t>::reduce_max() const {
 | 
			
		||||
  return std::max(values[0], values[1]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline minimum(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return {std::min(x[0], y[0]), std::min(x[1], y[1])};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline maximum(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return {std::max(x[0], y[0]), std::max(x[1], y[1])};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
 | 
			
		||||
    int64_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<int64_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int64_t> step_vec(step);
 | 
			
		||||
  const int64x2_t step_sizes = {0, 1};
 | 
			
		||||
  return base_vec.values + step_sizes * step_vec.values;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
 | 
			
		||||
    int32_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<int32_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int32_t> step_vec(step);
 | 
			
		||||
  const int32x4_t step_sizes = {0, 1, 2, 3};
 | 
			
		||||
  return vmlaq_s32(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
 | 
			
		||||
    int16_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<int16_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int16_t> step_vec(step);
 | 
			
		||||
  const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
 | 
			
		||||
  return vmlaq_s16(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
 | 
			
		||||
  const Vectorized<int8_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int8_t> step_vec(step);
 | 
			
		||||
  const int8x16_t step_sizes = {
 | 
			
		||||
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
 | 
			
		||||
  return vmlaq_s8(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  uint64x2_t u = vreinterpretq_u64_s64(y);
 | 
			
		||||
  uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
 | 
			
		||||
  return x >> vreinterpretq_s64_u64(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  int32x4_t x = a;
 | 
			
		||||
  int32x4_t y = b;
 | 
			
		||||
  uint32x4_t bound = vdupq_n_u32(31);
 | 
			
		||||
  uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
 | 
			
		||||
  return x >> vreinterpretq_s32_u32(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  int16x8_t x = a;
 | 
			
		||||
  int16x8_t y = b;
 | 
			
		||||
  uint16x8_t bound = vdupq_n_u16(15);
 | 
			
		||||
  uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
 | 
			
		||||
  return x >> vreinterpretq_s16_u16(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  int8x16_t x = a;
 | 
			
		||||
  int8x16_t y = b;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(7);
 | 
			
		||||
  int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
 | 
			
		||||
  return x >> z;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  uint64x2_t u = vreinterpretq_u64_s64(y);
 | 
			
		||||
  uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
 | 
			
		||||
  return vshlq_s64(a, vreinterpretq_s64_u64(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  int32x4_t y = b;
 | 
			
		||||
  uint32x4_t bound = vdupq_n_u32(32);
 | 
			
		||||
  uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
 | 
			
		||||
  return vshlq_s32(a, vreinterpretq_s32_u32(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  int16x8_t y = b;
 | 
			
		||||
  uint16x8_t bound = vdupq_n_u16(16);
 | 
			
		||||
  uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
 | 
			
		||||
  return vshlq_s16(a, vreinterpretq_s16_u16(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  int8x16_t y = b;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
 | 
			
		||||
  return vshlq_s8(a, z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int64_t> Vectorized<int64_t>::set(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 2) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    int64x2_t c = {b.values[0], a.values[1]};
 | 
			
		||||
    return c;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int32_t> Vectorized<int32_t>::set(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 4) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint32x4_t maskArray = {
 | 
			
		||||
        (count >= 1LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
        (count >= 2LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
        (count >= 3LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
        0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_s32(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int16_t> Vectorized<int16_t>::set(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 8) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint16x8_t maskArray = {
 | 
			
		||||
        static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_s16(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int8_t> Vectorized<int8_t>::set(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 16) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint8x16_t maskArray = {
 | 
			
		||||
        static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_s8(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator/(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
 | 
			
		||||
  Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
 | 
			
		||||
  Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
 | 
			
		||||
  Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
 | 
			
		||||
  int32x4_t highBitsResult = highBitsA / highBitsB;
 | 
			
		||||
  int32x4_t lowBitsResult = lowBitsA / lowBitsB;
 | 
			
		||||
  return vuzp1q_s16(
 | 
			
		||||
      vreinterpretq_s16_s32(lowBitsResult),
 | 
			
		||||
      vreinterpretq_s16_s32(highBitsResult));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator/(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
 | 
			
		||||
  Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
 | 
			
		||||
  Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
 | 
			
		||||
  Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
 | 
			
		||||
  int16x8_t highBitsResult = highBitsA / highBitsB;
 | 
			
		||||
  int16x8_t lowBitsResult = lowBitsA / lowBitsB;
 | 
			
		||||
  return vuzp1q_s8(
 | 
			
		||||
      vreinterpretq_s8_s16(lowBitsResult),
 | 
			
		||||
      vreinterpretq_s8_s16(highBitsResult));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline clamp(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& min,
 | 
			
		||||
    const Vectorized<int64_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline clamp(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& min,
 | 
			
		||||
    const Vectorized<int32_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline clamp(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& min,
 | 
			
		||||
    const Vectorized<int16_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline clamp(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& min,
 | 
			
		||||
    const Vectorized<int8_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
 | 
			
		||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
 | 
			
		||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
    at::vec::Vectorized<int8_t> src) {
 | 
			
		||||
  auto s8x8 = vld1_s8(src.operator const int8_t*());
 | 
			
		||||
  auto s8x8 = vget_low_s8(src);
 | 
			
		||||
  auto s16x8 = vmovl_s8(s8x8);
 | 
			
		||||
 | 
			
		||||
  auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
 | 
			
		||||
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
 | 
			
		||||
Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
    at::vec::Vectorized<int8_t> src) {
 | 
			
		||||
  auto s8x8 = vld1_s8(src.operator const int8_t*());
 | 
			
		||||
  auto s8x8 = vget_low_s8(src);
 | 
			
		||||
  auto s16x8 = vmovl_s8(s8x8);
 | 
			
		||||
 | 
			
		||||
  auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,8 @@
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/detail/BLASConstants.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
#include <c10/cuda/CUDAStream.h>
 | 
			
		||||
#include <hipblaslt/hipblaslt-ext.hpp>
 | 
			
		||||
@ -1954,13 +1956,15 @@ void scaled_gemm(
 | 
			
		||||
    const void *result_scale_ptr,
 | 
			
		||||
    int64_t result_ld,
 | 
			
		||||
    ScalarType result_dtype,
 | 
			
		||||
    bool use_fast_accum) {
 | 
			
		||||
    bool use_fast_accum,
 | 
			
		||||
    const std::optional<Tensor>& alpha) {
 | 
			
		||||
  // Note: see `cublasCommonArgs` for various non-intuitive manupulations
 | 
			
		||||
  // of input arguments to this function.
 | 
			
		||||
  const auto computeType = CUBLAS_COMPUTE_32F;
 | 
			
		||||
  const auto scaleType = CUDA_R_32F;
 | 
			
		||||
  const float alpha_val = 1.0;
 | 
			
		||||
  const float beta_val = 0.0;
 | 
			
		||||
  // Note: alpha_val may change later depending on user-passed argument
 | 
			
		||||
  float alpha_val = 1.0;
 | 
			
		||||
  float beta_val = 0.0;
 | 
			
		||||
  CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
 | 
			
		||||
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
 | 
			
		||||
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
 | 
			
		||||
@ -2031,6 +2035,33 @@ void scaled_gemm(
 | 
			
		||||
    computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
 | 
			
		||||
    computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Handle user-passed alpha
 | 
			
		||||
  float *alpha_ptr = &alpha_val;
 | 
			
		||||
  float *beta_ptr = &beta_val;
 | 
			
		||||
 | 
			
		||||
  if (alpha.has_value()) {
 | 
			
		||||
    auto& a = alpha.value();
 | 
			
		||||
 | 
			
		||||
    // if device-tensor
 | 
			
		||||
    if (a.is_cuda()) {
 | 
			
		||||
      // NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
 | 
			
		||||
      //       valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
 | 
			
		||||
      //       we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
 | 
			
		||||
      //       managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
 | 
			
		||||
      //       for beta respectively.
 | 
			
		||||
      float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
 | 
			
		||||
      at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
 | 
			
		||||
      user_alpha.copy_(a);
 | 
			
		||||
      // Tell cublasLt we're using device-side pointers for alpha/beta
 | 
			
		||||
      auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
 | 
			
		||||
      computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
 | 
			
		||||
      alpha_ptr = user_alpha.data_ptr<float>();
 | 
			
		||||
      beta_ptr = at::cuda::detail::get_cublas_device_zero();
 | 
			
		||||
    } else {
 | 
			
		||||
      alpha_val = a.item<float>();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
    // For other data types, use the get_scale_mode function based on scaling type
 | 
			
		||||
    // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
 | 
			
		||||
    // but we must invoke get_scale_mode anyways to trigger the version checks.
 | 
			
		||||
@ -2048,6 +2079,7 @@ void scaled_gemm(
 | 
			
		||||
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
 | 
			
		||||
  int returnedResult = 0;
 | 
			
		||||
  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
 | 
			
		||||
 | 
			
		||||
  TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
 | 
			
		||||
      ltHandle,
 | 
			
		||||
      computeDesc.descriptor(),
 | 
			
		||||
@ -2088,10 +2120,10 @@ void scaled_gemm(
 | 
			
		||||
        auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
 | 
			
		||||
                ltHandle,
 | 
			
		||||
                computeDesc.descriptor(),
 | 
			
		||||
                &alpha_val,
 | 
			
		||||
                alpha_ptr,
 | 
			
		||||
                Adesc.descriptor(),
 | 
			
		||||
                Bdesc.descriptor(),
 | 
			
		||||
                &beta_val,
 | 
			
		||||
                beta_ptr,
 | 
			
		||||
                Cdesc.descriptor(),
 | 
			
		||||
                Ddesc.descriptor(),
 | 
			
		||||
                all_algos[i].algo,
 | 
			
		||||
@ -2110,17 +2142,14 @@ void scaled_gemm(
 | 
			
		||||
  cublasStatus_t cublasStatus = cublasLtMatmul(
 | 
			
		||||
      ltHandle,
 | 
			
		||||
      computeDesc.descriptor(),
 | 
			
		||||
      &alpha_val,
 | 
			
		||||
      alpha_ptr,
 | 
			
		||||
      mat1_ptr,
 | 
			
		||||
      Adesc.descriptor(),
 | 
			
		||||
      mat2_ptr,
 | 
			
		||||
      Bdesc.descriptor(),
 | 
			
		||||
      &beta_val,
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
      beta_ptr,
 | 
			
		||||
      // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
 | 
			
		||||
      result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
 | 
			
		||||
#else
 | 
			
		||||
      nullptr,
 | 
			
		||||
#endif // ifdef USE_ROCM
 | 
			
		||||
      Cdesc.descriptor(),
 | 
			
		||||
      result_ptr,
 | 
			
		||||
      Ddesc.descriptor(),
 | 
			
		||||
 | 
			
		||||
@ -161,7 +161,8 @@ void scaled_gemm(
 | 
			
		||||
    const void* result_scale_ptr,
 | 
			
		||||
    int64_t result_ld,
 | 
			
		||||
    ScalarType result_dtype,
 | 
			
		||||
    bool use_fast_accum);
 | 
			
		||||
    bool use_fast_accum,
 | 
			
		||||
    const std::optional<Tensor>& alpha);
 | 
			
		||||
 | 
			
		||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype)  CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
 | 
			
		||||
 */
 | 
			
		||||
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
 | 
			
		||||
  // The RNG state comprises the seed, and an offset used for Philox.
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
  constexpr size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
 | 
			
		||||
  auto rng_state = state_tensor.data_ptr<uint8_t>();
 | 
			
		||||
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
 | 
			
		||||
 * and size of the internal state.
 | 
			
		||||
 */
 | 
			
		||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
  constexpr size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  detail::check_rng_state(new_state);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										54
									
								
								aten/src/ATen/cuda/detail/BLASConstants.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								aten/src/ATen/cuda/detail/BLASConstants.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/Tensor.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace cuda {
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
__device__ __constant__ float cublas_one_device;
 | 
			
		||||
__device__ __constant__ float cublas_zero_device;
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_one() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    const float one = 1.f;
 | 
			
		||||
    AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  float *ptr;
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
 | 
			
		||||
  return ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_zero() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    const float zero = 0.f;
 | 
			
		||||
    AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  float *ptr;
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
 | 
			
		||||
  return ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float *get_user_alpha_ptr() {
 | 
			
		||||
  static float *alpha_ptr;
 | 
			
		||||
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return alpha_ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
} // namespace cuda
 | 
			
		||||
} // namespace at
 | 
			
		||||
							
								
								
									
										11
									
								
								aten/src/ATen/cuda/detail/BLASConstants.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								aten/src/ATen/cuda/detail/BLASConstants.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/core/TensorBase.h>
 | 
			
		||||
 | 
			
		||||
namespace at::cuda::detail {
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_one();
 | 
			
		||||
float *get_cublas_device_zero();
 | 
			
		||||
float *get_user_alpha_ptr();
 | 
			
		||||
 | 
			
		||||
} // namespace at::cuda::detail
 | 
			
		||||
@ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
 | 
			
		||||
          params->c_scale_ptr,
 | 
			
		||||
          params->ldc,
 | 
			
		||||
          params->c_dtype,
 | 
			
		||||
          params->use_fast_accum);
 | 
			
		||||
          params->use_fast_accum,
 | 
			
		||||
          std::nullopt /* alpha */);
 | 
			
		||||
      return OK;
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
 | 
			
		||||
  DispatchKey::CUDA,
 | 
			
		||||
  DispatchKey::CPU,
 | 
			
		||||
  DispatchKey::PrivateUse1,
 | 
			
		||||
  DispatchKey::SparseCPU,
 | 
			
		||||
  DispatchKey::SparseCUDA,
 | 
			
		||||
  DispatchKey::SparseCsrCPU,
 | 
			
		||||
  DispatchKey::SparseCsrCUDA,
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
 | 
			
		||||
 | 
			
		||||
@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
 | 
			
		||||
static const double SELU_SCALE = 1.0507009873554804934193349852946;
 | 
			
		||||
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
 | 
			
		||||
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
 | 
			
		||||
 | 
			
		||||
DEFINE_DISPATCH(elu_stub);
 | 
			
		||||
DEFINE_DISPATCH(elu_backward_stub);
 | 
			
		||||
 | 
			
		||||
@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
 | 
			
		||||
#if AT_BUILD_WITH_BLAS()
 | 
			
		||||
template <>
 | 
			
		||||
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
 | 
			
		||||
  auto intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  auto constexpr intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  return n <= intmax && incx <= intmax;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
 | 
			
		||||
    int64_t incx,
 | 
			
		||||
    [[maybe_unused]] float beta,
 | 
			
		||||
    int64_t incy) {
 | 
			
		||||
  auto intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  auto constexpr intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
 | 
			
		||||
         (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -658,6 +658,7 @@ static void check_shape_forward(const at::Tensor& input,
 | 
			
		||||
  TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
 | 
			
		||||
  TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(weight_dim == k,
 | 
			
		||||
           "Expected ", weight_dim, "-dimensional input for ", weight_dim,
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <ATen/native/Math.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/MathConstants.h>
 | 
			
		||||
@ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
 | 
			
		||||
  const static scalar_t kTailValues[] = {
 | 
			
		||||
  constexpr static scalar_t kTailValues[] = {
 | 
			
		||||
    0.0810614667953272,
 | 
			
		||||
    0.0413406959554092,
 | 
			
		||||
    0.0276779256849983,
 | 
			
		||||
@ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
 | 
			
		||||
    0.00925546218271273,
 | 
			
		||||
    0.00833056343336287
 | 
			
		||||
  };
 | 
			
		||||
  if (k <= 9) {
 | 
			
		||||
  if (k < std::size(kTailValues)) {
 | 
			
		||||
    return kTailValues[static_cast<size_t>(k)];
 | 
			
		||||
  }
 | 
			
		||||
  scalar_t kp1sq = (k + 1) * (k + 1);
 | 
			
		||||
 | 
			
		||||
@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
  // lanczos approximation
 | 
			
		||||
  static const scalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
  static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
    0.006061842346248906525783753964555936883222,
 | 
			
		||||
    0.5098416655656676188125178644804694509993,
 | 
			
		||||
    19.51992788247617482847860966235652136208,
 | 
			
		||||
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
    103794043.1163445451906271053616070238554,
 | 
			
		||||
    56906521.91347156388090791033559122686859
 | 
			
		||||
  };
 | 
			
		||||
  static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
  static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
    1.,
 | 
			
		||||
    66.,
 | 
			
		||||
    1925.,
 | 
			
		||||
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
 | 
			
		||||
  // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
 | 
			
		||||
  static const scalar_t d[25][25] =
 | 
			
		||||
  static constexpr scalar_t d[25][25] =
 | 
			
		||||
    {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
 | 
			
		||||
      1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
 | 
			
		||||
      3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,7 @@
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
static const int MIOPEN_DIM_MAX = 5;
 | 
			
		||||
static constexpr int MIOPEN_DIM_MAX = 5;
 | 
			
		||||
 | 
			
		||||
namespace at::meta {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
 | 
			
		||||
  // next broadcast all index tensors together
 | 
			
		||||
  try {
 | 
			
		||||
    indices = expand_outplace(indices);
 | 
			
		||||
  } catch (std::exception& e) {
 | 
			
		||||
  } catch (std::exception&) {
 | 
			
		||||
    TORCH_CHECK_INDEX(
 | 
			
		||||
        false,
 | 
			
		||||
        "shape mismatch: indexing tensors could not be broadcast together"
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
#include <ATen/core/ATen_fwd.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
#include <c10/core/SymInt.h>
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/AccumulateType.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
@ -1710,11 +1711,14 @@ Tensor narrow_symint(
 | 
			
		||||
      "], but got ",
 | 
			
		||||
      start,
 | 
			
		||||
      ")")
 | 
			
		||||
  if (start < 0) {
 | 
			
		||||
    start = start + cur_size;
 | 
			
		||||
  }
 | 
			
		||||
  // Bounds check without converting start:
 | 
			
		||||
  // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
 | 
			
		||||
  // length <= 0
 | 
			
		||||
  // - If start >= 0: need start + length <= cur_size
 | 
			
		||||
  auto end = start + length;
 | 
			
		||||
  TORCH_SYM_CHECK(
 | 
			
		||||
      start.sym_le(cur_size - length),
 | 
			
		||||
      (start.sym_lt(0).sym_and((end).sym_le(0)))
 | 
			
		||||
          .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
 | 
			
		||||
      "start (",
 | 
			
		||||
      start,
 | 
			
		||||
      ") + length (",
 | 
			
		||||
@ -1722,7 +1726,31 @@ Tensor narrow_symint(
 | 
			
		||||
      ") exceeds dimension size (",
 | 
			
		||||
      cur_size,
 | 
			
		||||
      ").");
 | 
			
		||||
  return at::slice_symint(self, dim, start, start + length, 1);
 | 
			
		||||
 | 
			
		||||
  if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) {
 | 
			
		||||
    return at::slice_symint(self, dim, start, end, 1);
 | 
			
		||||
  } else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) {
 | 
			
		||||
    // Avoid the complex symbolic expressions path for non-unbacked.
 | 
			
		||||
    return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Cannot statically determine the condition due to unbacked.
 | 
			
		||||
    // This is an interesting situation; when start is negative and
 | 
			
		||||
    // start + length == 0, slice and narrow do different things.
 | 
			
		||||
    // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
 | 
			
		||||
    // pass curr_size instead of 0. Otherwise, they would do the same thing.
 | 
			
		||||
    // This says at runtime: if start < 0 and end == 0, then pass curr_size
 | 
			
		||||
    // instead of 0.
 | 
			
		||||
 | 
			
		||||
    auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
 | 
			
		||||
    auto result =
 | 
			
		||||
        at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
 | 
			
		||||
 | 
			
		||||
    // Ensure slice allocated unbacked size is specialized to length.
 | 
			
		||||
    SymInt new_size = result.sym_size(dim);
 | 
			
		||||
    TORCH_SYM_CHECK(new_size.sym_eq(length), "")
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This overload exists purely for XLA, because they wanted to pass in
 | 
			
		||||
@ -1736,8 +1764,8 @@ Tensor narrow_tensor_symint(
 | 
			
		||||
      start.dim() == 0 &&
 | 
			
		||||
          isIntegralType(start.scalar_type(), /*includeBool=*/false),
 | 
			
		||||
      "start must be an 0-dim integral Tensor.");
 | 
			
		||||
  int64_t st = start.item<int64_t>();
 | 
			
		||||
  return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
 | 
			
		||||
  c10::SymInt st = start.item().toSymInt();
 | 
			
		||||
  return at::narrow_symint(self, dim, std::move(st), std::move(length));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::
 | 
			
		||||
 | 
			
		||||
@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
 | 
			
		||||
  // We keep this structure for BC and consider as deprecated.
 | 
			
		||||
  // See HelperInterpNearestExact as replacement
 | 
			
		||||
 | 
			
		||||
  static const int interp_size = 1;
 | 
			
		||||
  static constexpr int interp_size = 1;
 | 
			
		||||
 | 
			
		||||
  static inline void init_indices_weights(
 | 
			
		||||
    at::ScalarType output_type,
 | 
			
		||||
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
 | 
			
		||||
 | 
			
		||||
struct HelperInterpLinear : public HelperInterpBase {
 | 
			
		||||
 | 
			
		||||
  static const int interp_size = 2;
 | 
			
		||||
  static constexpr int interp_size = 2;
 | 
			
		||||
 | 
			
		||||
  // Compute indices and weights for each interpolated dimension
 | 
			
		||||
  // indices_weights = {
 | 
			
		||||
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
 | 
			
		||||
 | 
			
		||||
struct HelperInterpCubic : public HelperInterpBase {
 | 
			
		||||
 | 
			
		||||
  static const int interp_size = 4;
 | 
			
		||||
  static constexpr int interp_size = 4;
 | 
			
		||||
 | 
			
		||||
  // Compute indices and weights for each interpolated dimension
 | 
			
		||||
  // indices_weights = {
 | 
			
		||||
 | 
			
		||||
@ -1359,7 +1359,8 @@ _scaled_gemm(
 | 
			
		||||
          const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const bool use_fast_accum,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
          Tensor& out,
 | 
			
		||||
          const std::optional<Tensor>& alpha = std::nullopt) {
 | 
			
		||||
  cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
 | 
			
		||||
  const auto out_dtype_ = args.result->scalar_type();
 | 
			
		||||
  TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
 | 
			
		||||
@ -1410,7 +1411,8 @@ _scaled_gemm(
 | 
			
		||||
          args.scale_result_ptr,
 | 
			
		||||
          args.result_ld,
 | 
			
		||||
          out_dtype_,
 | 
			
		||||
          use_fast_accum);
 | 
			
		||||
          use_fast_accum,
 | 
			
		||||
          alpha);
 | 
			
		||||
      return out;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static const int BLOCK_THREADS = 256;
 | 
			
		||||
static constexpr int BLOCK_THREADS = 256;
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
#if defined (USE_ROCM)
 | 
			
		||||
 | 
			
		||||
@ -36,9 +36,9 @@ namespace at::native {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
static const int BLOCKDIMY = 16;
 | 
			
		||||
static constexpr int BLOCKDIMY = 16;
 | 
			
		||||
#else
 | 
			
		||||
static const int BLOCKDIMY = 32;
 | 
			
		||||
static constexpr int BLOCKDIMY = 32;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
 | 
			
		||||
@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
  // lanczos approximation
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
 | 
			
		||||
  static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
  constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
    0.006061842346248906525783753964555936883222,
 | 
			
		||||
    0.5098416655656676188125178644804694509993,
 | 
			
		||||
    19.51992788247617482847860966235652136208,
 | 
			
		||||
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
    103794043.1163445451906271053616070238554,
 | 
			
		||||
    56906521.91347156388090791033559122686859
 | 
			
		||||
  };
 | 
			
		||||
  static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
  constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
    1.,
 | 
			
		||||
    66.,
 | 
			
		||||
    1925.,
 | 
			
		||||
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  accscalar_t ax, fac, res, num, numfac;
 | 
			
		||||
  static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
  constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
    7.09782712893383996843E2 : 88.72283905206835;
 | 
			
		||||
  static const accscalar_t EXP1 = 2.718281828459045;
 | 
			
		||||
  static const accscalar_t lanczos_g = 6.024680040776729583740234375;
 | 
			
		||||
  constexpr accscalar_t EXP1 = 2.718281828459045;
 | 
			
		||||
  constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
 | 
			
		||||
 | 
			
		||||
  if (::fabs(a - x) > 0.4 * ::fabs(a)) {
 | 
			
		||||
    ax = a * ::log(x) - x - ::lgamma(a);
 | 
			
		||||
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
 | 
			
		||||
  // Compute igam using DLMF 8.11.4. [igam1]
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
  static const int MAXITER = 2000;
 | 
			
		||||
  constexpr int MAXITER = 2000;
 | 
			
		||||
 | 
			
		||||
  int i;
 | 
			
		||||
  accscalar_t ans, ax, c, r;
 | 
			
		||||
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
 | 
			
		||||
  accscalar_t fac = 1;
 | 
			
		||||
  accscalar_t sum = 0;
 | 
			
		||||
  accscalar_t term, logx;
 | 
			
		||||
  static const int MAXITER = 2000;
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  constexpr int MAXITER = 2000;
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
 | 
			
		||||
  for (n = 1; n < MAXITER; n++) {
 | 
			
		||||
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
 | 
			
		||||
  // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  static const accscalar_t d[25][25] =
 | 
			
		||||
  constexpr accscalar_t d[25][25] =
 | 
			
		||||
    {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
 | 
			
		||||
    {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
 | 
			
		||||
    {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
 | 
			
		||||
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
 | 
			
		||||
 | 
			
		||||
  int k, n, sgn;
 | 
			
		||||
  int maxpow = 0;
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
  accscalar_t lambda = x / a;
 | 
			
		||||
  accscalar_t sigma = (x - a) / a;
 | 
			
		||||
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
 | 
			
		||||
  int i;
 | 
			
		||||
  accscalar_t ans, ax, c, yc, r, t, y, z;
 | 
			
		||||
  accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
 | 
			
		||||
  static const int MAXITER = 2000;
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  constexpr int MAXITER = 2000;
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
  static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
  constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
    4.503599627370496e15 : 16777216.;
 | 
			
		||||
  static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
  constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
    2.22044604925031308085e-16 : 5.9604644775390625E-8;
 | 
			
		||||
 | 
			
		||||
  ax = _igam_helper_fac(a, x);
 | 
			
		||||
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  accscalar_t absxma_a;
 | 
			
		||||
 | 
			
		||||
  static const accscalar_t SMALL = 20.0;
 | 
			
		||||
  static const accscalar_t LARGE = 200.0;
 | 
			
		||||
  static const accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  static const accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
  constexpr accscalar_t SMALL = 20.0;
 | 
			
		||||
  constexpr accscalar_t LARGE = 200.0;
 | 
			
		||||
  constexpr accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  constexpr accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
 | 
			
		||||
  if ((x < 0) || (a < 0)) {
 | 
			
		||||
    // out of defined-region of the function
 | 
			
		||||
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  accscalar_t absxma_a;
 | 
			
		||||
  static const accscalar_t SMALL = 20.0;
 | 
			
		||||
  static const accscalar_t LARGE = 200.0;
 | 
			
		||||
  static const accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  static const accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
  constexpr accscalar_t SMALL = 20.0;
 | 
			
		||||
  constexpr accscalar_t LARGE = 200.0;
 | 
			
		||||
  constexpr accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  constexpr accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
 | 
			
		||||
  // boundary values following SciPy
 | 
			
		||||
  if ((x < 0) || (a < 0)) {
 | 
			
		||||
 | 
			
		||||
@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
 | 
			
		||||
const auto digamma_string = jiterator_stringify(
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  T digamma(T x) {
 | 
			
		||||
    static const double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
    static constexpr double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
 | 
			
		||||
    // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
 | 
			
		||||
    if (x == 0) {
 | 
			
		||||
@ -3072,9 +3072,9 @@ template <typename scalar_t>
 | 
			
		||||
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
 | 
			
		||||
  // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  static const double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
  const accscalar_t PSI_10 = 2.25175258906672110764;
 | 
			
		||||
  const accscalar_t A[] = {
 | 
			
		||||
  static constexpr double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
  constexpr accscalar_t PSI_10 = 2.25175258906672110764;
 | 
			
		||||
  constexpr accscalar_t A[] = {
 | 
			
		||||
      8.33333333333333333333E-2,
 | 
			
		||||
      -2.10927960927960927961E-2,
 | 
			
		||||
      7.57575757575757575758E-3,
 | 
			
		||||
 | 
			
		||||
@ -1097,11 +1097,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
 | 
			
		||||
  // threads with different threadIdx.x are independent and will produce results for different outputs.
 | 
			
		||||
  // In such case, values in each loaded vector always correspond to different outputs.
 | 
			
		||||
  if (fastest_moving_stride == sizeof(scalar_t)) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
 | 
			
		||||
#else
 | 
			
		||||
    if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
 | 
			
		||||
#endif
 | 
			
		||||
      // Case 1: "vectorize along input"
 | 
			
		||||
      // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
 | 
			
		||||
      // we should avoid vectorization.
 | 
			
		||||
 | 
			
		||||
@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
 | 
			
		||||
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
 | 
			
		||||
void mean_kernel_impl(TensorIterator& iter) {
 | 
			
		||||
  //  returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
 | 
			
		||||
  constexpr bool is_16_bits = sizeof(scalar_t) == 2;
 | 
			
		||||
  using factor_t = typename c10::scalar_value_type<acc_t>::type;
 | 
			
		||||
  factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
 | 
			
		||||
  gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
 | 
			
		||||
  if constexpr (is_16_bits) {
 | 
			
		||||
    gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
 | 
			
		||||
  } else {
 | 
			
		||||
    gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void mean_kernel_cuda(TensorIterator& iter) {
 | 
			
		||||
 | 
			
		||||
@ -13,24 +13,19 @@ namespace at::native {
 | 
			
		||||
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
 | 
			
		||||
struct sum_functor {
 | 
			
		||||
  void operator()(TensorIterator& iter) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    // Half and BFloat16 can be packed in groups of up to 8 elements and
 | 
			
		||||
    // can use *_DWORDX4 instructions to achieve that.
 | 
			
		||||
    const bool is_16_bits =
 | 
			
		||||
      ( (std::is_same<at::Half, scalar_t>::value) ||
 | 
			
		||||
        (std::is_same<at::BFloat16, scalar_t>::value) );
 | 
			
		||||
    if (is_16_bits) {
 | 
			
		||||
    const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
 | 
			
		||||
      return a + b;
 | 
			
		||||
    };
 | 
			
		||||
    constexpr bool is_16_bits = sizeof(scalar_t) == 2;
 | 
			
		||||
    if constexpr (is_16_bits) {
 | 
			
		||||
      gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
 | 
			
		||||
        iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
 | 
			
		||||
          return a + b;
 | 
			
		||||
        }));
 | 
			
		||||
      return;
 | 
			
		||||
        iter, func_wrapper<out_t>(sum_combine)
 | 
			
		||||
      );
 | 
			
		||||
    } else {
 | 
			
		||||
      gpu_reduce_kernel<scalar_t, out_t>(
 | 
			
		||||
        iter, func_wrapper<out_t>(sum_combine)
 | 
			
		||||
      );
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    gpu_reduce_kernel<scalar_t, out_t>(
 | 
			
		||||
        iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
 | 
			
		||||
          return a + b;
 | 
			
		||||
        }));
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static const int size = 2;
 | 
			
		||||
  static constexpr int size = 2;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// taken from
 | 
			
		||||
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static const int size = 4;
 | 
			
		||||
  static constexpr int size = 4;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
 | 
			
		||||
@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
// Helper function to compute output pixel range that can contribute to input pixel
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
__device__ __forceinline__ void compute_output_range(
 | 
			
		||||
    int input_pos,
 | 
			
		||||
    accscalar_t scale,
 | 
			
		||||
    int output_size,
 | 
			
		||||
    bool align_corners,
 | 
			
		||||
    int& min_output,
 | 
			
		||||
    int& max_output) {
 | 
			
		||||
  accscalar_t lo, hi;
 | 
			
		||||
  if (align_corners) {
 | 
			
		||||
      lo = static_cast<accscalar_t>(input_pos - 1) / scale;
 | 
			
		||||
      hi = static_cast<accscalar_t>(input_pos + 1) / scale;
 | 
			
		||||
  } else {
 | 
			
		||||
      lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
      hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
  }
 | 
			
		||||
  min_output = max(0, static_cast<int>(ceil(lo)));
 | 
			
		||||
  max_output = min(output_size - 1, static_cast<int>(floor(hi)));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(1024)
 | 
			
		||||
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
    const bool align_corners,
 | 
			
		||||
    scalar_t* __restrict__ idata,
 | 
			
		||||
    const scalar_t* __restrict__ odata) {
 | 
			
		||||
  // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    // Decode input pixel coordinates
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
    const int w1 = index_temp % width1;
 | 
			
		||||
    index_temp /= width1;
 | 
			
		||||
    const int h1 = index_temp % height1;
 | 
			
		||||
    const size_t nc_idx = index_temp / height1;
 | 
			
		||||
 | 
			
		||||
    accscalar_t grad_sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Find range of output pixels that could interpolate from this input pixel
 | 
			
		||||
    int h2_min, h2_max, w2_min, w2_max;
 | 
			
		||||
    compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
 | 
			
		||||
    compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
 | 
			
		||||
 | 
			
		||||
    // Iterate over potential output pixels
 | 
			
		||||
    for (int h2 = h2_min; h2 <= h2_max; h2++) {
 | 
			
		||||
      for (int w2 = w2_min; w2 <= w2_max; w2++) {
 | 
			
		||||
        // Compute source coordinates for this output pixel
 | 
			
		||||
        const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rheight, h2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int h1_base = (int)h1r;
 | 
			
		||||
        const int h1p = (h1_base < height1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t h1lambda = h1r - h1_base;
 | 
			
		||||
        const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
 | 
			
		||||
 | 
			
		||||
        const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rwidth, w2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int w1_base = (int)w1r;
 | 
			
		||||
        const int w1p = (w1_base < width1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t w1lambda = w1r - w1_base;
 | 
			
		||||
        const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
 | 
			
		||||
 | 
			
		||||
        // Check if our input pixel participates in this interpolation and accumulate all weights
 | 
			
		||||
        // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
 | 
			
		||||
        // to the same pixel, so we need to accumulate weights from all matching positions
 | 
			
		||||
        accscalar_t weight = 0;
 | 
			
		||||
 | 
			
		||||
        // Check all four interpolation positions and accumulate weights
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base) {
 | 
			
		||||
          weight += h0lambda * w0lambda;  // top-left
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base) {
 | 
			
		||||
          weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (weight > 0) {
 | 
			
		||||
          const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
 | 
			
		||||
          grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write accumulated gradient (no atomics needed)
 | 
			
		||||
    idata[index] = static_cast<scalar_t>(grad_sum);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
        static_cast<scalar_t>(h1lambda * w1lambda * d2val),
 | 
			
		||||
        true);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
  // threads are not covering the whole input tensor.
 | 
			
		||||
  grad_input.zero_();
 | 
			
		||||
 | 
			
		||||
  const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
  const int num_threads = std::min(
 | 
			
		||||
      at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  constexpr bool use_input = true;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr bool use_input = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half, at::ScalarType::BFloat16,
 | 
			
		||||
      grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
 | 
			
		||||
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
 | 
			
		||||
              input_height,
 | 
			
		||||
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
 | 
			
		||||
             num_threads,
 | 
			
		||||
 | 
			
		||||
@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
 | 
			
		||||
  if constexpr (!rms_norm){
 | 
			
		||||
    U delta = val - curr_sum.mean;
 | 
			
		||||
    U new_count = curr_sum.count + 1.f;
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
 | 
			
		||||
    U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
 | 
			
		||||
#else
 | 
			
		||||
    U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
 | 
			
		||||
#endif
 | 
			
		||||
    return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
 | 
			
		||||
  } else{
 | 
			
		||||
    return {0.f, curr_sum.sigma2 + val * val, 0};
 | 
			
		||||
@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
 | 
			
		||||
    U count = dataA.count + dataB.count;
 | 
			
		||||
    U mean, sigma2;
 | 
			
		||||
    if (count > decltype(dataB.count){0}) {
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
 | 
			
		||||
      auto coef = __builtin_amdgcn_rcpf(count);
 | 
			
		||||
#else
 | 
			
		||||
      auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
 | 
			
		||||
#endif
 | 
			
		||||
      auto nA = dataA.count * coef;
 | 
			
		||||
      auto nB = dataB.count * coef;
 | 
			
		||||
      mean = nA*dataA.mean + nB*dataB.mean;
 | 
			
		||||
 | 
			
		||||
@ -487,9 +487,7 @@ std::unique_ptr<fe::graph::Graph> build_graph(
 | 
			
		||||
  auto scaled_dot_product_flash_attention_options =
 | 
			
		||||
      fe::graph::SDPA_attributes()
 | 
			
		||||
          .set_name("CUDNN_SDPA")
 | 
			
		||||
          .set_is_inference(return_softmaxstats == false)
 | 
			
		||||
          // TODO(eqy): switch to this API once cuDNN FE is upgraded
 | 
			
		||||
          // .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_causal_mask(is_causal)
 | 
			
		||||
          .set_attn_scale(attn_scale);
 | 
			
		||||
  if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
 | 
			
		||||
@ -707,9 +705,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
 | 
			
		||||
  auto scaled_dot_product_flash_attention_options =
 | 
			
		||||
      fe::graph::SDPA_attributes()
 | 
			
		||||
          .set_name("CUDNN_SDPA_NESTEDTENSOR")
 | 
			
		||||
          .set_is_inference(return_softmaxstats == false)
 | 
			
		||||
          // TODO(eqy): switch to this API once cuDNN FE is upgraded
 | 
			
		||||
          // .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_causal_mask(is_causal)
 | 
			
		||||
          .set_attn_scale(attn_scale)
 | 
			
		||||
          .set_seq_len_q(SEQ_LEN_Q_)
 | 
			
		||||
 | 
			
		||||
@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
 | 
			
		||||
  // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
 | 
			
		||||
  // else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
 | 
			
		||||
  // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
 | 
			
		||||
  static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
 | 
			
		||||
  constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
 | 
			
		||||
  if (mat1.dim() == 1 && mat2.dim() == 1) {
 | 
			
		||||
    // aten::dot
 | 
			
		||||
    return mat1.size(0) > mkldnn_gemm_min_size;
 | 
			
		||||
 | 
			
		||||
@ -1,16 +1,16 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <c10/metal/common.h>
 | 
			
		||||
 | 
			
		||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
 | 
			
		||||
struct CatLargeSharedParams {
 | 
			
		||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
 | 
			
		||||
struct CatSharedParams {
 | 
			
		||||
  int32_t ndim;
 | 
			
		||||
  int32_t cat_dim;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> output_strides;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> output_sizes;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
 | 
			
		||||
struct CatLargeInputParams {
 | 
			
		||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
 | 
			
		||||
struct CatInputParams {
 | 
			
		||||
  idx_type_t cat_dim_offset;
 | 
			
		||||
  idx_type_t input_element_offset;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> input_strides;
 | 
			
		||||
 | 
			
		||||
@ -6,26 +6,25 @@
 | 
			
		||||
using namespace metal;
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
 | 
			
		||||
template <typename T_in, typename T_out>
 | 
			
		||||
kernel void cat_large(
 | 
			
		||||
template <typename I, typename T_in, typename T_out>
 | 
			
		||||
kernel void cat(
 | 
			
		||||
    constant T_in* input [[buffer(0)]],
 | 
			
		||||
    device T_out* output [[buffer(1)]],
 | 
			
		||||
    constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
 | 
			
		||||
    constant CatLargeInputParams<>& input_params [[buffer(3)]],
 | 
			
		||||
    constant CatSharedParams<I>& shared_params [[buffer(2)]],
 | 
			
		||||
    constant CatInputParams<I>& input_params [[buffer(3)]],
 | 
			
		||||
    uint tid [[thread_position_in_grid]]) {
 | 
			
		||||
  auto ndim = shared_params.ndim;
 | 
			
		||||
  auto cat_dim = shared_params.cat_dim;
 | 
			
		||||
  constant auto& output_strides = shared_params.output_strides;
 | 
			
		||||
  constant auto& output_sizes = shared_params.output_sizes;
 | 
			
		||||
 | 
			
		||||
  auto cat_dim_offset = input_params.cat_dim_offset;
 | 
			
		||||
  auto input_element_offset = input_params.input_element_offset;
 | 
			
		||||
  constant auto& input_strides = input_params.input_strides;
 | 
			
		||||
  constant auto& input_sizes = input_params.input_sizes;
 | 
			
		||||
 | 
			
		||||
  auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
 | 
			
		||||
  int64_t input_offset = 0;
 | 
			
		||||
  int64_t output_offset = 0;
 | 
			
		||||
  auto input_element_idx = static_cast<I>(tid) + input_element_offset;
 | 
			
		||||
  I input_offset = 0;
 | 
			
		||||
  I output_offset = 0;
 | 
			
		||||
 | 
			
		||||
  for (auto dim = ndim - 1; dim >= 0; dim--) {
 | 
			
		||||
    auto dim_size = input_sizes[dim];
 | 
			
		||||
@ -42,41 +41,45 @@ kernel void cat_large(
 | 
			
		||||
  output[output_offset] = static_cast<T_out>(input[input_offset]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out)                           \
 | 
			
		||||
  template [[host_name("cat_large_" #T_in "_" #T_out)]]              \
 | 
			
		||||
  kernel void cat_large<T_in, T_out>(                                \
 | 
			
		||||
      constant T_in * input [[buffer(0)]],                           \
 | 
			
		||||
      device T_out * output [[buffer(1)]],                           \
 | 
			
		||||
      constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
 | 
			
		||||
      constant CatLargeInputParams<> & input_params [[buffer(3)]],   \
 | 
			
		||||
#define REGISTER_CAT_OP(I, T_in, T_out)                          \
 | 
			
		||||
  template [[host_name("cat_" #I "_" #T_in "_" #T_out)]]         \
 | 
			
		||||
  kernel void cat<I, T_in, T_out>(                               \
 | 
			
		||||
      constant T_in * input [[buffer(0)]],                       \
 | 
			
		||||
      device T_out * output [[buffer(1)]],                       \
 | 
			
		||||
      constant CatSharedParams<I> & shared_params [[buffer(2)]], \
 | 
			
		||||
      constant CatInputParams<I> & input_params [[buffer(3)]],   \
 | 
			
		||||
      uint tid [[thread_position_in_grid]]);
 | 
			
		||||
 | 
			
		||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(float, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(half, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(bfloat, T_out);              \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(int, T_out);                 \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(uint, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(long, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(ulong, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(short, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(ushort, T_out);              \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(char, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(uchar, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(bool, T_out);
 | 
			
		||||
#define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \
 | 
			
		||||
  REGISTER_CAT_OP(I, float, T_out);               \
 | 
			
		||||
  REGISTER_CAT_OP(I, half, T_out);                \
 | 
			
		||||
  REGISTER_CAT_OP(I, bfloat, T_out);              \
 | 
			
		||||
  REGISTER_CAT_OP(I, int, T_out);                 \
 | 
			
		||||
  REGISTER_CAT_OP(I, uint, T_out);                \
 | 
			
		||||
  REGISTER_CAT_OP(I, long, T_out);                \
 | 
			
		||||
  REGISTER_CAT_OP(I, ulong, T_out);               \
 | 
			
		||||
  REGISTER_CAT_OP(I, short, T_out);               \
 | 
			
		||||
  REGISTER_CAT_OP(I, ushort, T_out);              \
 | 
			
		||||
  REGISTER_CAT_OP(I, char, T_out);                \
 | 
			
		||||
  REGISTER_CAT_OP(I, uchar, T_out);               \
 | 
			
		||||
  REGISTER_CAT_OP(I, bool, T_out);
 | 
			
		||||
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
 | 
			
		||||
#define REGISTER_CAT_FOR_INDEX_TYPE(I)        \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float);  \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half);   \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int);    \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint);   \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long);   \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong);  \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short);  \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char);   \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar);  \
 | 
			
		||||
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool);   \
 | 
			
		||||
                                              \
 | 
			
		||||
  REGISTER_CAT_OP(I, float2, float2);         \
 | 
			
		||||
  REGISTER_CAT_OP(I, half2, half2);
 | 
			
		||||
 | 
			
		||||
REGISTER_CAT_LARGE_OP(float2, float2);
 | 
			
		||||
REGISTER_CAT_LARGE_OP(half2, half2);
 | 
			
		||||
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
 | 
			
		||||
REGISTER_CAT_FOR_INDEX_TYPE(int32_t);
 | 
			
		||||
 | 
			
		||||
@ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
 | 
			
		||||
       other.size(0) > max_stride_size || other.size(1) > max_stride_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void map_mps_decomposition_error_code_to_blas(const Tensor& status) {
 | 
			
		||||
  const auto& status_flat = status.view(-1);
 | 
			
		||||
 | 
			
		||||
  for (const auto i : c10::irange(status_flat.size(0))) {
 | 
			
		||||
    int code = status_flat[i].item<int>();
 | 
			
		||||
    switch (code) {
 | 
			
		||||
      case MPSMatrixDecompositionStatusSuccess:
 | 
			
		||||
        status_flat[i] = 0;
 | 
			
		||||
        break;
 | 
			
		||||
      case MPSMatrixDecompositionStatusNonPositiveDefinite:
 | 
			
		||||
      case MPSMatrixDecompositionStatusSingular:
 | 
			
		||||
        status_flat[i] = 2;
 | 
			
		||||
        break;
 | 
			
		||||
      case MPSMatrixDecompositionStatusFailure:
 | 
			
		||||
        status_flat[i] = -1;
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
 | 
			
		||||
@ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
 | 
			
		||||
                  "mpsmatrixdecompositionstatus for details.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  map_mps_decomposition_error_code_to_blas(info);
 | 
			
		||||
 | 
			
		||||
  if (!left) {
 | 
			
		||||
    // If this was a right solve, transpose the result back
 | 
			
		||||
    result.copy_(result_t.transpose(-2, -1).contiguous());
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@
 | 
			
		||||
#include <ATen/MemoryOverlap.h>
 | 
			
		||||
#include <ATen/WrapDimUtils.h>
 | 
			
		||||
#include <ATen/mps/MPSProfiler.h>
 | 
			
		||||
#include <ATen/native/Pool.h>
 | 
			
		||||
#include <ATen/native/TensorShape.h>
 | 
			
		||||
#include <ATen/native/TypeProperties.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
@ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This implementation of cat is used only if one of the inputs or the output is
 | 
			
		||||
// too large to use MPSGraph.
 | 
			
		||||
template <typename T>
 | 
			
		||||
std::string get_type_str();
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
std::string get_type_str<int64_t>() {
 | 
			
		||||
  return "int64_t";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
std::string get_type_str<int32_t>() {
 | 
			
		||||
  return "int32_t";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOTE: `output` is expected to already have the correct size.
 | 
			
		||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
 | 
			
		||||
  CatLargeSharedParams shared_params;
 | 
			
		||||
template <typename idx_type_t>
 | 
			
		||||
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
 | 
			
		||||
  CatSharedParams<idx_type_t> shared_params;
 | 
			
		||||
 | 
			
		||||
  shared_params.ndim = output.dim();
 | 
			
		||||
  shared_params.cat_dim = dimension;
 | 
			
		||||
 | 
			
		||||
  for (const auto dim : c10::irange(output.dim())) {
 | 
			
		||||
    shared_params.output_strides[dim] = output.stride(dim);
 | 
			
		||||
    shared_params.output_sizes[dim] = output.size(dim);
 | 
			
		||||
    shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim));
 | 
			
		||||
    shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t cat_dim_offset = 0;
 | 
			
		||||
  idx_type_t cat_dim_offset = 0;
 | 
			
		||||
  size_t input_idx = 0;
 | 
			
		||||
  MPSStream* stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  // Launch a separate kernels for each input. This will produce some overhead,
 | 
			
		||||
  // but that should be relatively minimal since at least one of the inputs is
 | 
			
		||||
  // very large. In order to launch only one kernel to process all inputs, we
 | 
			
		||||
  // would have to copy all the input tensor data into a packed buffer, which
 | 
			
		||||
  // would not be ideal.
 | 
			
		||||
  // Launch a separate kernels for each input. This will produce some overhead.
 | 
			
		||||
  // In order to launch only one kernel to process all inputs, we would have to
 | 
			
		||||
  // copy all the input tensor data into a packed buffer, which would not be
 | 
			
		||||
  // ideal.
 | 
			
		||||
  for (const Tensor& input : inputs) {
 | 
			
		||||
    if (input.numel() == 0) {
 | 
			
		||||
      continue;
 | 
			
		||||
@ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen
 | 
			
		||||
 | 
			
		||||
    for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
 | 
			
		||||
      auto num_threads = std::min(max_num_threads, numel_remaining);
 | 
			
		||||
      CatLargeInputParams input_params;
 | 
			
		||||
      CatInputParams<idx_type_t> input_params;
 | 
			
		||||
 | 
			
		||||
      input_params.cat_dim_offset = cat_dim_offset;
 | 
			
		||||
      input_params.input_element_offset = input.numel() - numel_remaining;
 | 
			
		||||
      input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset);
 | 
			
		||||
      input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining);
 | 
			
		||||
 | 
			
		||||
      for (const auto dim : c10::irange(input.dim())) {
 | 
			
		||||
        input_params.input_strides[dim] = input.stride(dim);
 | 
			
		||||
        input_params.input_sizes[dim] = input.size(dim);
 | 
			
		||||
        input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim));
 | 
			
		||||
        input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
        @autoreleasepool {
 | 
			
		||||
          id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
 | 
			
		||||
          auto pipeline_state = lib.getPipelineStateForFunc(
 | 
			
		||||
              fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
 | 
			
		||||
          auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}",
 | 
			
		||||
                                                                        get_type_str<idx_type_t>(),
 | 
			
		||||
                                                                        scalarToMetalTypeString(input),
 | 
			
		||||
                                                                        scalarToMetalTypeString(output)));
 | 
			
		||||
          getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
 | 
			
		||||
          [computeEncoder setComputePipelineState:pipeline_state];
 | 
			
		||||
          mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
 | 
			
		||||
@ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
 | 
			
		||||
              " and out is on ",
 | 
			
		||||
              out.device());
 | 
			
		||||
 | 
			
		||||
  // TODO: For better performance by eliminating input tensor gathering and post transpose,
 | 
			
		||||
  // TODO: it is better to keep the out tensor's memory format.
 | 
			
		||||
  // TODO: dimension needs to be recomputed as:
 | 
			
		||||
  // TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
 | 
			
		||||
  if (needsGather(out)) {
 | 
			
		||||
    out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
  std::vector<int64_t> size(notSkippedTensor.sizes().vec());
 | 
			
		||||
 | 
			
		||||
  // Compute size of the result in the cat dimension
 | 
			
		||||
@ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
 | 
			
		||||
  has_large_tensor |= isTooLargeForMPSGraph(out);
 | 
			
		||||
 | 
			
		||||
  if (has_large_tensor) {
 | 
			
		||||
    return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct CachedGraph : public MPSCachedGraph {
 | 
			
		||||
    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
 | 
			
		||||
    std::vector<MPSGraphTensor*> inputTensors_;
 | 
			
		||||
    MPSGraphTensor* outputTensor_ = nil;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
    std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
 | 
			
		||||
        (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
 | 
			
		||||
    if (!all_same_dtype) {
 | 
			
		||||
      key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
 | 
			
		||||
    } else {
 | 
			
		||||
      key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
 | 
			
		||||
    }
 | 
			
		||||
    for (auto idx : skipped_tensor_indices) {
 | 
			
		||||
      key += "," + std::to_string(idx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
 | 
			
		||||
      std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
 | 
			
		||||
      newCachedGraph->inputTensors_.reserve(len_tensor_array);
 | 
			
		||||
 | 
			
		||||
      for (const auto idx : c10::irange(len_tensor_array)) {
 | 
			
		||||
        const Tensor& tensor = input_tensors[idx];
 | 
			
		||||
        auto scalar_type = getMPSScalarType(tensor.scalar_type());
 | 
			
		||||
        if (tensor.scalar_type() == kBool) {
 | 
			
		||||
          scalar_type = MPSDataTypeInt8;
 | 
			
		||||
        }
 | 
			
		||||
        newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
 | 
			
		||||
        if (tensor.scalar_type() != out_dtype) {
 | 
			
		||||
          castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
 | 
			
		||||
                                                toType:getMPSDataType(out_dtype)
 | 
			
		||||
                                                  name:@"castInput"];
 | 
			
		||||
        } else {
 | 
			
		||||
          castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
 | 
			
		||||
      MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
 | 
			
		||||
                                                   dimension:dimension // Maybe convert this from int64_t -> int32
 | 
			
		||||
                                                        name:nil];
 | 
			
		||||
      if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
 | 
			
		||||
        outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
 | 
			
		||||
      }
 | 
			
		||||
      newCachedGraph->outputTensor_ = outputTensor;
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    std::vector<Placeholder> inputPlaceholders;
 | 
			
		||||
    int i = 0;
 | 
			
		||||
    int t_idx = 0;
 | 
			
		||||
    for (const Tensor& tensor : materialized_inputs) {
 | 
			
		||||
      if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
 | 
			
		||||
        auto scalar_type = getMPSScalarType(tensor.scalar_type());
 | 
			
		||||
        if (tensor.scalar_type() == kBool) {
 | 
			
		||||
          scalar_type = MPSDataTypeInt8;
 | 
			
		||||
        }
 | 
			
		||||
        inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type);
 | 
			
		||||
        t_idx++;
 | 
			
		||||
      }
 | 
			
		||||
      i++;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto outputDataType = getMPSScalarType(out.scalar_type());
 | 
			
		||||
    Placeholder outputPlaceholder =
 | 
			
		||||
        Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
 | 
			
		||||
 | 
			
		||||
    NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
 | 
			
		||||
    for (auto& inputPlaceholder : inputPlaceholders) {
 | 
			
		||||
      feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
 | 
			
		||||
    }
 | 
			
		||||
    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
 | 
			
		||||
    return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
 | 
			
		||||
  } else {
 | 
			
		||||
    return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6531,6 +6531,7 @@
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: var
 | 
			
		||||
    MPS: var_mps
 | 
			
		||||
    MTIA: var_mtia
 | 
			
		||||
  tags: core
 | 
			
		||||
 | 
			
		||||
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
 | 
			
		||||
 | 
			
		||||
@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
 | 
			
		||||
 | 
			
		||||
#if defined(__ARM_NEON__) || defined(__aarch64__)
 | 
			
		||||
 | 
			
		||||
const static int PARALLEL_THRESHOLD = 1 << 20;
 | 
			
		||||
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
 | 
			
		||||
 | 
			
		||||
// Generic template defaults to naive quantize implementation
 | 
			
		||||
template <typename T>
 | 
			
		||||
 | 
			
		||||
@ -1388,7 +1388,7 @@ namespace at::native {
 | 
			
		||||
    TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
 | 
			
		||||
        "onednn int8 linear: act scale/zp size should be 1/<=1");
 | 
			
		||||
    static std::optional<at::Tensor> other = std::nullopt;
 | 
			
		||||
    static const std::string_view binary_post_op = "none";
 | 
			
		||||
    constexpr std::string_view binary_post_op = "none";
 | 
			
		||||
    int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
 | 
			
		||||
    return linear_int8_with_onednn_weight(
 | 
			
		||||
        act, act_scale.item().toDouble(), act_zp,
 | 
			
		||||
 | 
			
		||||
@ -16,8 +16,8 @@ namespace {
 | 
			
		||||
 | 
			
		||||
#ifdef USE_PYTORCH_QNNPACK
 | 
			
		||||
 | 
			
		||||
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
 | 
			
		||||
const static int qnnpack_softmax_output_zero_point = 0;
 | 
			
		||||
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
 | 
			
		||||
constexpr static int qnnpack_softmax_output_zero_point = 0;
 | 
			
		||||
 | 
			
		||||
bool is_qnnpack_compatible(
 | 
			
		||||
    const Tensor& qx,
 | 
			
		||||
 | 
			
		||||
@ -110,9 +110,9 @@ class ApplyLogSumExp {
 | 
			
		||||
  using ElementCompute = ElementCompute_;
 | 
			
		||||
  using ElementLSE = ElementLSE_;
 | 
			
		||||
 | 
			
		||||
  static int const kElementsPerAccess = ElementsPerAccess;
 | 
			
		||||
  static int const kCount = kElementsPerAccess;
 | 
			
		||||
  static const ScaleType::Kind kScale =
 | 
			
		||||
  static int constexpr kElementsPerAccess = ElementsPerAccess;
 | 
			
		||||
  static int constexpr kCount = kElementsPerAccess;
 | 
			
		||||
  static constexpr ScaleType::Kind kScale =
 | 
			
		||||
      cutlass::epilogue::thread::ScaleType::NoBetaScaling;
 | 
			
		||||
 | 
			
		||||
  using FragmentOutput = Array<ElementOutput, kCount>;
 | 
			
		||||
 | 
			
		||||
@ -14,16 +14,16 @@ using namespace at;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
const auto int_min = std::numeric_limits<int>::min();
 | 
			
		||||
const auto int_max = std::numeric_limits<int>::max();
 | 
			
		||||
const auto long_min = std::numeric_limits<int64_t>::min();
 | 
			
		||||
const auto long_max = std::numeric_limits<int64_t>::max();
 | 
			
		||||
const auto float_lowest = std::numeric_limits<float>::lowest();
 | 
			
		||||
const auto float_min = std::numeric_limits<float>::min();
 | 
			
		||||
const auto float_max = std::numeric_limits<float>::max();
 | 
			
		||||
const auto double_lowest = std::numeric_limits<double>::lowest();
 | 
			
		||||
const auto double_min = std::numeric_limits<double>::min();
 | 
			
		||||
const auto double_max = std::numeric_limits<double>::max();
 | 
			
		||||
constexpr auto int_min = std::numeric_limits<int>::min();
 | 
			
		||||
constexpr auto int_max = std::numeric_limits<int>::max();
 | 
			
		||||
constexpr auto long_min = std::numeric_limits<int64_t>::min();
 | 
			
		||||
constexpr auto long_max = std::numeric_limits<int64_t>::max();
 | 
			
		||||
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
 | 
			
		||||
constexpr auto float_min = std::numeric_limits<float>::min();
 | 
			
		||||
constexpr auto float_max = std::numeric_limits<float>::max();
 | 
			
		||||
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
 | 
			
		||||
constexpr auto double_min = std::numeric_limits<double>::min();
 | 
			
		||||
constexpr auto double_max = std::numeric_limits<double>::max();
 | 
			
		||||
 | 
			
		||||
const std::vector<int> ints {
 | 
			
		||||
  int_min,
 | 
			
		||||
 | 
			
		||||
@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
 | 
			
		||||
 | 
			
		||||
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
 | 
			
		||||
  // The RNG state comprises the seed, and an offset used for Philox.
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
  constexpr size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t offset_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  // The internal state is returned as a CPU byte tensor.
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu(
 | 
			
		||||
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
 | 
			
		||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
 | 
			
		||||
  at::xpu::assertNotCapturing(
 | 
			
		||||
      "Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
  constexpr size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t offset_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  at::detail::check_rng_state(new_state);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ import os
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
import tempfile
 | 
			
		||||
from typing import Callable
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
 | 
			
		||||
from torch._inductor.utils import fresh_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2284,9 +2284,11 @@ class BenchmarkRunner:
 | 
			
		||||
                    )
 | 
			
		||||
                ):
 | 
			
		||||
                    is_same = False
 | 
			
		||||
            except Exception:
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                # Sometimes torch.allclose may throw RuntimeError
 | 
			
		||||
                is_same = False
 | 
			
		||||
                exception_string = str(e)
 | 
			
		||||
                accuracy_status = f"fail_exception: {exception_string}"
 | 
			
		||||
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
 | 
			
		||||
 | 
			
		||||
            if not is_same:
 | 
			
		||||
                accuracy_status = "eager_two_runs_differ"
 | 
			
		||||
@ -2403,9 +2405,11 @@ class BenchmarkRunner:
 | 
			
		||||
                    force_max_multiplier=force_max_multiplier,
 | 
			
		||||
                ):
 | 
			
		||||
                    is_same = False
 | 
			
		||||
            except Exception:
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                # Sometimes torch.allclose may throw RuntimeError
 | 
			
		||||
                is_same = False
 | 
			
		||||
                exception_string = str(e)
 | 
			
		||||
                accuracy_status = f"fail_exception: {exception_string}"
 | 
			
		||||
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
 | 
			
		||||
 | 
			
		||||
            if not is_same:
 | 
			
		||||
                if self.args.skip_accuracy_check:
 | 
			
		||||
@ -4060,7 +4064,7 @@ def run(runner, args, original_dir=None):
 | 
			
		||||
        else:
 | 
			
		||||
            optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
 | 
			
		||||
        experiment = (
 | 
			
		||||
            speedup_experiment if not args.backend == "torchao" else latency_experiment
 | 
			
		||||
            speedup_experiment if args.backend != "torchao" else latency_experiment
 | 
			
		||||
        )
 | 
			
		||||
        if args.accuracy:
 | 
			
		||||
            output_filename = f"accuracy_{args.backend}.csv"
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,8 @@
 | 
			
		||||
import os
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Callable, Optional
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
from typing import Any, Callable
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,8 @@
 | 
			
		||||
import time
 | 
			
		||||
from argparse import ArgumentParser
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Any, Callable, NamedTuple
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from typing import Any, NamedTuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.autograd import functional
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Callable, Optional, Union
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn, Tensor
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import Callable, Optional
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
all_experiments: dict[str, Callable] = {}
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,9 @@ import logging
 | 
			
		||||
import time
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
from typing import Any, Callable, Optional
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
from tabulate import tabulate
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -7,6 +7,7 @@ from pt import (  # noqa: F401
 | 
			
		||||
    binary_inplace_test,
 | 
			
		||||
    binary_test,
 | 
			
		||||
    bmm_test,
 | 
			
		||||
    boolean_test,
 | 
			
		||||
    cat_test,
 | 
			
		||||
    channel_shuffle_test,
 | 
			
		||||
    chunk_test,
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,9 @@ binary_ops_list = op_bench.op_list(
 | 
			
		||||
        ["sub", torch.sub],
 | 
			
		||||
        ["div", torch.div],
 | 
			
		||||
        ["mul", torch.mul],
 | 
			
		||||
        ["asr", torch.bitwise_right_shift],
 | 
			
		||||
        ["lsl", torch.bitwise_left_shift],
 | 
			
		||||
        ["xor", torch.bitwise_xor],
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										73
									
								
								benchmarks/operator_benchmark/pt/boolean_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								benchmarks/operator_benchmark/pt/boolean_test.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,73 @@
 | 
			
		||||
import operator_benchmark as op_bench
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""Microbenchmarks for boolean operators. Supports both Caffe2/PyTorch."""
 | 
			
		||||
 | 
			
		||||
# Configs for PT all operator
 | 
			
		||||
all_long_configs = op_bench.cross_product_configs(
 | 
			
		||||
    M=[8, 128], N=[32, 64], K=[256, 512], device=["cpu", "cuda"], tags=["long"]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
all_short_configs = op_bench.config_list(
 | 
			
		||||
    attr_names=["M", "N", "K"],
 | 
			
		||||
    attrs=[
 | 
			
		||||
        [1, 1, 1],
 | 
			
		||||
        [64, 64, 64],
 | 
			
		||||
        [64, 64, 128],
 | 
			
		||||
    ],
 | 
			
		||||
    cross_product_configs={
 | 
			
		||||
        "device": ["cpu", "cuda"],
 | 
			
		||||
    },
 | 
			
		||||
    tags=["short"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AllBenchmark(op_bench.TorchBenchmarkBase):
 | 
			
		||||
    def init(self, M, N, K, device):
 | 
			
		||||
        self.inputs = {
 | 
			
		||||
            "input_one": torch.randint(0, 2, (M, N, K), device=device, dtype=torch.bool)
 | 
			
		||||
        }
 | 
			
		||||
        self.set_module_name("all")
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_one):
 | 
			
		||||
        return torch.all(input_one)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# The generated test names based on all_short_configs will be in the following pattern:
 | 
			
		||||
# all_M8_N16_K32_devicecpu
 | 
			
		||||
# all_M8_N16_K32_devicecpu_bwdall
 | 
			
		||||
# all_M8_N16_K32_devicecpu_bwd1
 | 
			
		||||
# all_M8_N16_K32_devicecpu_bwd2
 | 
			
		||||
# ...
 | 
			
		||||
# Those names can be used to filter tests.
 | 
			
		||||
 | 
			
		||||
op_bench.generate_pt_test(all_long_configs + all_short_configs, AllBenchmark)
 | 
			
		||||
 | 
			
		||||
"""Mircobenchmark for any operator."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AnyBenchmark(op_bench.TorchBenchmarkBase):
 | 
			
		||||
    def init(self, M, N, device):
 | 
			
		||||
        self.inputs = {
 | 
			
		||||
            "input_one": torch.randint(0, 2, (M, N), device=device, dtype=torch.bool)
 | 
			
		||||
        }
 | 
			
		||||
        self.set_module_name("any")
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_one):
 | 
			
		||||
        return torch.any(input_one)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
any_configs = op_bench.cross_product_configs(
 | 
			
		||||
    M=[8, 256],
 | 
			
		||||
    N=[256, 16],
 | 
			
		||||
    device=["cpu", "cuda"],
 | 
			
		||||
    tags=["any"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
op_bench.generate_pt_test(any_configs, AnyBenchmark)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    op_bench.benchmark_runner.main()
 | 
			
		||||
@ -38,12 +38,16 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
 | 
			
		||||
op_bench.generate_pt_test(
 | 
			
		||||
    configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
 | 
			
		||||
)
 | 
			
		||||
op_bench.generate_pt_test(
 | 
			
		||||
    configs.convtranspose_1d_configs_short
 | 
			
		||||
    + configs.conv_1d_configs_short
 | 
			
		||||
    + configs.conv_1d_configs_long,
 | 
			
		||||
    ConvTranspose1dBenchmark,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if not torch.backends.mkldnn.is_acl_available():
 | 
			
		||||
    # convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654
 | 
			
		||||
    op_bench.generate_pt_test(
 | 
			
		||||
        configs.convtranspose_1d_configs_short
 | 
			
		||||
        + configs.conv_1d_configs_short
 | 
			
		||||
        + configs.conv_1d_configs_long,
 | 
			
		||||
        ConvTranspose1dBenchmark,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,8 @@
 | 
			
		||||
import itertools
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from dataclasses import asdict, dataclass
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Callable, Union
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from tabulate import tabulate
 | 
			
		||||
 | 
			
		||||
@ -3,10 +3,11 @@ import csv
 | 
			
		||||
import itertools
 | 
			
		||||
import random
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
from dataclasses import asdict, dataclass
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Callable, Optional, Union
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from tabulate import tabulate
 | 
			
		||||
@ -270,7 +271,7 @@ def run_single_backend_sdpa(
 | 
			
		||||
 | 
			
		||||
        if config.calculate_bwd_time:
 | 
			
		||||
            # TODO: debug backward pass for njt
 | 
			
		||||
            if eager_sdpa and not config.attn_type == "document_mask":
 | 
			
		||||
            if eager_sdpa and config.attn_type != "document_mask":
 | 
			
		||||
                d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
 | 
			
		||||
                backward_eager_time = benchmark_torch_function_in_microseconds(
 | 
			
		||||
                    out_eager.backward, d_out, retain_graph=True
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,8 @@
 | 
			
		||||
import itertools
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
from dataclasses import asdict, dataclass
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
from tabulate import tabulate
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
@ -1729,8 +1729,10 @@ def define_buck_targets(
 | 
			
		||||
            "torch/csrc/jit/backends/backend_debug_info.cpp",
 | 
			
		||||
            "torch/csrc/jit/backends/backend_interface.cpp",
 | 
			
		||||
        ],
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags(),
 | 
			
		||||
        fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags() + select({
 | 
			
		||||
            "DEFAULT": [],
 | 
			
		||||
            "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
 | 
			
		||||
        }),
 | 
			
		||||
        # @lint-ignore BUCKLINT link_whole
 | 
			
		||||
        link_whole = True,
 | 
			
		||||
        linker_flags = get_no_as_needed_linker_flag(),
 | 
			
		||||
@ -2023,6 +2025,9 @@ def define_buck_targets(
 | 
			
		||||
                "ovr_config//os:android-x86_64": [
 | 
			
		||||
                    "-mssse3",
 | 
			
		||||
                ],
 | 
			
		||||
            }) + select({
 | 
			
		||||
                "DEFAULT": [],
 | 
			
		||||
                "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
 | 
			
		||||
            }),
 | 
			
		||||
            exported_preprocessor_flags = get_aten_preprocessor_flags(),
 | 
			
		||||
            exported_deps = [
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
#include <c10/core/AllocatorConfig.h>
 | 
			
		||||
#include <c10/core/DeviceType.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
 | 
			
		||||
namespace c10::CachingAllocator {
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
#include <c10/core/SymBool.h>
 | 
			
		||||
#include <c10/core/SymInt.h>
 | 
			
		||||
#include <c10/core/SymNodeImpl.h>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
 | 
			
		||||
  return toSymNodeImpl()->has_hint();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SymInt SymBool::toSymInt() const {
 | 
			
		||||
  // If concrete bool, return concrete SymInt
 | 
			
		||||
  if (auto ma = maybe_as_bool()) {
 | 
			
		||||
    return SymInt(*ma ? 1 : 0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Symbolic case: use sym_ite to convert bool to int (0 or 1)
 | 
			
		||||
  auto node = toSymNodeImpl();
 | 
			
		||||
  auto one_node = node->wrap_int(1);
 | 
			
		||||
  auto zero_node = node->wrap_int(0);
 | 
			
		||||
  return SymInt(node->sym_ite(one_node, zero_node));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,8 @@
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
class SymInt;
 | 
			
		||||
 | 
			
		||||
class C10_API SymBool {
 | 
			
		||||
 public:
 | 
			
		||||
  /*implicit*/ SymBool(bool b) : data_(b) {}
 | 
			
		||||
@ -80,6 +82,10 @@ class C10_API SymBool {
 | 
			
		||||
    return toSymNodeImplUnowned()->constant_bool();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Convert SymBool to SymInt (0 or 1)
 | 
			
		||||
  // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
 | 
			
		||||
  SymInt toSymInt() const;
 | 
			
		||||
 | 
			
		||||
  bool is_heap_allocated() const {
 | 
			
		||||
    return ptr_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,6 @@
 | 
			
		||||
#include <c10/core/SymNodeImpl.h>
 | 
			
		||||
#include <c10/util/intrusive_ptr.h>
 | 
			
		||||
#include <c10/util/safe_numerics.h>
 | 
			
		||||
#include <functional>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@
 | 
			
		||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
 | 
			
		||||
#include <c10/util/Logging.h>
 | 
			
		||||
#include <c10/util/accumulate.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,5 @@
 | 
			
		||||
#include <c10/core/TensorOptions.h>
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/core/Layout.h>
 | 
			
		||||
#include <c10/util/Optional.h>
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Allocator.h>
 | 
			
		||||
#include <c10/core/StorageImpl.h>
 | 
			
		||||
#include <c10/core/alignment.h>
 | 
			
		||||
#include <c10/core/impl/COWDeleter.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/ParallelGuard.h>
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
#include <c10/core/DispatchKey.h>
 | 
			
		||||
#include <c10/core/SafePyObject.h>
 | 
			
		||||
#include <c10/core/impl/LocalDispatchKeySet.h>
 | 
			
		||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
@ -1260,6 +1260,9 @@ class DeviceCachingAllocator {
 | 
			
		||||
  // thread local compile context for each device
 | 
			
		||||
  static thread_local std::stack<std::string> compile_context;
 | 
			
		||||
 | 
			
		||||
  // thread local user metadata for annotating allocations
 | 
			
		||||
  static thread_local std::string user_metadata;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 | 
			
		||||
  explicit DeviceCachingAllocator(c10::DeviceIndex id)
 | 
			
		||||
@ -1302,6 +1305,14 @@ class DeviceCachingAllocator {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void setUserMetadata(const std::string& metadata) {
 | 
			
		||||
    user_metadata = metadata;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string getUserMetadata() {
 | 
			
		||||
    return user_metadata;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool checkPoolLiveAllocations(
 | 
			
		||||
      MempoolId_t mempool_id,
 | 
			
		||||
      const std::unordered_set<void*>& expected_live_allocations) const {
 | 
			
		||||
@ -3682,7 +3693,8 @@ class DeviceCachingAllocator {
 | 
			
		||||
        mempool_id,
 | 
			
		||||
        getApproximateTime(),
 | 
			
		||||
        record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr,
 | 
			
		||||
        compile_string);
 | 
			
		||||
        compile_string,
 | 
			
		||||
        user_metadata);
 | 
			
		||||
 | 
			
		||||
    // Callbacks should not include any Pytorch call
 | 
			
		||||
    for (const auto& cb : trace_trackers_) {
 | 
			
		||||
@ -3737,6 +3749,7 @@ static void uncached_delete(void* ptr) {
 | 
			
		||||
 | 
			
		||||
static void local_raw_delete(void* ptr);
 | 
			
		||||
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
 | 
			
		||||
thread_local std::string DeviceCachingAllocator::user_metadata;
 | 
			
		||||
#ifdef __cpp_lib_hardware_interference_size
 | 
			
		||||
using std::hardware_destructive_interference_size;
 | 
			
		||||
#else
 | 
			
		||||
@ -3934,6 +3947,18 @@ class NativeCachingAllocator : public CUDAAllocator {
 | 
			
		||||
    device_allocator[device]->popCompileContext();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void setUserMetadata(const std::string& metadata) override {
 | 
			
		||||
    c10::DeviceIndex device = 0;
 | 
			
		||||
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
 | 
			
		||||
    device_allocator[device]->setUserMetadata(metadata);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string getUserMetadata() override {
 | 
			
		||||
    c10::DeviceIndex device = 0;
 | 
			
		||||
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
 | 
			
		||||
    return device_allocator[device]->getUserMetadata();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isHistoryEnabled() override {
 | 
			
		||||
    c10::DeviceIndex device = 0;
 | 
			
		||||
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
 | 
			
		||||
 | 
			
		||||
@ -118,7 +118,8 @@ struct TraceEntry {
 | 
			
		||||
      MempoolId_t mempool,
 | 
			
		||||
      approx_time_t time,
 | 
			
		||||
      std::shared_ptr<GatheredContext> context = nullptr,
 | 
			
		||||
      std::string compile_context = "")
 | 
			
		||||
      std::string compile_context = "",
 | 
			
		||||
      std::string user_metadata = "")
 | 
			
		||||
      : action_(action),
 | 
			
		||||
        device_(device),
 | 
			
		||||
        addr_(addr),
 | 
			
		||||
@ -126,7 +127,8 @@ struct TraceEntry {
 | 
			
		||||
        stream_(stream),
 | 
			
		||||
        size_(size),
 | 
			
		||||
        mempool_(std::move(mempool)),
 | 
			
		||||
        compile_context_(std::move(compile_context)) {
 | 
			
		||||
        compile_context_(std::move(compile_context)),
 | 
			
		||||
        user_metadata_(std::move(user_metadata)) {
 | 
			
		||||
    time_.approx_t_ = time;
 | 
			
		||||
  }
 | 
			
		||||
  Action action_;
 | 
			
		||||
@ -138,6 +140,7 @@ struct TraceEntry {
 | 
			
		||||
  MempoolId_t mempool_;
 | 
			
		||||
  trace_time_ time_{};
 | 
			
		||||
  std::string compile_context_;
 | 
			
		||||
  std::string user_metadata_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Calls made by record_function will save annotations
 | 
			
		||||
@ -297,6 +300,10 @@ class CUDAAllocator : public DeviceAllocator {
 | 
			
		||||
      const std::vector<std::pair<std::string, std::string>>& /*md*/) {}
 | 
			
		||||
  virtual void pushCompileContext(std::string& md) {}
 | 
			
		||||
  virtual void popCompileContext() {}
 | 
			
		||||
  virtual void setUserMetadata(const std::string& metadata) {}
 | 
			
		||||
  virtual std::string getUserMetadata() {
 | 
			
		||||
    return "";
 | 
			
		||||
  }
 | 
			
		||||
  virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
 | 
			
		||||
 | 
			
		||||
  // Attached AllocatorTraceTracker callbacks will be called while the
 | 
			
		||||
@ -536,6 +543,14 @@ inline void enablePeerAccess(
 | 
			
		||||
  get()->enablePeerAccess(dev, dev_to_access);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void setUserMetadata(const std::string& metadata) {
 | 
			
		||||
  get()->setUserMetadata(metadata);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline std::string getUserMetadata() {
 | 
			
		||||
  return get()->getUserMetadata();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10::cuda::CUDACachingAllocator
 | 
			
		||||
 | 
			
		||||
namespace c10::cuda {
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,6 @@
 | 
			
		||||
#include <c10/cuda/CUDADeviceAssertionHost.h>
 | 
			
		||||
#include <c10/cuda/CUDAException.h>
 | 
			
		||||
#include <c10/cuda/CUDAFunctions.h>
 | 
			
		||||
#include <c10/util/Backtrace.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,6 @@
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <c10/util/UniqueVoidPtr.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#include <c10/cuda/CUDAMiscFunctions.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
namespace c10::cuda {
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/CUDAException.h>
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <c10/util/CallOnce.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Logging.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user