mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			18 Commits
		
	
	
		
			codegen_tr
			...
			gh/karthic
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a08a36fd61 | |||
| 0512b080e6 | |||
| 45b4521b83 | |||
| a5b1ef2d1d | |||
| 702e868c31 | |||
| 36f4550411 | |||
| 1c02006361 | |||
| b33759a6c5 | |||
| d3e3e504cf | |||
| ca1160f112 | |||
| dd749f54c9 | |||
| 2bf5728496 | |||
| 452575e225 | |||
| 970c40c3c0 | |||
| fd4b78b142 | |||
| e6772939b0 | |||
| a2d2747597 | |||
| 8704f71a52 | 
@ -113,7 +113,6 @@ case "$tag" in
 | 
			
		||||
    UCX_COMMIT=${_UCX_COMMIT}
 | 
			
		||||
    UCC_COMMIT=${_UCC_COMMIT}
 | 
			
		||||
    TRITON=yes
 | 
			
		||||
    INSTALL_MINGW=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
 | 
			
		||||
    CUDA_VERSION=13.0.0
 | 
			
		||||
@ -362,7 +361,6 @@ docker build \
 | 
			
		||||
       --build-arg "OPENBLAS=${OPENBLAS:-}" \
 | 
			
		||||
       --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
 | 
			
		||||
       --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
 | 
			
		||||
       --build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
 | 
			
		||||
       -f $(dirname ${DOCKERFILE})/Dockerfile \
 | 
			
		||||
       -t "$tmp_tag" \
 | 
			
		||||
       "$@" \
 | 
			
		||||
 | 
			
		||||
@ -83,6 +83,10 @@ function build_cpython {
 | 
			
		||||
        py_suffix=${py_ver::-1}
 | 
			
		||||
        py_folder=$py_suffix
 | 
			
		||||
    fi
 | 
			
		||||
    # Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4
 | 
			
		||||
    if [ "$py_suffix" == "3.14.0" ]; then
 | 
			
		||||
        py_suffix="3.14.0rc2"
 | 
			
		||||
    fi
 | 
			
		||||
    wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
 | 
			
		||||
    do_cpython_build $py_ver Python-$py_suffix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,10 +0,0 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
set -ex
 | 
			
		||||
 | 
			
		||||
# Install MinGW-w64 for Windows cross-compilation
 | 
			
		||||
apt-get update
 | 
			
		||||
apt-get install -y g++-mingw-w64-x86-64-posix
 | 
			
		||||
 | 
			
		||||
echo "MinGW-w64 installed successfully"
 | 
			
		||||
x86_64-w64-mingw32-g++ --version
 | 
			
		||||
@ -20,7 +20,7 @@ pip_install \
 | 
			
		||||
 | 
			
		||||
pip_install coloredlogs packaging
 | 
			
		||||
pip_install onnxruntime==1.23.0
 | 
			
		||||
pip_install onnxscript==0.5.4
 | 
			
		||||
pip_install onnxscript==0.5.3
 | 
			
		||||
 | 
			
		||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
 | 
			
		||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
 | 
			
		||||
 | 
			
		||||
@ -39,13 +39,9 @@ case ${DOCKER_TAG_PREFIX} in
 | 
			
		||||
        DOCKER_GPU_BUILD_ARG=""
 | 
			
		||||
        ;;
 | 
			
		||||
    rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        BASE_TARGET=rocm
 | 
			
		||||
        GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
 | 
			
		||||
 | 
			
		||||
@ -75,13 +75,9 @@ case ${image} in
 | 
			
		||||
        DOCKERFILE_SUFFIX="_cuda_aarch64"
 | 
			
		||||
        ;;
 | 
			
		||||
    manylinux2_28-builder:rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        TARGET=rocm_final
 | 
			
		||||
        MANY_LINUX_VERSION="2_28"
 | 
			
		||||
 | 
			
		||||
@ -103,11 +103,6 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
 | 
			
		||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
 | 
			
		||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
 | 
			
		||||
 | 
			
		||||
ARG INSTALL_MINGW
 | 
			
		||||
COPY ./common/install_mingw.sh install_mingw.sh
 | 
			
		||||
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
 | 
			
		||||
RUN rm install_mingw.sh
 | 
			
		||||
 | 
			
		||||
ARG TRITON
 | 
			
		||||
ARG TRITON_CPU
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules
 | 
			
		||||
        logger.info("Successfully cloned %s", target)
 | 
			
		||||
        return r, commit
 | 
			
		||||
 | 
			
		||||
    except GitCommandError:
 | 
			
		||||
        logger.exception("Git operation failed")
 | 
			
		||||
    except GitCommandError as e:
 | 
			
		||||
        logger.error("Git operation failed: %s", e)
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ dependencies = [
 | 
			
		||||
    "GitPython==3.1.45",
 | 
			
		||||
    "docker==7.1.0",
 | 
			
		||||
    "pytest==7.3.2",
 | 
			
		||||
    "uv==0.9.5"
 | 
			
		||||
    "uv==0.8.6"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
 | 
			
		||||
@ -187,22 +187,19 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
 | 
			
		||||
            export USE_CUFILE=0
 | 
			
		||||
        else
 | 
			
		||||
            DEPS_LIST+=(
 | 
			
		||||
                "/usr/local/cuda/lib64/libnvToolsExt.so.1"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcublas.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcublasLt.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcudart.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libnvrtc.so.12"
 | 
			
		||||
                "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
 | 
			
		||||
            DEPS_SONAME+=(
 | 
			
		||||
                "libnvToolsExt.so.1"
 | 
			
		||||
                "libcublas.so.12"
 | 
			
		||||
                "libcublasLt.so.12"
 | 
			
		||||
                "libcudart.so.12"
 | 
			
		||||
                "libnvrtc.so.12"
 | 
			
		||||
                "libcupti.so.12")
 | 
			
		||||
 | 
			
		||||
            if [[ $CUDA_VERSION != 12.9* ]]; then
 | 
			
		||||
                DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
 | 
			
		||||
                DEPS_SONAME+=("libnvToolsExt.so.1")
 | 
			
		||||
            fi
 | 
			
		||||
        fi
 | 
			
		||||
    else
 | 
			
		||||
        echo "Using nvidia libs from pypi."
 | 
			
		||||
 | 
			
		||||
@ -485,22 +485,6 @@ test_inductor_aoti() {
 | 
			
		||||
  /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
test_inductor_aoti_cross_compile_for_windows() {
 | 
			
		||||
 | 
			
		||||
  TEST_REPORTS_DIR=$(pwd)/test/test-reports
 | 
			
		||||
  mkdir -p "$TEST_REPORTS_DIR"
 | 
			
		||||
 | 
			
		||||
  # Set WINDOWS_CUDA_HOME environment variable
 | 
			
		||||
  WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
 | 
			
		||||
  export WINDOWS_CUDA_HOME
 | 
			
		||||
 | 
			
		||||
  echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
 | 
			
		||||
  echo "Contents:"
 | 
			
		||||
  ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
 | 
			
		||||
 | 
			
		||||
  python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
test_inductor_cpp_wrapper_shard() {
 | 
			
		||||
  if [[ -z "$NUM_TEST_SHARDS" ]]; then
 | 
			
		||||
    echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
 | 
			
		||||
@ -916,7 +900,7 @@ test_inductor_set_cpu_affinity(){
 | 
			
		||||
  export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
 | 
			
		||||
  export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
 | 
			
		||||
 | 
			
		||||
  if [[ "$(uname -m)" != "aarch64" ]]; then
 | 
			
		||||
  if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
 | 
			
		||||
    # Use Intel OpenMP for x86
 | 
			
		||||
    IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
 | 
			
		||||
    export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
 | 
			
		||||
@ -930,7 +914,7 @@ test_inductor_set_cpu_affinity(){
 | 
			
		||||
  cores=$((cpus / thread_per_core))
 | 
			
		||||
 | 
			
		||||
  # Set number of cores to 16 on aarch64 for performance runs
 | 
			
		||||
  if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
 | 
			
		||||
  if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
 | 
			
		||||
    cores=16
 | 
			
		||||
  fi
 | 
			
		||||
  export OMP_NUM_THREADS=$cores
 | 
			
		||||
@ -1631,7 +1615,6 @@ test_operator_benchmark() {
 | 
			
		||||
  TEST_REPORTS_DIR=$(pwd)/test/test-reports
 | 
			
		||||
  mkdir -p "$TEST_REPORTS_DIR"
 | 
			
		||||
  TEST_DIR=$(pwd)
 | 
			
		||||
  ARCH=$(uname -m)
 | 
			
		||||
 | 
			
		||||
  test_inductor_set_cpu_affinity
 | 
			
		||||
 | 
			
		||||
@ -1646,7 +1629,7 @@ test_operator_benchmark() {
 | 
			
		||||
  pip_install pandas
 | 
			
		||||
  python check_perf_csv.py \
 | 
			
		||||
      --actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \
 | 
			
		||||
      --expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv"
 | 
			
		||||
      --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
test_operator_microbenchmark() {
 | 
			
		||||
@ -1683,7 +1666,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
 | 
			
		||||
    python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
 | 
			
		||||
  fi
 | 
			
		||||
  python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
 | 
			
		||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
 | 
			
		||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
 | 
			
		||||
  test_linux_aarch64
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
 | 
			
		||||
  test_forward_backward_compatibility
 | 
			
		||||
@ -1734,8 +1717,6 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
 | 
			
		||||
  test_inductor_triton_cpu
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
 | 
			
		||||
  test_inductor_micro_benchmark
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
 | 
			
		||||
  test_inductor_aoti_cross_compile_for_windows
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
 | 
			
		||||
  install_torchvision
 | 
			
		||||
  id=$((SHARD_NUMBER-1))
 | 
			
		||||
 | 
			
		||||
@ -163,13 +163,8 @@ if [[ "$(uname)" != Darwin ]]; then
 | 
			
		||||
  MEMORY_LIMIT_MAX_JOBS=12
 | 
			
		||||
  NUM_CPUS=$(( $(nproc) - 2 ))
 | 
			
		||||
 | 
			
		||||
  if [[ "$(uname)" == Linux ]]; then
 | 
			
		||||
    # Defaults here for **binary** linux builds so they can be changed in one place
 | 
			
		||||
    export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
 | 
			
		||||
  else
 | 
			
		||||
    # For other builds
 | 
			
		||||
    export MAX_JOBS=${NUM_CPUS}
 | 
			
		||||
  fi
 | 
			
		||||
  # Defaults here for **binary** linux builds so they can be changed in one place
 | 
			
		||||
  export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
 | 
			
		||||
 | 
			
		||||
  cat >>"$envfile" <<EOL
 | 
			
		||||
  export MAX_JOBS="${MAX_JOBS}"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								.flake8
									
									
									
									
									
								
							@ -7,12 +7,16 @@ max-line-length = 120
 | 
			
		||||
# C408 ignored because we like the dict keyword argument syntax
 | 
			
		||||
# E501 is not flexible enough, we're using B950 instead
 | 
			
		||||
ignore =
 | 
			
		||||
    E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
 | 
			
		||||
    E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
 | 
			
		||||
    # shebang has extra meaning in fbcode lints, so I think it's not worth trying
 | 
			
		||||
    # to line this up with executable bit
 | 
			
		||||
    EXE001,
 | 
			
		||||
    # these ignores are from flake8-bugbear; please fix!
 | 
			
		||||
    B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
 | 
			
		||||
    # these ignores are from flake8-comprehensions; please fix!
 | 
			
		||||
    C407,
 | 
			
		||||
    # these ignores are from flake8-logging-format; please fix!
 | 
			
		||||
    G100,G101,G200
 | 
			
		||||
    # these ignores are from flake8-simplify. please fix or ignore with commented reason
 | 
			
		||||
    SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
 | 
			
		||||
    # SIM104 is already covered by pyupgrade ruff
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							@ -8,7 +8,6 @@ assignees: ''
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
> NOTE: Remember to label this issue with "`ci: sev`"
 | 
			
		||||
>       If you want autorevert to be disabled, keep the ci: disable-autorevert label
 | 
			
		||||
 | 
			
		||||
 <!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open -->
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
---
 | 
			
		||||
name: "D❌\U0001F519 ISABLE AUTOREVERT"
 | 
			
		||||
name: DISABLE AUTOREVERT
 | 
			
		||||
about: Disables autorevert when open
 | 
			
		||||
title: "[DISABLE AUTOREVERT]"
 | 
			
		||||
title: "❌\U0001F519 [DISABLE AUTOREVERT]"
 | 
			
		||||
labels: 'ci: disable-autorevert'
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ runs:
 | 
			
		||||
          cd .ci/lumen_cli
 | 
			
		||||
          python3 -m pip install -e .
 | 
			
		||||
        )
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=10)"
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=6)"
 | 
			
		||||
        export MAX_JOBS
 | 
			
		||||
 | 
			
		||||
        # Split the comma-separated list and build each target
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
69bbe7363897764f9e758d851cd0340147d27f94
 | 
			
		||||
8ad2aa5d354d1bf432339113860185d5a5d1abbd
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
faffd5cf673615583da6517275e361cb3dbc77e6
 | 
			
		||||
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							@ -133,32 +133,3 @@
 | 
			
		||||
 | 
			
		||||
"ciflow/vllm":
 | 
			
		||||
- .github/ci_commit_pins/vllm.txt
 | 
			
		||||
 | 
			
		||||
"ciflow/b200":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/h100":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/rocm":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -3,7 +3,6 @@ ciflow_tracking_issue: 64124
 | 
			
		||||
ciflow_push_tags:
 | 
			
		||||
- ciflow/b200
 | 
			
		||||
- ciflow/b200-symm-mem
 | 
			
		||||
- ciflow/b200-distributed
 | 
			
		||||
- ciflow/binaries
 | 
			
		||||
- ciflow/binaries_libtorch
 | 
			
		||||
- ciflow/binaries_wheel
 | 
			
		||||
@ -16,8 +15,7 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/inductor-micro-benchmark
 | 
			
		||||
- ciflow/inductor-micro-benchmark-cpu-x86
 | 
			
		||||
- ciflow/inductor-perf-compare
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-x86-zen
 | 
			
		||||
- ciflow/inductor-periodic
 | 
			
		||||
- ciflow/inductor-rocm
 | 
			
		||||
@ -33,7 +31,6 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/rocm
 | 
			
		||||
- ciflow/rocm-mi300
 | 
			
		||||
- ciflow/rocm-mi355
 | 
			
		||||
- ciflow/rocm-navi31
 | 
			
		||||
- ciflow/s390
 | 
			
		||||
- ciflow/slow
 | 
			
		||||
- ciflow/torchbench
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										42
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										42
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
 | 
			
		||||
        "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "12.9": (
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
 | 
			
		||||
    ),
 | 
			
		||||
    "13.0": (
 | 
			
		||||
        "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
@ -241,11 +241,7 @@ def generate_libtorch_matrix(
 | 
			
		||||
            arches += CUDA_ARCHES
 | 
			
		||||
            arches += ROCM_ARCHES
 | 
			
		||||
        elif os == "windows":
 | 
			
		||||
            # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
 | 
			
		||||
            # in 2.10
 | 
			
		||||
            windows_cuda_arches = CUDA_ARCHES.copy()
 | 
			
		||||
            windows_cuda_arches.remove("12.9")
 | 
			
		||||
            arches += windows_cuda_arches
 | 
			
		||||
            arches += CUDA_ARCHES
 | 
			
		||||
    if libtorch_variants is None:
 | 
			
		||||
        libtorch_variants = [
 | 
			
		||||
            "shared-with-deps",
 | 
			
		||||
@ -309,11 +305,7 @@ def generate_wheels_matrix(
 | 
			
		||||
        if os == "linux":
 | 
			
		||||
            arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES
 | 
			
		||||
        elif os == "windows":
 | 
			
		||||
            # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
 | 
			
		||||
            # in 2.10
 | 
			
		||||
            windows_cuda_arches = CUDA_ARCHES.copy()
 | 
			
		||||
            windows_cuda_arches.remove("12.9")
 | 
			
		||||
            arches += windows_cuda_arches + XPU_ARCHES
 | 
			
		||||
            arches += CUDA_ARCHES + XPU_ARCHES
 | 
			
		||||
        elif os == "linux-aarch64":
 | 
			
		||||
            # Separate new if as the CPU type is different and
 | 
			
		||||
            # uses different build/test scripts
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							@ -1092,7 +1092,7 @@ class GitHubPR:
 | 
			
		||||
        editor = node["editor"]
 | 
			
		||||
        return GitHubComment(
 | 
			
		||||
            body_text=node["bodyText"],
 | 
			
		||||
            created_at=node.get("createdAt", ""),
 | 
			
		||||
            created_at=node["createdAt"] if "createdAt" in node else "",
 | 
			
		||||
            author_login=node["author"]["login"],
 | 
			
		||||
            author_url=node["author"].get("url", None),
 | 
			
		||||
            author_association=node["authorAssociation"],
 | 
			
		||||
@ -2042,6 +2042,10 @@ def validate_revert(
 | 
			
		||||
            f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Raises exception if matching rule is not found, but ignores all status checks
 | 
			
		||||
    find_matching_merge_rule(
 | 
			
		||||
        pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
 | 
			
		||||
    )
 | 
			
		||||
    commit_sha = get_pr_commit_sha(repo, pr)
 | 
			
		||||
    return (author_login, commit_sha)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,8 +26,9 @@ name: !{{ build_environment }}
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}"
 | 
			
		||||
          python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}"
 | 
			
		||||
          freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
 | 
			
		||||
{%- endmacro %}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,9 +79,9 @@ jobs:
 | 
			
		||||
    runs-on: "windows-11-arm64-preview"
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    {%- if branches == "nightly" %}
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
    timeout-minutes: !{{ common.timeout_minutes_windows_binary }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							@ -37,7 +37,7 @@ on:
 | 
			
		||||
      runner:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: "linux.c7i.2xlarge"
 | 
			
		||||
        default: "linux.2xlarge"
 | 
			
		||||
        description: |
 | 
			
		||||
          Label of the runner this job should run on.
 | 
			
		||||
      test-matrix:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							@ -224,46 +224,6 @@ jobs:
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        uses: ./.github/actions/download-td-artifacts
 | 
			
		||||
 | 
			
		||||
      - name: Download Windows torch wheel for cross-compilation
 | 
			
		||||
        if: matrix.win_torch_wheel_artifact != ''
 | 
			
		||||
        uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
 | 
			
		||||
        with:
 | 
			
		||||
          name: ${{ matrix.win_torch_wheel_artifact }}
 | 
			
		||||
          path: win-torch-wheel
 | 
			
		||||
 | 
			
		||||
      - name: Extract Windows wheel and setup CUDA libraries
 | 
			
		||||
        if: matrix.win_torch_wheel_artifact != ''
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -x
 | 
			
		||||
 | 
			
		||||
          # Find the wheel file
 | 
			
		||||
          WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
 | 
			
		||||
          if [ -z "$WHEEL_FILE" ]; then
 | 
			
		||||
            echo "Error: No wheel file found in win-torch-wheel directory"
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi
 | 
			
		||||
          echo "Found wheel file: $WHEEL_FILE"
 | 
			
		||||
 | 
			
		||||
          # Unzip the wheel file
 | 
			
		||||
          unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
 | 
			
		||||
          echo "Extracted wheel contents"
 | 
			
		||||
 | 
			
		||||
          # Setup CUDA libraries (cuda.lib and cudart.lib) directory
 | 
			
		||||
          mkdir -p win-torch-wheel-extracted/lib/x64
 | 
			
		||||
          if [ -f "win-torch-wheel/cuda.lib" ]; then
 | 
			
		||||
            mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
 | 
			
		||||
            echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
 | 
			
		||||
          fi
 | 
			
		||||
          if [ -f "win-torch-wheel/cudart.lib" ]; then
 | 
			
		||||
            mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
 | 
			
		||||
            echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # Verify CUDA libraries are present
 | 
			
		||||
          echo "CUDA libraries:"
 | 
			
		||||
          ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
 | 
			
		||||
 | 
			
		||||
      - name: Parse ref
 | 
			
		||||
        id: parse-ref
 | 
			
		||||
        run: .github/scripts/parse_ref.py
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							@ -168,31 +168,6 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          .ci/pytorch/win-build.sh
 | 
			
		||||
 | 
			
		||||
      # Collect Windows torch libs and CUDA libs for cross-compilation
 | 
			
		||||
      - name: Collect Windows CUDA libs for cross-compilation
 | 
			
		||||
        if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -ex
 | 
			
		||||
 | 
			
		||||
          # Create directory structure if does not exist
 | 
			
		||||
          mkdir -p /c/${{ github.run_id }}/build-results
 | 
			
		||||
 | 
			
		||||
          # Copy CUDA libs
 | 
			
		||||
          CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
 | 
			
		||||
 | 
			
		||||
          if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
 | 
			
		||||
            cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
 | 
			
		||||
            cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # List collected files
 | 
			
		||||
          echo "Collected CUDA libs:"
 | 
			
		||||
          ls -lah /c/${{ github.run_id }}/build-results/*.lib
 | 
			
		||||
 | 
			
		||||
      # Upload to github so that people can click and download artifacts
 | 
			
		||||
      - name: Upload artifacts to s3
 | 
			
		||||
        if: steps.build.outcome != 'skipped'
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
								
							@ -1,62 +0,0 @@
 | 
			
		||||
name: CI for distributed tests on B200
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  pull_request:
 | 
			
		||||
    paths:
 | 
			
		||||
      - .github/workflows/b200-distributed.yml
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/b200-distributed/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 46 8 * * *  # about 1:46am PDT
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions:
 | 
			
		||||
  id-token: write
 | 
			
		||||
  contents: read
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runner: linux.12xlarge.memory
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
 | 
			
		||||
      cuda-arch-list: '10.0'
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
 | 
			
		||||
    with:
 | 
			
		||||
      timeout-minutes: 1200
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
 | 
			
		||||
      aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -27,8 +27,9 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: [ '3.12' ]
 | 
			
		||||
        # TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
 | 
			
		||||
        platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129', 'cu130' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129' ]
 | 
			
		||||
        include:
 | 
			
		||||
          - platform: manylinux_2_28_x86_64
 | 
			
		||||
            device: cu128
 | 
			
		||||
@ -38,10 +39,6 @@ jobs:
 | 
			
		||||
            device: cu129
 | 
			
		||||
            manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
 | 
			
		||||
            runner: linux.12xlarge.memory
 | 
			
		||||
          - platform: manylinux_2_28_x86_64
 | 
			
		||||
            device: cu130
 | 
			
		||||
            manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
 | 
			
		||||
            runner: linux.12xlarge.memory
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu128
 | 
			
		||||
            manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
 | 
			
		||||
@ -50,11 +47,6 @@ jobs:
 | 
			
		||||
            device: cu129
 | 
			
		||||
            manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
 | 
			
		||||
            runner: linux.arm64.r7g.12xlarge.memory
 | 
			
		||||
        exclude:
 | 
			
		||||
          # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
 | 
			
		||||
          # xformers is update to support 13.0
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu130
 | 
			
		||||
    name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
 | 
			
		||||
    runs-on: ${{ matrix.runner }}
 | 
			
		||||
    timeout-minutes: 480
 | 
			
		||||
@ -177,12 +169,7 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129', 'cu130' ]
 | 
			
		||||
        exclude:
 | 
			
		||||
          # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
 | 
			
		||||
          # xformers is update to support 13.0
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu130
 | 
			
		||||
        device: [ 'cu128', 'cu129' ]
 | 
			
		||||
    env:
 | 
			
		||||
      PLATFORM: ${{ matrix.platform }}
 | 
			
		||||
      BUILD_DEVICE: ${{ matrix.device }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -224,7 +224,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -473,7 +473,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -722,7 +722,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -971,7 +971,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1220,7 +1220,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1469,7 +1469,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1718,7 +1718,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -259,7 +259,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_10-cuda12_9-test:  # Testing
 | 
			
		||||
@ -925,7 +925,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_11-cuda12_9-test:  # Testing
 | 
			
		||||
@ -1591,7 +1591,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_12-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2257,7 +2257,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2923,7 +2923,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13t-cuda12_9-test:  # Testing
 | 
			
		||||
@ -3589,7 +3589,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14-cuda12_9-test:  # Testing
 | 
			
		||||
@ -4255,7 +4255,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14t-cuda12_9-test:  # Testing
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -63,6 +63,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.10.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -59,6 +59,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.10.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -168,6 +169,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.11.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -277,6 +279,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.12.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -386,6 +389,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.13.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -495,6 +499,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.13.4"
 | 
			
		||||
          freethreaded: true
 | 
			
		||||
@ -604,8 +609,9 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.14.0"
 | 
			
		||||
          python-version: "3.14.0-rc.2"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
@ -713,8 +719,9 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.14.0"
 | 
			
		||||
          python-version: "3.14.0-rc.2"
 | 
			
		||||
          freethreaded: true
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										258
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										258
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  libtorch-cpu-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -291,7 +291,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_6-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -541,7 +541,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_8-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -788,10 +788,260 @@ jobs:
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: debug
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v4.4.0
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-debug
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-debug-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - libtorch-cuda12_9-shared-with-deps-debug-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: debug
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v4.1.7
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-debug
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-debug-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: libtorch-cuda12_9-shared-with-deps-debug-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      LIBTORCH_CONFIG: debug
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: libtorch-cuda12_9-shared-with-deps-debug
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										258
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										258
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  libtorch-cpu-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -291,7 +291,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_6-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -541,7 +541,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_8-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -788,10 +788,260 @@ jobs:
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v4.4.0
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-release
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-release-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - libtorch-cuda12_9-shared-with-deps-release-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v4.1.7
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-release
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-release-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: libtorch-cuda12_9-shared-with-deps-release-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: libtorch-cuda12_9-shared-with-deps-release
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1736
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1736
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,132 +0,0 @@
 | 
			
		||||
name: inductor-perf-nightly-rocm-mi300
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm-mi300/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 15 0 * * *
 | 
			
		||||
  # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
 | 
			
		||||
  # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      training:
 | 
			
		||||
        description: Run training (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      inference:
 | 
			
		||||
        description: Run inference (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      default:
 | 
			
		||||
        description: Run inductor_default?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      dynamic:
 | 
			
		||||
        description: Run inductor_dynamic_shapes?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cppwrapper:
 | 
			
		||||
        description: Run inductor_cpp_wrapper?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      freezing_cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs with freezing for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      aotinductor:
 | 
			
		||||
        description: Run aot_inductor for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      maxautotune:
 | 
			
		||||
        description: Run inductor_max_autotune?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      benchmark_configs:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
      opt_out_experiments: lf
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-inductor-benchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: rocm-py3_10-inductor-benchmark-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-inductor-benchmark-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: rocm-py3_10-inductor-benchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs: linux-jammy-rocm-py3_10-inductor-benchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }}
 | 
			
		||||
      timeout-minutes: 720
 | 
			
		||||
      # Disable monitor in perf tests for more investigation
 | 
			
		||||
      disable-monitor: true
 | 
			
		||||
      monitor-log-interval: 10
 | 
			
		||||
      monitor-data-collect-interval: 2
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
@ -1,11 +1,11 @@
 | 
			
		||||
name: inductor-perf-nightly-rocm-mi355
 | 
			
		||||
name: inductor-perf-nightly-rocm
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm-mi355/*
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 15 0 * * *
 | 
			
		||||
    - cron: 0 7 * * 0,3
 | 
			
		||||
  # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
 | 
			
		||||
  # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
@ -59,7 +59,7 @@ on:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355
 | 
			
		||||
        default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
@ -88,27 +88,23 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							@ -118,9 +118,9 @@ jobs:
 | 
			
		||||
        CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
 | 
			
		||||
        echo "Running all other linters"
 | 
			
		||||
        if [ "$CHANGED_FILES" = '*' ]; then
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
        else
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
 | 
			
		||||
        fi
 | 
			
		||||
 | 
			
		||||
  quick-checks:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										49
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										49
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							@ -7,11 +7,9 @@ on:
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      test_mode:
 | 
			
		||||
        type: choice
 | 
			
		||||
        options:
 | 
			
		||||
          - 'short'
 | 
			
		||||
          - 'long'
 | 
			
		||||
          - 'all'
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: 'short'
 | 
			
		||||
        description: tag filter for operator benchmarks, options from long, short, all
 | 
			
		||||
  schedule:
 | 
			
		||||
    # Run at 07:00 UTC every Sunday
 | 
			
		||||
@ -30,49 +28,38 @@ permissions:
 | 
			
		||||
  contents: read
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  x86-opbenchmark-build:
 | 
			
		||||
  opbenchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: x86-opbenchmark-build
 | 
			
		||||
    name: opbenchmark-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
          { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  x86-opbenchmark-test:
 | 
			
		||||
    name: x86-opbenchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: x86-opbenchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  aarch64-opbenchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: aarch64-opbenchmark-build
 | 
			
		||||
  opbenchmark-on-demand-build:
 | 
			
		||||
    if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: opbenchmark-on-demand-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-aarch64-py3.10
 | 
			
		||||
      runner: linux.arm64.m7g.4xlarge
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
 | 
			
		||||
          { config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  aarch64-opbenchmark-test:
 | 
			
		||||
    name: aarch64-opbenchmark-test
 | 
			
		||||
  opbenchmark-test:
 | 
			
		||||
    name: opbenchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: aarch64-opbenchmark-build
 | 
			
		||||
    needs: opbenchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-aarch64-py3.10
 | 
			
		||||
      docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -147,16 +147,15 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
 | 
			
		||||
      cuda-arch-list: 8.9
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							@ -45,12 +45,12 @@ jobs:
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										63
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										63
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							@ -1,63 +0,0 @@
 | 
			
		||||
name: rocm-navi31
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/rocm-navi31/*
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  schedule:
 | 
			
		||||
    # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
 | 
			
		||||
    # Also run less frequently on weekends.
 | 
			
		||||
    - cron: 45 */2 * * 1-5
 | 
			
		||||
    - cron: 45 4,12 * * 0,6
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  target-determination:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: before-test
 | 
			
		||||
    uses: ./.github/workflows/target_determination.yml
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3_10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
      tests-to-include: >-
 | 
			
		||||
         ${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
 | 
			
		||||
         test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
 | 
			
		||||
         inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
 | 
			
		||||
         inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
 | 
			
		||||
         inductor/test_flex_attention inductor/test_max_autotune' || '' }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										26
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										26
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							@ -59,3 +59,29 @@ jobs:
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-gfx1100-test:
 | 
			
		||||
    if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3_10-gfx1100
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
        ]}
 | 
			
		||||
      tests-to-include: >
 | 
			
		||||
         test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
 | 
			
		||||
         test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
 | 
			
		||||
         inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
 | 
			
		||||
         inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
 | 
			
		||||
         inductor/test_flex_attention inductor/test_max_autotune
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							@ -58,10 +58,8 @@ jobs:
 | 
			
		||||
          else
 | 
			
		||||
            COMMIT_SHA="${{ github.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
          {
 | 
			
		||||
            echo "sha=${COMMIT_SHA}"
 | 
			
		||||
            echo "tag_name=trunk/${COMMIT_SHA}"
 | 
			
		||||
          } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
 | 
			
		||||
      - name: Validate commit SHA
 | 
			
		||||
        run: |
 | 
			
		||||
@ -89,7 +87,7 @@ jobs:
 | 
			
		||||
            echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Create and push tag(s) with retry
 | 
			
		||||
      - name: Create and push tag with retry
 | 
			
		||||
        id: check_tag
 | 
			
		||||
        env:
 | 
			
		||||
          TAG_NAME: ${{ steps.commit.outputs.tag_name }}
 | 
			
		||||
@ -114,23 +112,14 @@ jobs:
 | 
			
		||||
            return 1
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # Counters for summary reporting
 | 
			
		||||
          created_count=0
 | 
			
		||||
          skipped_count=0
 | 
			
		||||
          failed_count=0
 | 
			
		||||
          # Exit early if tag already exists
 | 
			
		||||
          if check_tag_exists; then
 | 
			
		||||
            echo "✅ Tag already exists - no action needed"
 | 
			
		||||
            echo "exists=true" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # Always write outputs once on exit
 | 
			
		||||
          finish() {
 | 
			
		||||
            set +e
 | 
			
		||||
            if [ -n "${GITHUB_OUTPUT:-}" ]; then
 | 
			
		||||
              {
 | 
			
		||||
                echo "created_count=${created_count}"
 | 
			
		||||
                echo "skipped_count=${skipped_count}"
 | 
			
		||||
                echo "failed_count=${failed_count}"
 | 
			
		||||
              } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            fi
 | 
			
		||||
          }
 | 
			
		||||
          trap finish EXIT
 | 
			
		||||
          echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
          # Retry configuration
 | 
			
		||||
          MAX_RETRIES=5
 | 
			
		||||
@ -205,111 +194,31 @@ jobs:
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # New behavior for push events: enumerate commits in the push and tag each one.
 | 
			
		||||
          # For workflow_dispatch, retain existing single-SHA behavior.
 | 
			
		||||
 | 
			
		||||
          # Always fetch tags once up front to improve idempotency in loops
 | 
			
		||||
          git fetch origin --tags --quiet || true
 | 
			
		||||
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            BEFORE_SHA="${{ github.event.before }}"
 | 
			
		||||
            AFTER_SHA="${{ github.sha }}"  # same as event.after
 | 
			
		||||
 | 
			
		||||
            # List commits introduced by this push (old..new), oldest first for stable ordering
 | 
			
		||||
            commits_file="$(mktemp)"
 | 
			
		||||
            git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ ! -s "${commits_file}" ]; then
 | 
			
		||||
              echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
 | 
			
		||||
              rm -f "${commits_file}"
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
 | 
			
		||||
            echo "Found ${commit_count} commit(s) to tag for push:"
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              printf '  %s\n' "${sha}"
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              TAG_NAME="trunk/${sha}"
 | 
			
		||||
              COMMIT_SHA="${sha}"
 | 
			
		||||
 | 
			
		||||
              # If tag already exists locally or remotely, skip (idempotent)
 | 
			
		||||
              if check_tag_exists; then
 | 
			
		||||
                echo "✅ Tag ${TAG_NAME} already exists - skipping"
 | 
			
		||||
                skipped_count=$((skipped_count + 1))
 | 
			
		||||
                continue
 | 
			
		||||
              fi
 | 
			
		||||
 | 
			
		||||
              echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
              if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
                created_count=$((created_count + 1))
 | 
			
		||||
              else
 | 
			
		||||
                echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
 | 
			
		||||
                failed_count=$((failed_count + 1))
 | 
			
		||||
              fi
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            rm -f "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ "${failed_count}" -gt 0 ]; then
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
          # Execute with retry
 | 
			
		||||
          if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
            echo "exists=false" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          else
 | 
			
		||||
            # workflow_dispatch path (single SHA tagging preserved)
 | 
			
		||||
 | 
			
		||||
            # Exit early if tag already exists
 | 
			
		||||
            if check_tag_exists; then
 | 
			
		||||
              echo "✅ Tag already exists - no action needed"
 | 
			
		||||
              skipped_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
            if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
              created_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            else
 | 
			
		||||
              echo "Tag creation failed after all retry attempts"
 | 
			
		||||
              failed_count=1
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
            echo "Tag creation failed after all retry attempts"
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Tag creation summary
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            echo "Trigger: push on main"
 | 
			
		||||
            echo "Created: ${{ steps.check_tag.outputs.created_count }}"
 | 
			
		||||
            echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
 | 
			
		||||
            echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
          if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
 | 
			
		||||
            echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
          elif [ "${{ job.status }}" = "success" ]; then
 | 
			
		||||
            echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          else
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
 | 
			
		||||
                echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
              else
 | 
			
		||||
                echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
              fi
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo ""
 | 
			
		||||
            echo "Tag details:"
 | 
			
		||||
            echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
            echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
            if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
              echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
            echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          echo ""
 | 
			
		||||
          echo "Tag details:"
 | 
			
		||||
          echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
          echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
          if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
            echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										59
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										59
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -180,50 +180,16 @@ jobs:
 | 
			
		||||
      disable-monitor: false
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  win-vs2022-cuda12_8-py3-build:
 | 
			
		||||
    name: win-vs2022-cuda12.8-py3
 | 
			
		||||
  win-vs2022-cuda12_6-py3-build:
 | 
			
		||||
    name: win-vs2022-cuda12.6-py3
 | 
			
		||||
    uses: ./.github/workflows/_win-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: win-vs2022-cuda12.8-py3
 | 
			
		||||
      cuda-version: "12.8"
 | 
			
		||||
      build-environment: win-vs2022-cuda12.6-py3
 | 
			
		||||
      cuda-version: "12.6"
 | 
			
		||||
      runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
      tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  inductor-build:
 | 
			
		||||
    name: inductor-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
@ -234,23 +200,6 @@ jobs:
 | 
			
		||||
      cuda-arch-list: '8.0'
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  # Test cross-compiled models with Windows libs extracted from wheel
 | 
			
		||||
  cross-compile-linux-test:
 | 
			
		||||
    name: cross-compile-linux-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
      - win-vs2022-cuda12_8-py3-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  verify-cachebench-cpu-build:
 | 
			
		||||
    name: verify-cachebench-cpu-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -374,7 +374,6 @@ third_party/ruy/
 | 
			
		||||
third_party/glog/
 | 
			
		||||
 | 
			
		||||
# Virtualenv
 | 
			
		||||
.venv/
 | 
			
		||||
venv/
 | 
			
		||||
 | 
			
		||||
# Log files
 | 
			
		||||
@ -396,4 +395,3 @@ android/pytorch_android_torchvision/.cxx
 | 
			
		||||
CLAUDE.local.md
 | 
			
		||||
/test_*.py
 | 
			
		||||
/debug_*.py
 | 
			
		||||
CLAUDE_CONTEXT/
 | 
			
		||||
 | 
			
		||||
@ -209,46 +209,6 @@ command = [
 | 
			
		||||
    '@{{PATHSFILE}}'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'PYREFLY'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    'torch/**/*.py',
 | 
			
		||||
    'torch/**/*.pyi',
 | 
			
		||||
    'torchgen/**/*.py',
 | 
			
		||||
    'torchgen/**/*.pyi',
 | 
			
		||||
    'functorch/**/*.py',
 | 
			
		||||
    'functorch/**/*.pyi',
 | 
			
		||||
]
 | 
			
		||||
exclude_patterns = []
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pyrefly_linter.py',
 | 
			
		||||
    '--config=pyrefly.toml',
 | 
			
		||||
]
 | 
			
		||||
init_command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pip_init.py',
 | 
			
		||||
    '--dry-run={{DRYRUN}}',
 | 
			
		||||
    'numpy==2.1.0 ; python_version >= "3.12"',
 | 
			
		||||
    'expecttest==0.3.0',
 | 
			
		||||
    'pyrefly==0.36.2',
 | 
			
		||||
    'sympy==1.13.3',
 | 
			
		||||
    'types-requests==2.27.25',
 | 
			
		||||
    'types-pyyaml==6.0.2',
 | 
			
		||||
    'types-tabulate==0.8.8',
 | 
			
		||||
    'types-protobuf==5.29.1.20250403',
 | 
			
		||||
    'types-setuptools==79.0.0.20250422',
 | 
			
		||||
    'types-jinja2==2.11.9',
 | 
			
		||||
    'types-colorama==0.4.6',
 | 
			
		||||
    'filelock==3.18.0',
 | 
			
		||||
    'junitparser==2.1.1',
 | 
			
		||||
    'rich==14.1.0',
 | 
			
		||||
    'optree==0.17.0',
 | 
			
		||||
    'types-openpyxl==3.1.5.20250919',
 | 
			
		||||
    'types-python-dateutil==2.9.0.20251008'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'CLANGTIDY'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							@ -201,17 +201,3 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
 | 
			
		||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
 | 
			
		||||
/torch/headeronly/ @janeyx99
 | 
			
		||||
/torch/header_only_apis.txt @janeyx99
 | 
			
		||||
 | 
			
		||||
# FlexAttention
 | 
			
		||||
/torch/nn/attention/flex_attention.py @drisspg
 | 
			
		||||
/torch/_higher_order_ops/flex_attention.py @drisspg
 | 
			
		||||
/torch/_inductor/kernel/flex/ @drisspg
 | 
			
		||||
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_attention.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_decoding.py @drisspg
 | 
			
		||||
 | 
			
		||||
# Low Precision GEMMs
 | 
			
		||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
 | 
			
		||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
 | 
			
		||||
 | 
			
		||||
@ -256,7 +256,6 @@ endif()
 | 
			
		||||
IF(USE_FBGEMM_GENAI)
 | 
			
		||||
  set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
 | 
			
		||||
  set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
 | 
			
		||||
 | 
			
		||||
  if(USE_CUDA)
 | 
			
		||||
    # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
 | 
			
		||||
    # If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
 | 
			
		||||
@ -293,65 +292,58 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
    target_include_directories(fbgemm_genai PUBLIC
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
  else()
 | 
			
		||||
    if(USE_ROCM)
 | 
			
		||||
      # Only include the kernels we want to build to avoid increasing binary size.
 | 
			
		||||
      file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
 | 
			
		||||
        "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
 | 
			
		||||
        "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
 | 
			
		||||
      set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
 | 
			
		||||
 | 
			
		||||
    # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  elseif(USE_ROCM)
 | 
			
		||||
    # Only include the kernels we want to build to avoid increasing binary size.
 | 
			
		||||
    file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
 | 
			
		||||
    set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
 | 
			
		||||
      # Add additional HIPCC compiler flags for performance
 | 
			
		||||
      set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -amdgpu-coerce-illegal-types=1
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -enable-post-misched=0
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -greedy-reverse-local-assignment=1
 | 
			
		||||
        -fhip-new-launch-api)
 | 
			
		||||
 | 
			
		||||
    # Add additional HIPCC compiler flags for performance
 | 
			
		||||
    set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -enable-post-misched=0
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -greedy-reverse-local-assignment=1
 | 
			
		||||
      -fhip-new-launch-api)
 | 
			
		||||
    if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
 | 
			
		||||
        list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
 | 
			
		||||
      # Only compile for gfx942 for now.
 | 
			
		||||
      # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
      set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
 | 
			
		||||
      string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
 | 
			
		||||
      if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
 | 
			
		||||
        list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
 | 
			
		||||
      endif()
 | 
			
		||||
      set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
 | 
			
		||||
 | 
			
		||||
    # Only compile for gfx942 for now.
 | 
			
		||||
    # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
    set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
 | 
			
		||||
    string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
 | 
			
		||||
    if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
 | 
			
		||||
      list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
 | 
			
		||||
      hip_add_library(
 | 
			
		||||
        fbgemm_genai STATIC
 | 
			
		||||
        ${fbgemm_genai_native_rocm_hip}
 | 
			
		||||
        HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
 | 
			
		||||
      set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
 | 
			
		||||
      set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
      target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
 | 
			
		||||
 | 
			
		||||
      target_include_directories(fbgemm_genai PUBLIC
 | 
			
		||||
        # FBGEMM version of Composable Kernel is used due to some customizations
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/composable_kernel/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
        ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
        ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
      )
 | 
			
		||||
    endif()
 | 
			
		||||
    set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
 | 
			
		||||
 | 
			
		||||
    hip_add_library(
 | 
			
		||||
      fbgemm_genai STATIC
 | 
			
		||||
      ${fbgemm_genai_native_rocm_hip}
 | 
			
		||||
      HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
 | 
			
		||||
    set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
    target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      # FBGEMM version of Composable Kernel is used due to some customizations
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/composable_kernel/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
@ -700,6 +692,12 @@ if(USE_CUDA AND NOT USE_ROCM)
 | 
			
		||||
  list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
 | 
			
		||||
  list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
 | 
			
		||||
 | 
			
		||||
  # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
  if(USE_FBGEMM_GENAI)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  if($ENV{ATEN_STATIC_CUDA})
 | 
			
		||||
    if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
 | 
			
		||||
      list(APPEND ATen_CUDA_DEPENDENCY_LIBS
 | 
			
		||||
 | 
			
		||||
@ -389,16 +389,37 @@ void fillVersion<DLManagedTensorVersioned>(
 | 
			
		||||
// constructed out of ATen tensor
 | 
			
		||||
template <class T>
 | 
			
		||||
T* toDLPackImpl(const Tensor& src) {
 | 
			
		||||
  auto view = src;
 | 
			
		||||
 | 
			
		||||
  // Detect whether there is need to normalize the strides
 | 
			
		||||
  // Background: gh-83069
 | 
			
		||||
  //
 | 
			
		||||
  // However, normalizing strides can come at a high-cost
 | 
			
		||||
  // to slow down toDLPack conversion 3x, so we
 | 
			
		||||
  // only normalize if needed.
 | 
			
		||||
  //
 | 
			
		||||
  // The following code detects whether the src follows
 | 
			
		||||
  // a continuous pattern. If the src follows such pattern (common-case)
 | 
			
		||||
  // then we do not need to normalize the strides.
 | 
			
		||||
  bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
 | 
			
		||||
  // less common case, try normalizing the strides
 | 
			
		||||
  if (need_normalize_strides) {
 | 
			
		||||
    // create a new tensor with possibly normalized strides
 | 
			
		||||
    // gh-83069
 | 
			
		||||
    auto shape = src.sizes();
 | 
			
		||||
    view = src.as_strided(shape, {1}, src.storage_offset());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
 | 
			
		||||
  atDLMTensor->handle = src;
 | 
			
		||||
  atDLMTensor->handle = view;
 | 
			
		||||
  atDLMTensor->tensor.manager_ctx = atDLMTensor;
 | 
			
		||||
  atDLMTensor->tensor.deleter = &deleter<T>;
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.byte_offset = 0;
 | 
			
		||||
  fillVersion(&atDLMTensor->tensor);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,16 +52,16 @@ struct DLPackTraits {};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct DLPackTraits<DLManagedTensor> {
 | 
			
		||||
  inline static constexpr const char* capsule = "dltensor";
 | 
			
		||||
  inline static constexpr const char* used = "used_dltensor";
 | 
			
		||||
  inline static const char* capsule = "dltensor";
 | 
			
		||||
  inline static const char* used = "used_dltensor";
 | 
			
		||||
  inline static auto toDLPack = at::toDLPack;
 | 
			
		||||
  inline static auto fromDLPack = at::fromDLPack;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct DLPackTraits<DLManagedTensorVersioned> {
 | 
			
		||||
  inline static constexpr const char* capsule = "dltensor_versioned";
 | 
			
		||||
  inline static constexpr const char* used = "used_dltensor_versioned";
 | 
			
		||||
  inline static const char* capsule = "dltensor_versioned";
 | 
			
		||||
  inline static const char* used = "used_dltensor_versioned";
 | 
			
		||||
  inline static auto toDLPack = at::toDLPackVersioned;
 | 
			
		||||
  inline static auto fromDLPack = at::fromDLPackVersioned;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -42,14 +42,8 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool torch_function_mode_enabled() {
 | 
			
		||||
  // Manually flatten because gcc is refusing to inline here.  Note
 | 
			
		||||
  // that we are still calling __tls_get_addr twice here with GCC,
 | 
			
		||||
  // presumably because of
 | 
			
		||||
  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81501 (which says
 | 
			
		||||
  // the fix ships in GCC 16), but forcing inlining still improves
 | 
			
		||||
  // performance.
 | 
			
		||||
  const auto& ptfs = pythonTorchFunctionState;
 | 
			
		||||
  return ptfs.disabled_state_ != TorchFunctionDisabledState::ALL_DISABLED && !ptfs.stack_.empty();
 | 
			
		||||
  return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
 | 
			
		||||
         PythonTorchFunctionTLS::stack_len() > 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This is needed to disambiguate the ternary torch function disabled states
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,6 @@ struct TORCH_API PythonTorchFunctionTLS {
 | 
			
		||||
  TorchFunctionDisabledState disabled_state_ =
 | 
			
		||||
      TorchFunctionDisabledState::ENABLED;
 | 
			
		||||
  std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
 | 
			
		||||
  friend TORCH_API bool torch_function_mode_enabled();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_API bool torch_function_mode_enabled();
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ struct HostBlock {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename B>
 | 
			
		||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
 | 
			
		||||
struct alignas(64) FreeBlockList {
 | 
			
		||||
  std::mutex mutex_;
 | 
			
		||||
  std::deque<B*> list_;
 | 
			
		||||
};
 | 
			
		||||
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
 | 
			
		||||
// Struct containing memory allocator summary statistics for host, as they
 | 
			
		||||
// are staged for reporting. This is a temporary struct that is used to
 | 
			
		||||
// avoid locking the allocator while collecting stats.
 | 
			
		||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
 | 
			
		||||
struct alignas(64) HostStatsStaged {
 | 
			
		||||
  std::mutex timing_mutex_;
 | 
			
		||||
  // COUNT: total allocations (active + free)
 | 
			
		||||
  // LOCK: access to this stat is protected by the allocator's blocks_mutex_
 | 
			
		||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
 | 
			
		||||
  alignas(64) std::mutex blocks_mutex_;
 | 
			
		||||
  ska::flat_hash_set<B*> blocks_; // block list
 | 
			
		||||
  ska::flat_hash_map<void*, B*> ptr_to_block_;
 | 
			
		||||
 | 
			
		||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
  // size. This allows us to quickly find a free block of the right size.
 | 
			
		||||
  // We use deque to store per size free list and guard the list with its own
 | 
			
		||||
  // mutex.
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
 | 
			
		||||
  alignas(64) std::vector<FreeBlockList<B>> free_list_ =
 | 
			
		||||
      std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
 | 
			
		||||
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
 | 
			
		||||
  alignas(64) std::mutex events_mutex_;
 | 
			
		||||
  std::deque<std::pair<E, B*>> events_; // event queue paired with block
 | 
			
		||||
 | 
			
		||||
  // Indicates whether the object is active.
 | 
			
		||||
  // Set to false in the destructor to signal background threads to stop.
 | 
			
		||||
  std::atomic<bool> active_{true};
 | 
			
		||||
protected:
 | 
			
		||||
  alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
 | 
			
		||||
  alignas(64) HostStatsStaged stats_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API HostAllocator : public at::Allocator {
 | 
			
		||||
 | 
			
		||||
@ -229,10 +229,10 @@ private:
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  static constexpr uint32_t kPhilox10A = 0x9E3779B9;
 | 
			
		||||
  static constexpr uint32_t kPhilox10B = 0xBB67AE85;
 | 
			
		||||
  static constexpr uint32_t kPhiloxSA = 0xD2511F53;
 | 
			
		||||
  static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
 | 
			
		||||
  static const uint32_t kPhilox10A = 0x9E3779B9;
 | 
			
		||||
  static const uint32_t kPhilox10B = 0xBB67AE85;
 | 
			
		||||
  static const uint32_t kPhiloxSA = 0xD2511F53;
 | 
			
		||||
  static const uint32_t kPhiloxSB = 0xCD9E8D57;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
typedef philox_engine Philox4_32;
 | 
			
		||||
 | 
			
		||||
@ -624,14 +624,7 @@ struct TORCH_API IValue final {
 | 
			
		||||
  IValue(const c10::SymBool& i) {
 | 
			
		||||
    if (auto mi = i.maybe_as_bool()) {
 | 
			
		||||
      tag = Tag::Bool;
 | 
			
		||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
 | 
			
		||||
      payload.u.as_int = *mi;
 | 
			
		||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
 | 
			
		||||
      /* due to byteorder if value assigned as_int, as_bool actually is not set correctly */
 | 
			
		||||
      payload.u.as_bool = *mi;
 | 
			
		||||
#else
 | 
			
		||||
#error Unexpected or undefined __BYTE_ORDER__
 | 
			
		||||
#endif
 | 
			
		||||
    } else {
 | 
			
		||||
      tag = Tag::SymBool;
 | 
			
		||||
      payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
 | 
			
		||||
 | 
			
		||||
@ -8,8 +8,6 @@
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
 | 
			
		||||
 | 
			
		||||
@ -1,794 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
#define VEC_INT_NEON_TEMPLATE(vl, bit)                                        \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {};  \
 | 
			
		||||
                                                                              \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  class Vectorized<int##bit##_t> {                                            \
 | 
			
		||||
    using neon_type = int##bit##x##vl##_t;                                    \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   private:                                                                   \
 | 
			
		||||
    neon_type values;                                                         \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   public:                                                                    \
 | 
			
		||||
    using value_type = int##bit##_t;                                          \
 | 
			
		||||
    using size_type = int;                                                    \
 | 
			
		||||
    static constexpr size_type size() {                                       \
 | 
			
		||||
      return vl;                                                              \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized() {                                                            \
 | 
			
		||||
      values = vdupq_n_s##bit(0);                                             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized(neon_type v) : values(v) {}                                    \
 | 
			
		||||
    Vectorized(int##bit##_t val);                                             \
 | 
			
		||||
    template <                                                                \
 | 
			
		||||
        typename... Args,                                                     \
 | 
			
		||||
        typename = std::enable_if_t<(sizeof...(Args) == size())>>             \
 | 
			
		||||
    Vectorized(Args... vals) {                                                \
 | 
			
		||||
      __at_align__ int##bit##_t buffer[size()] = {vals...};                   \
 | 
			
		||||
      values = vld1q_s##bit(buffer);                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    operator neon_type() const {                                              \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    static Vectorized<int##bit##_t> loadu(                                    \
 | 
			
		||||
        const void* ptr,                                                      \
 | 
			
		||||
        int64_t count = size());                                              \
 | 
			
		||||
    void store(void* ptr, int64_t count = size()) const;                      \
 | 
			
		||||
    template <int64_t mask>                                                   \
 | 
			
		||||
    static Vectorized<int##bit##_t> blend(                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& a,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& b);                                   \
 | 
			
		||||
    static Vectorized<int##bit##_t> blendv(                                   \
 | 
			
		||||
        const Vectorized<int##bit##_t>& a,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& b,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& mask_) {                              \
 | 
			
		||||
      return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    template <typename step_t>                                                \
 | 
			
		||||
    static Vectorized<int##bit##_t> arange(                                   \
 | 
			
		||||
        value_type base = 0,                                                  \
 | 
			
		||||
        step_t step = static_cast<step_t>(1));                                \
 | 
			
		||||
    static Vectorized<int##bit##_t> set(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& a,                                    \
 | 
			
		||||
        const Vectorized<int##bit##_t>& b,                                    \
 | 
			
		||||
        int64_t count = size());                                              \
 | 
			
		||||
    const int##bit##_t& operator[](int idx) const = delete;                   \
 | 
			
		||||
    int##bit##_t& operator[](int idx) = delete;                               \
 | 
			
		||||
    Vectorized<int##bit##_t> abs() const {                                    \
 | 
			
		||||
      return vabsq_s##bit(values);                                            \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> real() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> imag() const {                                   \
 | 
			
		||||
      return vdupq_n_s##bit(0);                                               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> conj() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> neg() const {                                    \
 | 
			
		||||
      return vnegq_s##bit(values);                                            \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    int##bit##_t reduce_add() const {                                         \
 | 
			
		||||
      return vaddvq_s##bit(values);                                           \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    int##bit##_t reduce_max() const;                                          \
 | 
			
		||||
    Vectorized<int##bit##_t> operator==(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator!=(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const;                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator<(                                       \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator<=(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator>(                                       \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> operator>=(                                      \
 | 
			
		||||
        const Vectorized<int##bit##_t>& other) const {                        \
 | 
			
		||||
      return Vectorized<value_type>(                                          \
 | 
			
		||||
          vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
    Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
 | 
			
		||||
  };                                                                          \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator+(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vaddq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator-(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vsubq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator&(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vandq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator|(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return vorrq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<int##bit##_t> inline operator^(                                  \
 | 
			
		||||
      const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
 | 
			
		||||
    return veorq_s##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this == other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this != other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this > other) & Vectorized<int##bit##_t>(1);                     \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this >= other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this < other) & Vectorized<int##bit##_t>(1);                     \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le(               \
 | 
			
		||||
      const Vectorized<int##bit##_t>& other) const {                          \
 | 
			
		||||
    return (*this <= other) & Vectorized<int##bit##_t>(1);                    \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(2, 64)
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(4, 32)
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(8, 16)
 | 
			
		||||
VEC_INT_NEON_TEMPLATE(16, 8)
 | 
			
		||||
 | 
			
		||||
inline int32_t Vectorized<int32_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_s32(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int16_t Vectorized<int16_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_s16(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int8_t Vectorized<int8_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_s8(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator*(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  return vmulq_s32(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator*(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  return vmulq_s16(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator*(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  return vmulq_s8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
 | 
			
		||||
  int64x2_t val = a;
 | 
			
		||||
  return ~val;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
 | 
			
		||||
  return vmvnq_s32(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
 | 
			
		||||
  return vmvnq_s16(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
 | 
			
		||||
  return vmvnq_s8(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
 | 
			
		||||
    const Vectorized<int64_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
 | 
			
		||||
    const Vectorized<int32_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
 | 
			
		||||
    const Vectorized<int16_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
 | 
			
		||||
    const Vectorized<int8_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline minimum(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  return vminq_s32(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline minimum(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  return vminq_s16(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline minimum(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  return vminq_s8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline maximum(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  return vmaxq_s32(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline maximum(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  return vmaxq_s16(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline maximum(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  return vmaxq_s8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int64_t> Vectorized<int64_t>::blend(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint64x2_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s64(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int32_t> Vectorized<int32_t>::blend(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint32x4_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFFFFFFFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s32(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int16_t> Vectorized<int16_t>::blend(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint16x8_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFFFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFFFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s16(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int64_t mask>
 | 
			
		||||
Vectorized<int8_t> Vectorized<int8_t>::blend(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint8x16_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 256LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 512LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 1024LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2048LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4096LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8192LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16384LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32768LL) ? 0xFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_s8(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define VEC_INT_NEON_OPS(vl, bit)                                             \
 | 
			
		||||
  inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) {             \
 | 
			
		||||
    values = vdupq_n_s##bit(val);                                             \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu(            \
 | 
			
		||||
      const void* ptr, int64_t count) {                                       \
 | 
			
		||||
    if (count == size()) {                                                    \
 | 
			
		||||
      return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr));        \
 | 
			
		||||
    } else {                                                                  \
 | 
			
		||||
      __at_align__ int##bit##_t tmp_values[size()];                           \
 | 
			
		||||
      for (const auto i : c10::irange(size())) {                              \
 | 
			
		||||
        tmp_values[i] = 0;                                                    \
 | 
			
		||||
      }                                                                       \
 | 
			
		||||
      std::memcpy(                                                            \
 | 
			
		||||
          tmp_values,                                                         \
 | 
			
		||||
          reinterpret_cast<const int##bit##_t*>(ptr),                         \
 | 
			
		||||
          count * sizeof(int##bit##_t));                                      \
 | 
			
		||||
      return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count)       \
 | 
			
		||||
      const {                                                                 \
 | 
			
		||||
    if (count == size()) {                                                    \
 | 
			
		||||
      vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values);             \
 | 
			
		||||
    } else {                                                                  \
 | 
			
		||||
      int##bit##_t tmp_values[size()];                                        \
 | 
			
		||||
      vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values);      \
 | 
			
		||||
      std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t));             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_INT_NEON_OPS(2, 64)
 | 
			
		||||
VEC_INT_NEON_OPS(4, 32)
 | 
			
		||||
VEC_INT_NEON_OPS(8, 16)
 | 
			
		||||
VEC_INT_NEON_OPS(16, 8)
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator*(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return x * y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator/(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator/(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  int32x4_t x = a;
 | 
			
		||||
  int32x4_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int64_t Vectorized<int64_t>::reduce_max() const {
 | 
			
		||||
  return std::max(values[0], values[1]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline minimum(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return {std::min(x[0], y[0]), std::min(x[1], y[1])};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline maximum(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  return {std::max(x[0], y[0]), std::max(x[1], y[1])};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
 | 
			
		||||
    int64_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<int64_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int64_t> step_vec(step);
 | 
			
		||||
  const int64x2_t step_sizes = {0, 1};
 | 
			
		||||
  return base_vec.values + step_sizes * step_vec.values;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
 | 
			
		||||
    int32_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<int32_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int32_t> step_vec(step);
 | 
			
		||||
  const int32x4_t step_sizes = {0, 1, 2, 3};
 | 
			
		||||
  return vmlaq_s32(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
 | 
			
		||||
    int16_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<int16_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int16_t> step_vec(step);
 | 
			
		||||
  const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
 | 
			
		||||
  return vmlaq_s16(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
 | 
			
		||||
  const Vectorized<int8_t> base_vec(base);
 | 
			
		||||
  const Vectorized<int8_t> step_vec(step);
 | 
			
		||||
  const int8x16_t step_sizes = {
 | 
			
		||||
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
 | 
			
		||||
  return vmlaq_s8(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t x = a;
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  uint64x2_t u = vreinterpretq_u64_s64(y);
 | 
			
		||||
  uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
 | 
			
		||||
  return x >> vreinterpretq_s64_u64(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  int32x4_t x = a;
 | 
			
		||||
  int32x4_t y = b;
 | 
			
		||||
  uint32x4_t bound = vdupq_n_u32(31);
 | 
			
		||||
  uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
 | 
			
		||||
  return x >> vreinterpretq_s32_u32(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  int16x8_t x = a;
 | 
			
		||||
  int16x8_t y = b;
 | 
			
		||||
  uint16x8_t bound = vdupq_n_u16(15);
 | 
			
		||||
  uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
 | 
			
		||||
  return x >> vreinterpretq_s16_u16(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator>>(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  int8x16_t x = a;
 | 
			
		||||
  int8x16_t y = b;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(7);
 | 
			
		||||
  int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
 | 
			
		||||
  return x >> z;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b) {
 | 
			
		||||
  int64x2_t y = b;
 | 
			
		||||
  uint64x2_t u = vreinterpretq_u64_s64(y);
 | 
			
		||||
  uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
 | 
			
		||||
  return vshlq_s64(a, vreinterpretq_s64_u64(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b) {
 | 
			
		||||
  int32x4_t y = b;
 | 
			
		||||
  uint32x4_t bound = vdupq_n_u32(32);
 | 
			
		||||
  uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
 | 
			
		||||
  return vshlq_s32(a, vreinterpretq_s32_u32(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  int16x8_t y = b;
 | 
			
		||||
  uint16x8_t bound = vdupq_n_u16(16);
 | 
			
		||||
  uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
 | 
			
		||||
  return vshlq_s16(a, vreinterpretq_s16_u16(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator<<(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  int8x16_t y = b;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
 | 
			
		||||
  return vshlq_s8(a, z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int64_t> Vectorized<int64_t>::set(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 2) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    int64x2_t c = {b.values[0], a.values[1]};
 | 
			
		||||
    return c;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int32_t> Vectorized<int32_t>::set(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 4) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint32x4_t maskArray = {
 | 
			
		||||
        (count >= 1LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
        (count >= 2LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
        (count >= 3LL) ? 0xFFFFFFFF : 0,
 | 
			
		||||
        0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_s32(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int16_t> Vectorized<int16_t>::set(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 8) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint16x8_t maskArray = {
 | 
			
		||||
        static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
 | 
			
		||||
        static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_s16(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<int8_t> Vectorized<int8_t>::set(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b,
 | 
			
		||||
    int64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 16) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint8x16_t maskArray = {
 | 
			
		||||
        static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_s8(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline operator/(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& b) {
 | 
			
		||||
  Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
 | 
			
		||||
  Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
 | 
			
		||||
  Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
 | 
			
		||||
  Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
 | 
			
		||||
  int32x4_t highBitsResult = highBitsA / highBitsB;
 | 
			
		||||
  int32x4_t lowBitsResult = lowBitsA / lowBitsB;
 | 
			
		||||
  return vuzp1q_s16(
 | 
			
		||||
      vreinterpretq_s16_s32(lowBitsResult),
 | 
			
		||||
      vreinterpretq_s16_s32(highBitsResult));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline operator/(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& b) {
 | 
			
		||||
  Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
 | 
			
		||||
  Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
 | 
			
		||||
  Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
 | 
			
		||||
  Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
 | 
			
		||||
  int16x8_t highBitsResult = highBitsA / highBitsB;
 | 
			
		||||
  int16x8_t lowBitsResult = lowBitsA / lowBitsB;
 | 
			
		||||
  return vuzp1q_s8(
 | 
			
		||||
      vreinterpretq_s8_s16(lowBitsResult),
 | 
			
		||||
      vreinterpretq_s8_s16(highBitsResult));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline clamp(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& min,
 | 
			
		||||
    const Vectorized<int64_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline clamp(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& min,
 | 
			
		||||
    const Vectorized<int32_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline clamp(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& min,
 | 
			
		||||
    const Vectorized<int16_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline clamp(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& min,
 | 
			
		||||
    const Vectorized<int8_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int64_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int64_t>& a,
 | 
			
		||||
    const Vectorized<int64_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int32_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int32_t>& a,
 | 
			
		||||
    const Vectorized<int32_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int16_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int16_t>& a,
 | 
			
		||||
    const Vectorized<int16_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<int8_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<int8_t>& a,
 | 
			
		||||
    const Vectorized<int8_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1,378 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
 | 
			
		||||
                                                                              \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  class Vectorized<uint##bit##_t> {                                           \
 | 
			
		||||
    using neon_type = uint##bit##x##vl##_t;                                   \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   private:                                                                   \
 | 
			
		||||
    neon_type values;                                                         \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   public:                                                                    \
 | 
			
		||||
    using value_type = uint##bit##_t;                                         \
 | 
			
		||||
    using size_type = int;                                                    \
 | 
			
		||||
    static constexpr size_type size() {                                       \
 | 
			
		||||
      return vl;                                                              \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized() {                                                            \
 | 
			
		||||
      values = vdupq_n_u##bit(0);                                             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized(neon_type v) : values(v) {}                                    \
 | 
			
		||||
    Vectorized(uint##bit##_t val);                                            \
 | 
			
		||||
    template <                                                                \
 | 
			
		||||
        typename... Args,                                                     \
 | 
			
		||||
        typename = std::enable_if_t<(sizeof...(Args) == size())>>             \
 | 
			
		||||
    Vectorized(Args... vals) {                                                \
 | 
			
		||||
      __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \
 | 
			
		||||
      values = vld1q_u##bit(buffer);                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    operator neon_type() const {                                              \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    static Vectorized<uint##bit##_t> loadu(                                   \
 | 
			
		||||
        const void* ptr,                                                      \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    void store(void* ptr, uint64_t count = size()) const;                     \
 | 
			
		||||
    template <uint64_t mask>                                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blend(                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b);                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blendv(                                  \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& mask_) {                             \
 | 
			
		||||
      return vbslq_u##bit(mask_.values, b, a);                                \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    template <typename step_t>                                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> arange(                                  \
 | 
			
		||||
        value_type base = 0,                                                  \
 | 
			
		||||
        step_t step = static_cast<step_t>(1));                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> set(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    const uint##bit##_t& operator[](uint idx) const = delete;                 \
 | 
			
		||||
    uint##bit##_t& operator[](uint idx) = delete;                             \
 | 
			
		||||
    Vectorized<uint##bit##_t> abs() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> real() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> imag() const {                                  \
 | 
			
		||||
      return vdupq_n_u##bit(0);                                               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> conj() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> neg() const {                                   \
 | 
			
		||||
      return vreinterpretq_u##bit##_s##bit(                                   \
 | 
			
		||||
          vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_add() const {                                        \
 | 
			
		||||
      return vaddvq_u##bit(values);                                           \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_max() const;                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator==(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator!=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> eq(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ne(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> gt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ge(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> lt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> le(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
  };                                                                          \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator+(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vaddq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator-(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vsubq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator&(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vandq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator|(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vorrq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator^(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return veorq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this == other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this != other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this > other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this < other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_TEMPLATE(16, 8)
 | 
			
		||||
 | 
			
		||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_u8(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator*(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmulq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
 | 
			
		||||
  return vmvnq_u8(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
 | 
			
		||||
    const Vectorized<uint8_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline minimum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vminq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline maximum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmaxq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <uint64_t mask>
 | 
			
		||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint8x16_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 256LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 512LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 1024LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2048LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4096LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8192LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16384LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32768LL) ? 0xFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_OPS(vl, bit)                                             \
 | 
			
		||||
  inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \
 | 
			
		||||
    values = vdupq_n_u##bit(val);                                              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \
 | 
			
		||||
      const void* ptr, uint64_t count) {                                       \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      __at_align__ uint##bit##_t tmp_values[size()];                           \
 | 
			
		||||
      for (const auto i : c10::irange(size())) {                               \
 | 
			
		||||
        tmp_values[i] = 0;                                                     \
 | 
			
		||||
      }                                                                        \
 | 
			
		||||
      std::memcpy(                                                             \
 | 
			
		||||
          tmp_values,                                                          \
 | 
			
		||||
          reinterpret_cast<const uint##bit##_t*>(ptr),                         \
 | 
			
		||||
          count * sizeof(uint##bit##_t));                                      \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \
 | 
			
		||||
      const {                                                                  \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      uint##bit##_t tmp_values[size()];                                        \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \
 | 
			
		||||
      std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_OPS(16, 8)
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
 | 
			
		||||
    uint8_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<uint8_t> base_vec(base);
 | 
			
		||||
  const Vectorized<uint8_t> step_vec(step);
 | 
			
		||||
  const uint8x16_t step_sizes = {
 | 
			
		||||
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
 | 
			
		||||
  return vmlaq_u8(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator>>(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return x >> z;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator<<(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return vshlq_u8(a, vreinterpretq_s8_u8(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b,
 | 
			
		||||
    uint64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 16) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint8x16_t maskArray = {
 | 
			
		||||
        static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator/(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
 | 
			
		||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
 | 
			
		||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
    at::vec::Vectorized<int8_t> src) {
 | 
			
		||||
  auto s8x8 = vget_low_s8(src);
 | 
			
		||||
  auto s8x8 = vld1_s8(src.operator const int8_t*());
 | 
			
		||||
  auto s16x8 = vmovl_s8(s8x8);
 | 
			
		||||
 | 
			
		||||
  auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
 | 
			
		||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
 | 
			
		||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
 | 
			
		||||
Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
    at::vec::Vectorized<int8_t> src) {
 | 
			
		||||
  auto s8x8 = vget_low_s8(src);
 | 
			
		||||
  auto s8x8 = vld1_s8(src.operator const int8_t*());
 | 
			
		||||
  auto s16x8 = vmovl_s8(s8x8);
 | 
			
		||||
 | 
			
		||||
  auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
 | 
			
		||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
 | 
			
		||||
Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,8 +16,6 @@
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/detail/BLASConstants.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
#include <c10/cuda/CUDAStream.h>
 | 
			
		||||
#include <hipblaslt/hipblaslt-ext.hpp>
 | 
			
		||||
@ -1956,15 +1954,13 @@ void scaled_gemm(
 | 
			
		||||
    const void *result_scale_ptr,
 | 
			
		||||
    int64_t result_ld,
 | 
			
		||||
    ScalarType result_dtype,
 | 
			
		||||
    bool use_fast_accum,
 | 
			
		||||
    const std::optional<Tensor>& alpha) {
 | 
			
		||||
    bool use_fast_accum) {
 | 
			
		||||
  // Note: see `cublasCommonArgs` for various non-intuitive manupulations
 | 
			
		||||
  // of input arguments to this function.
 | 
			
		||||
  const auto computeType = CUBLAS_COMPUTE_32F;
 | 
			
		||||
  const auto scaleType = CUDA_R_32F;
 | 
			
		||||
  // Note: alpha_val may change later depending on user-passed argument
 | 
			
		||||
  float alpha_val = 1.0;
 | 
			
		||||
  float beta_val = 0.0;
 | 
			
		||||
  const float alpha_val = 1.0;
 | 
			
		||||
  const float beta_val = 0.0;
 | 
			
		||||
  CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
 | 
			
		||||
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
 | 
			
		||||
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
 | 
			
		||||
@ -2035,33 +2031,6 @@ void scaled_gemm(
 | 
			
		||||
    computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
 | 
			
		||||
    computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Handle user-passed alpha
 | 
			
		||||
  float *alpha_ptr = &alpha_val;
 | 
			
		||||
  float *beta_ptr = &beta_val;
 | 
			
		||||
 | 
			
		||||
  if (alpha.has_value()) {
 | 
			
		||||
    auto& a = alpha.value();
 | 
			
		||||
 | 
			
		||||
    // if device-tensor
 | 
			
		||||
    if (a.is_cuda()) {
 | 
			
		||||
      // NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
 | 
			
		||||
      //       valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
 | 
			
		||||
      //       we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
 | 
			
		||||
      //       managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
 | 
			
		||||
      //       for beta respectively.
 | 
			
		||||
      float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
 | 
			
		||||
      at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
 | 
			
		||||
      user_alpha.copy_(a);
 | 
			
		||||
      // Tell cublasLt we're using device-side pointers for alpha/beta
 | 
			
		||||
      auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
 | 
			
		||||
      computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
 | 
			
		||||
      alpha_ptr = user_alpha.data_ptr<float>();
 | 
			
		||||
      beta_ptr = at::cuda::detail::get_cublas_device_zero();
 | 
			
		||||
    } else {
 | 
			
		||||
      alpha_val = a.item<float>();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
    // For other data types, use the get_scale_mode function based on scaling type
 | 
			
		||||
    // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
 | 
			
		||||
    // but we must invoke get_scale_mode anyways to trigger the version checks.
 | 
			
		||||
@ -2079,7 +2048,6 @@ void scaled_gemm(
 | 
			
		||||
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
 | 
			
		||||
  int returnedResult = 0;
 | 
			
		||||
  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
 | 
			
		||||
 | 
			
		||||
  TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
 | 
			
		||||
      ltHandle,
 | 
			
		||||
      computeDesc.descriptor(),
 | 
			
		||||
@ -2120,10 +2088,10 @@ void scaled_gemm(
 | 
			
		||||
        auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
 | 
			
		||||
                ltHandle,
 | 
			
		||||
                computeDesc.descriptor(),
 | 
			
		||||
                alpha_ptr,
 | 
			
		||||
                &alpha_val,
 | 
			
		||||
                Adesc.descriptor(),
 | 
			
		||||
                Bdesc.descriptor(),
 | 
			
		||||
                beta_ptr,
 | 
			
		||||
                &beta_val,
 | 
			
		||||
                Cdesc.descriptor(),
 | 
			
		||||
                Ddesc.descriptor(),
 | 
			
		||||
                all_algos[i].algo,
 | 
			
		||||
@ -2142,14 +2110,17 @@ void scaled_gemm(
 | 
			
		||||
  cublasStatus_t cublasStatus = cublasLtMatmul(
 | 
			
		||||
      ltHandle,
 | 
			
		||||
      computeDesc.descriptor(),
 | 
			
		||||
      alpha_ptr,
 | 
			
		||||
      &alpha_val,
 | 
			
		||||
      mat1_ptr,
 | 
			
		||||
      Adesc.descriptor(),
 | 
			
		||||
      mat2_ptr,
 | 
			
		||||
      Bdesc.descriptor(),
 | 
			
		||||
      beta_ptr,
 | 
			
		||||
      // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
 | 
			
		||||
      &beta_val,
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
      result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
 | 
			
		||||
#else
 | 
			
		||||
      nullptr,
 | 
			
		||||
#endif // ifdef USE_ROCM
 | 
			
		||||
      Cdesc.descriptor(),
 | 
			
		||||
      result_ptr,
 | 
			
		||||
      Ddesc.descriptor(),
 | 
			
		||||
 | 
			
		||||
@ -161,8 +161,7 @@ void scaled_gemm(
 | 
			
		||||
    const void* result_scale_ptr,
 | 
			
		||||
    int64_t result_ld,
 | 
			
		||||
    ScalarType result_dtype,
 | 
			
		||||
    bool use_fast_accum,
 | 
			
		||||
    const std::optional<Tensor>& alpha);
 | 
			
		||||
    bool use_fast_accum);
 | 
			
		||||
 | 
			
		||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype)  CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
 | 
			
		||||
 */
 | 
			
		||||
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
 | 
			
		||||
  // The RNG state comprises the seed, and an offset used for Philox.
 | 
			
		||||
  constexpr size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
 | 
			
		||||
  auto rng_state = state_tensor.data_ptr<uint8_t>();
 | 
			
		||||
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
 | 
			
		||||
 * and size of the internal state.
 | 
			
		||||
 */
 | 
			
		||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
 | 
			
		||||
  constexpr size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  constexpr size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  detail::check_rng_state(new_state);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -183,6 +183,11 @@ struct CUDACachingHostAllocatorImpl
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool pinned_use_background_threads() override {
 | 
			
		||||
    return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
 | 
			
		||||
        pinned_use_background_threads();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EventPool::Event create_event_internal(DeviceIndex idx) {
 | 
			
		||||
    // Leak the event pool to avoid shutdown issue.
 | 
			
		||||
    static auto* event_pool = new EventPool();
 | 
			
		||||
 | 
			
		||||
@ -177,6 +177,7 @@ inline void segmented_sort_pairs(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
 | 
			
		||||
inline void unique_by_key(
 | 
			
		||||
  KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
 | 
			
		||||
@ -192,6 +193,7 @@ inline void unique_by_key(
 | 
			
		||||
  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
 | 
			
		||||
    keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace impl {
 | 
			
		||||
 | 
			
		||||
@ -577,6 +579,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
 | 
			
		||||
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
 | 
			
		||||
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
 | 
			
		||||
@ -604,6 +607,7 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
 | 
			
		||||
void unique(InputIteratorT input, OutputIteratorT output,
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,22 @@
 | 
			
		||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for UniqueByKey is added to cub 1.16 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/405
 | 
			
		||||
#if CUB_VERSION >= 101600
 | 
			
		||||
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for scan by key is added to cub 1.15
 | 
			
		||||
// in https://github.com/NVIDIA/cub/pull/376
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for cub::FutureValue is added to cub 1.15 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/305
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
 | 
			
		||||
@ -1,54 +0,0 @@
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/Tensor.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace cuda {
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
__device__ __constant__ float cublas_one_device;
 | 
			
		||||
__device__ __constant__ float cublas_zero_device;
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_one() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    const float one = 1.f;
 | 
			
		||||
    AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  float *ptr;
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
 | 
			
		||||
  return ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_zero() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    const float zero = 0.f;
 | 
			
		||||
    AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  float *ptr;
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
 | 
			
		||||
  return ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float *get_user_alpha_ptr() {
 | 
			
		||||
  static float *alpha_ptr;
 | 
			
		||||
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return alpha_ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
} // namespace cuda
 | 
			
		||||
} // namespace at
 | 
			
		||||
@ -1,11 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/core/TensorBase.h>
 | 
			
		||||
 | 
			
		||||
namespace at::cuda::detail {
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_one();
 | 
			
		||||
float *get_cublas_device_zero();
 | 
			
		||||
float *get_user_alpha_ptr();
 | 
			
		||||
 | 
			
		||||
} // namespace at::cuda::detail
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/tunable/TunableOp.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
@ -151,7 +150,6 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
 | 
			
		||||
      BLASType = "unknown";
 | 
			
		||||
  }
 | 
			
		||||
  return BLASType;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Similar to Compute Type in GemmRocblas.h
 | 
			
		||||
@ -246,25 +244,33 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
 | 
			
		||||
 | 
			
		||||
  if (!config.enabled) {
 | 
			
		||||
    return true; // skip when disabled
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
 | 
			
		||||
  auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
 | 
			
		||||
  // comparison done as 1D tensor
 | 
			
		||||
  at::Tensor ref = at::from_blob(c,       {size}, options);
 | 
			
		||||
  at::Tensor oth = at::from_blob(other_c, {size}, options);
 | 
			
		||||
  at::Tensor ref_float = ref.to(at::kFloat);
 | 
			
		||||
  at::Tensor oth_float = oth.to(at::kFloat);
 | 
			
		||||
 | 
			
		||||
  const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
 | 
			
		||||
  if (ok) {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
 | 
			
		||||
  } else {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
 | 
			
		||||
  std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
 | 
			
		||||
  std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
 | 
			
		||||
  double last_succeed_atol = 1;
 | 
			
		||||
  double last_succeed_rtol = 1;
 | 
			
		||||
  for (auto& atol : atols) {
 | 
			
		||||
    for (auto& rtol : rtols) {
 | 
			
		||||
      if (at::allclose(ref_float, oth_float, rtol, atol)) {
 | 
			
		||||
        last_succeed_atol = atol;
 | 
			
		||||
        last_succeed_rtol = rtol;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return ok;
 | 
			
		||||
  if (last_succeed_atol == 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -349,10 +355,8 @@ struct GemmParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<T>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -445,10 +449,8 @@ struct GemmAndBiasParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<T>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -544,10 +546,8 @@ struct GemmStridedBatchedParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -663,9 +663,7 @@ struct ScaledGemmParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
 | 
			
		||||
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
 | 
			
		||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
 | 
			
		||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
 | 
			
		||||
@ -173,9 +173,10 @@ All python APIs exist in the `torch.cuda.tunable` module.
 | 
			
		||||
| get_max_tuning_iterations() -> int | |
 | 
			
		||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
 | 
			
		||||
| get_filename() -> str | |
 | 
			
		||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
 | 
			
		||||
| get_results() -> Tuple[str, str, str, float] | |
 | 
			
		||||
| get_validators() -> Tuple[str, str] | |
 | 
			
		||||
| write_file_on_exit(val: bool) -> None | Default is True. |
 | 
			
		||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
 | 
			
		||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
 | 
			
		||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
 | 
			
		||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
 | 
			
		||||
 | 
			
		||||
@ -107,30 +107,14 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
 | 
			
		||||
  bool is_new = false;
 | 
			
		||||
  ResultEntry inserted = ResultEntry::Null();
 | 
			
		||||
  std::scoped_lock l{lock_};
 | 
			
		||||
 | 
			
		||||
  // ---- mutate maps under results lock ----
 | 
			
		||||
  {
 | 
			
		||||
    std::scoped_lock l{lock_};
 | 
			
		||||
    auto& km = results_[op_signature];  // creates if missing
 | 
			
		||||
    is_new = (km.find(params_signature) == km.end());
 | 
			
		||||
    AddImpl(op_signature, params_signature, std::move(best), km);
 | 
			
		||||
    if (is_new) {
 | 
			
		||||
      inserted = km.at(params_signature);  // snapshot for I/O after unlocking
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
   if (!is_new) return;  // only write once per unique (op, params)
 | 
			
		||||
 | 
			
		||||
   TuningContext* ctx = getTuningContext();
 | 
			
		||||
  if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
 | 
			
		||||
    InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
 | 
			
		||||
 | 
			
		||||
    if (is_new && realtime_out_ && realtime_out_->good()) {
 | 
			
		||||
      AppendResultLine(op_signature, params_signature, inserted);
 | 
			
		||||
    }
 | 
			
		||||
  auto it = results_.find(op_signature);
 | 
			
		||||
  if (it == results_.end()) {
 | 
			
		||||
    it = results_.insert({op_signature, {}}).first;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  AddImpl(op_signature, params_signature, std::move(best), it->second);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
 | 
			
		||||
@ -166,77 +150,6 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
  if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (realtime_out_ && realtime_filename_ != filename) {
 | 
			
		||||
    realtime_out_->flush();
 | 
			
		||||
    realtime_out_->close();
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    validators_written_ = false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool file_exists = false;
 | 
			
		||||
  bool file_empty = true;
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    std::ifstream check_file(filename);
 | 
			
		||||
    if (check_file.good()) {
 | 
			
		||||
      file_exists = true;
 | 
			
		||||
      file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
 | 
			
		||||
 | 
			
		||||
  if (!realtime_out_->good()) {
 | 
			
		||||
    TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if(!file_exists || file_empty) {
 | 
			
		||||
    for(const auto& [key, val] : validators) {
 | 
			
		||||
      (*realtime_out_) << "Validator," << key << "," << val << std::endl;
 | 
			
		||||
      realtime_out_->flush();
 | 
			
		||||
    }
 | 
			
		||||
    validators_written_ = true;
 | 
			
		||||
 | 
			
		||||
    TUNABLE_LOG2("Wrote validators to realtime output file");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  realtime_filename_ = filename;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
  if(!realtime_out_ || !realtime_out_->good()) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  (*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
 | 
			
		||||
  realtime_out_->flush(); //ensure immediate write to disk
 | 
			
		||||
 | 
			
		||||
  TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::CloseRealtimeAppend() {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  if(realtime_out_) {
 | 
			
		||||
    realtime_out_->flush();
 | 
			
		||||
    realtime_out_->close();
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    TUNABLE_LOG2("Closed realtime output file");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
 | 
			
		||||
  std::scoped_lock l{lock_};
 | 
			
		||||
 | 
			
		||||
@ -483,6 +396,7 @@ TuningContext::TuningContext() :
 | 
			
		||||
    tuning_enable_{true},
 | 
			
		||||
    record_untuned_enable_{false},
 | 
			
		||||
    manager_initialized_{false},
 | 
			
		||||
    write_file_on_exit_{true},
 | 
			
		||||
    numerics_check_enable_{false},
 | 
			
		||||
    max_tuning_duration_ms_{30},
 | 
			
		||||
    max_tuning_iterations_{100},
 | 
			
		||||
@ -503,8 +417,20 @@ TuningContext::~TuningContext() {
 | 
			
		||||
    // but doesn't do any computation itself.
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  TUNABLE_LOG1("Closing File");
 | 
			
		||||
  GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
 | 
			
		||||
  auto filename = GetFilename();
 | 
			
		||||
  if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
 | 
			
		||||
    if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
 | 
			
		||||
      if (results_count_from_input_file_ > 0) {
 | 
			
		||||
        TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        TUNABLE_LOG1("writing file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
      if (!WriteFile(filename)) {
 | 
			
		||||
        TUNABLE_LOG1("failed to write file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (untuned_file_.good()) {
 | 
			
		||||
    untuned_file_.close();
 | 
			
		||||
@ -585,54 +511,20 @@ std::ofstream& TuningContext::GetUntunedFile(){
 | 
			
		||||
  return untuned_file_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::WriteFileOnExit(bool value) {
 | 
			
		||||
  write_file_on_exit_ = value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::EnableNumericsCheck(bool value) {
 | 
			
		||||
  numerics_check_enable_ = value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
 | 
			
		||||
  const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
 | 
			
		||||
 | 
			
		||||
  if (!env_opt.has_value()) {
 | 
			
		||||
    return numerics_cfg_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const std::string& env = env_opt.value();
 | 
			
		||||
 | 
			
		||||
  if (env == "0") {
 | 
			
		||||
    return NumericalCheckConfig(false, 1e-5, 1e-5);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t underscore = env.find('_');
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      underscore != std::string::npos,
 | 
			
		||||
      "Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
 | 
			
		||||
      "Expected 'atol_rtol', got: ",
 | 
			
		||||
      env);
 | 
			
		||||
 | 
			
		||||
  double atol = 0.0;
 | 
			
		||||
  double rtol = 0.0;
 | 
			
		||||
 | 
			
		||||
  try {
 | 
			
		||||
    atol = std::stod(env.substr(0, underscore));
 | 
			
		||||
    rtol = std::stod(env.substr(underscore + 1));
 | 
			
		||||
  } catch (const std::exception& e) {
 | 
			
		||||
    TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
 | 
			
		||||
  return NumericalCheckConfig(true, atol, rtol);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
 | 
			
		||||
  TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
 | 
			
		||||
  numerics_cfg_ = {enabled, atol, rtol};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::IsNumericsCheckEnabled() const {
 | 
			
		||||
  const auto cfg = GetNumericalCheckConfig();
 | 
			
		||||
  return cfg.enabled || numerics_check_enable_;
 | 
			
		||||
  const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
 | 
			
		||||
  if (env == "1") {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  return numerics_check_enable_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
 | 
			
		||||
@ -742,6 +634,11 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
 | 
			
		||||
    auto filename = GetFilename();
 | 
			
		||||
    if (!filename.empty() && !IsRecordUntunedEnabled()) {
 | 
			
		||||
      ReadFile(filename);
 | 
			
		||||
      // attempt immediately to open file for writing to catch errors early
 | 
			
		||||
      std::ofstream file(filename, std::ios::out | std::ios::app);
 | 
			
		||||
      if (!file.good()) {
 | 
			
		||||
        TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
  return manager_;
 | 
			
		||||
@ -847,6 +744,27 @@ bool TuningContext::ReadFile(const std::string& filename_) {
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::WriteFile(const std::string& filename_) {
 | 
			
		||||
  std::string filename = filename_.empty() ? GetFilename() : filename_;
 | 
			
		||||
  std::ofstream file(filename, std::ios::out | std::ios::trunc);
 | 
			
		||||
  if (!file.good()) {
 | 
			
		||||
    TUNABLE_LOG1("error opening tuning results file for writing ", filename);
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  auto validators = GetTuningResultsValidator().GetAllValidators();
 | 
			
		||||
  for (const auto& [key, val] : validators) {
 | 
			
		||||
    file << "Validator," << key << "," << val << std::endl;
 | 
			
		||||
  }
 | 
			
		||||
  auto results = GetTuningResultsManager().Dump();
 | 
			
		||||
  for (const auto& [op_sig, kernelmap] : results) {
 | 
			
		||||
    for (const auto& [param_sig, result] : kernelmap) {
 | 
			
		||||
      file << op_sig << "," << param_sig << "," << result << std::endl;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  file.close();
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct MaybeDelete {
 | 
			
		||||
 | 
			
		||||
@ -103,24 +103,10 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
 | 
			
		||||
 | 
			
		||||
    void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
 | 
			
		||||
      const std::string& params_signature, const std::string& blas_signature);
 | 
			
		||||
 | 
			
		||||
    void InitRealtimeAppend(
 | 
			
		||||
        const std::string& filename,
 | 
			
		||||
        const std::unordered_map<std::string, std::string>& validators);
 | 
			
		||||
 | 
			
		||||
    void AppendResultLine(const std::string& op_sig,
 | 
			
		||||
                         const std::string& param_sig,
 | 
			
		||||
                         const ResultEntry& result);
 | 
			
		||||
 | 
			
		||||
    void CloseRealtimeAppend();  // For clean shutdown
 | 
			
		||||
  private:
 | 
			
		||||
    std::mutex lock_;
 | 
			
		||||
    std::mutex realtime_file_mutex_;
 | 
			
		||||
    std::unique_ptr<std::ofstream> realtime_out_;
 | 
			
		||||
    std::string realtime_filename_;
 | 
			
		||||
    ResultsMap results_;
 | 
			
		||||
    UntunedMap untuned_results_;
 | 
			
		||||
    bool validators_written_ = false;
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -148,16 +134,6 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
 | 
			
		||||
    GetValidateFuncs validators_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct NumericalCheckConfig {
 | 
			
		||||
  bool   enabled{false};
 | 
			
		||||
  double atol{1e-5};
 | 
			
		||||
  double rtol{1e-5};
 | 
			
		||||
 | 
			
		||||
  NumericalCheckConfig() = default;
 | 
			
		||||
  NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
  public:
 | 
			
		||||
    TuningContext();
 | 
			
		||||
@ -179,8 +155,6 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
 | 
			
		||||
    void EnableNumericsCheck(bool value);
 | 
			
		||||
    bool IsNumericsCheckEnabled() const;
 | 
			
		||||
    void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
 | 
			
		||||
    NumericalCheckConfig GetNumericalCheckConfig() const;
 | 
			
		||||
 | 
			
		||||
    void SetMaxTuningDurationMs(int max_duration_ms);
 | 
			
		||||
    int GetMaxTuningDurationMs() const;
 | 
			
		||||
@ -211,7 +185,10 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
 | 
			
		||||
    std::string GetFilename() const;
 | 
			
		||||
 | 
			
		||||
    void WriteFileOnExit(bool value);
 | 
			
		||||
 | 
			
		||||
    bool ReadFile(const std::string& filename={});
 | 
			
		||||
    bool WriteFile(const std::string& filename={});
 | 
			
		||||
 | 
			
		||||
    template<class... Types>
 | 
			
		||||
    void Log(int level, Types... args) {
 | 
			
		||||
@ -230,6 +207,7 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    bool tuning_enable_;
 | 
			
		||||
    bool record_untuned_enable_;
 | 
			
		||||
    bool manager_initialized_;
 | 
			
		||||
    bool write_file_on_exit_;
 | 
			
		||||
    bool numerics_check_enable_;
 | 
			
		||||
    int max_tuning_duration_ms_;
 | 
			
		||||
    int max_tuning_iterations_;
 | 
			
		||||
@ -244,8 +222,6 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    std::ofstream untuned_file_;
 | 
			
		||||
    size_t results_count_from_input_file_;
 | 
			
		||||
    bool is_shutting_down_;
 | 
			
		||||
 | 
			
		||||
    NumericalCheckConfig numerics_cfg_{};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
 | 
			
		||||
 | 
			
		||||
@ -109,8 +109,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
 | 
			
		||||
          params->c_scale_ptr,
 | 
			
		||||
          params->ldc,
 | 
			
		||||
          params->c_dtype,
 | 
			
		||||
          params->use_fast_accum,
 | 
			
		||||
          std::nullopt /* alpha */);
 | 
			
		||||
          params->use_fast_accum);
 | 
			
		||||
      return OK;
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -267,10 +267,27 @@ class TunableOp {
 | 
			
		||||
      for (size_t i = 0; i < op_names_.size(); i++) {
 | 
			
		||||
        auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
 | 
			
		||||
 | 
			
		||||
        auto status = candidate->Call(reusable_params[0]);
 | 
			
		||||
        if (status != OK) {
 | 
			
		||||
          TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
          continue;
 | 
			
		||||
        if (do_numerics_check) {
 | 
			
		||||
          ParamsT* numerical_params = params->DeepCopy(false);
 | 
			
		||||
          auto status = candidate->Call(numerical_params);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            numerical_params->Delete();
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          status = reference_params->NumericalCheck(numerical_params);
 | 
			
		||||
          numerical_params->Delete();
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
          auto status = candidate->Call(reusable_params[0]);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // collect a small profile
 | 
			
		||||
@ -293,22 +310,6 @@ class TunableOp {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (do_numerics_check) {
 | 
			
		||||
          ParamsT* numerical_params = params->DeepCopy(false);
 | 
			
		||||
          auto status = candidate->Call(numerical_params);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            numerical_params->Delete();
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          status = reference_params->NumericalCheck(numerical_params);
 | 
			
		||||
          numerical_params->Delete();
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // for warmup does user set max duration, max iters, or both?
 | 
			
		||||
        // warmup is skipped by default, i.e. warmup_iter = 0
 | 
			
		||||
        // warmup will be set to the non-zero value of max_warmup_duration
 | 
			
		||||
 | 
			
		||||
@ -213,22 +213,40 @@ static cudnn_grid_sample_backward_batch_rule(
 | 
			
		||||
  return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// uses functional formulation for one_hot under vmap to be compatible with
 | 
			
		||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
 | 
			
		||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
 | 
			
		||||
// but requires explicit positive num_classes under vmap to avoid
 | 
			
		||||
// data-dependent output shapes.
 | 
			
		||||
// TODO: replace with targetable functionalization
 | 
			
		||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
    TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
 | 
			
		||||
    auto shape = self.sym_sizes().vec();
 | 
			
		||||
 | 
			
		||||
    // empty tensor could be converted to one hot representation,
 | 
			
		||||
    // but shape inference is not possible.
 | 
			
		||||
    if (self.sym_numel() == 0) {
 | 
			
		||||
        if (num_classes <= 0) {
 | 
			
		||||
            TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
 | 
			
		||||
        } else {
 | 
			
		||||
            shape.emplace_back(num_classes);
 | 
			
		||||
            return at::empty_symint(shape, self.options());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // disallow implicit inference under vmap; this would be data-dependent
 | 
			
		||||
    // and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
 | 
			
		||||
    TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
 | 
			
		||||
        "provide an explicit positive num_classes argument.");
 | 
			
		||||
 | 
			
		||||
    const auto options = self.options();
 | 
			
		||||
    at::Tensor index = at::arange(num_classes, options);
 | 
			
		||||
    return at::eq(self.unsqueeze(-1), index).to(at::kLong);
 | 
			
		||||
    // Disabling all of the following checks. This is OK because scatter has checks too.
 | 
			
		||||
    // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
 | 
			
		||||
    // // non-empty tensor
 | 
			
		||||
    // if (self.device().type() != at::kCUDA) {
 | 
			
		||||
    //   //for cuda, rely on device assert thrown by scatter
 | 
			
		||||
    //   TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
 | 
			
		||||
    // }
 | 
			
		||||
    // if (self.device().type() != at::kCUDA) {
 | 
			
		||||
    //   //rely on device asserts from scatter to avoid sync here
 | 
			
		||||
    //   TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
 | 
			
		||||
    // }
 | 
			
		||||
 | 
			
		||||
    shape.emplace_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros_symint(shape, self.options());
 | 
			
		||||
    return ret.scatter(-1, self.unsqueeze(-1), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename A, A a, typename C>
 | 
			
		||||
 | 
			
		||||
@ -160,10 +160,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
 | 
			
		||||
  DispatchKey::CUDA,
 | 
			
		||||
  DispatchKey::CPU,
 | 
			
		||||
  DispatchKey::PrivateUse1,
 | 
			
		||||
  DispatchKey::SparseCPU,
 | 
			
		||||
  DispatchKey::SparseCUDA,
 | 
			
		||||
  DispatchKey::SparseCsrCPU,
 | 
			
		||||
  DispatchKey::SparseCsrCUDA,
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
 | 
			
		||||
 | 
			
		||||
@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
 | 
			
		||||
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
 | 
			
		||||
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
 | 
			
		||||
static const double SELU_SCALE = 1.0507009873554804934193349852946;
 | 
			
		||||
 | 
			
		||||
DEFINE_DISPATCH(elu_stub);
 | 
			
		||||
DEFINE_DISPATCH(elu_backward_stub);
 | 
			
		||||
 | 
			
		||||
@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
 | 
			
		||||
#if AT_BUILD_WITH_BLAS()
 | 
			
		||||
template <>
 | 
			
		||||
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
 | 
			
		||||
  auto constexpr intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  auto intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  return n <= intmax && incx <= intmax;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
 | 
			
		||||
    int64_t incx,
 | 
			
		||||
    [[maybe_unused]] float beta,
 | 
			
		||||
    int64_t incy) {
 | 
			
		||||
  auto constexpr intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  auto intmax = std::numeric_limits<int>::max();
 | 
			
		||||
  return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
 | 
			
		||||
         (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -658,7 +658,6 @@ static void check_shape_forward(const at::Tensor& input,
 | 
			
		||||
  TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
 | 
			
		||||
  TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(weight_dim == k,
 | 
			
		||||
           "Expected ", weight_dim, "-dimensional input for ", weight_dim,
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <ATen/native/Math.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/MathConstants.h>
 | 
			
		||||
@ -128,7 +127,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
 | 
			
		||||
  constexpr static scalar_t kTailValues[] = {
 | 
			
		||||
  const static scalar_t kTailValues[] = {
 | 
			
		||||
    0.0810614667953272,
 | 
			
		||||
    0.0413406959554092,
 | 
			
		||||
    0.0276779256849983,
 | 
			
		||||
@ -140,7 +139,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
 | 
			
		||||
    0.00925546218271273,
 | 
			
		||||
    0.00833056343336287
 | 
			
		||||
  };
 | 
			
		||||
  if (k < std::size(kTailValues)) {
 | 
			
		||||
  if (k <= 9) {
 | 
			
		||||
    return kTailValues[static_cast<size_t>(k)];
 | 
			
		||||
  }
 | 
			
		||||
  scalar_t kp1sq = (k + 1) * (k + 1);
 | 
			
		||||
 | 
			
		||||
@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
 | 
			
		||||
    try {
 | 
			
		||||
      mkldnn_matmul_i8i8i32(self, mat2, result);
 | 
			
		||||
      dispatched = true;
 | 
			
		||||
    } catch ([[maybe_unused]] const std::exception& e) {
 | 
			
		||||
    } catch (const std::exception& e) {
 | 
			
		||||
      TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
  // lanczos approximation
 | 
			
		||||
  static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
  static const scalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
    0.006061842346248906525783753964555936883222,
 | 
			
		||||
    0.5098416655656676188125178644804694509993,
 | 
			
		||||
    19.51992788247617482847860966235652136208,
 | 
			
		||||
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
    103794043.1163445451906271053616070238554,
 | 
			
		||||
    56906521.91347156388090791033559122686859
 | 
			
		||||
  };
 | 
			
		||||
  static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
  static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
    1.,
 | 
			
		||||
    66.,
 | 
			
		||||
    1925.,
 | 
			
		||||
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
 | 
			
		||||
  // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
 | 
			
		||||
  static constexpr scalar_t d[25][25] =
 | 
			
		||||
  static const scalar_t d[25][25] =
 | 
			
		||||
    {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
 | 
			
		||||
      1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
 | 
			
		||||
      3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,7 @@
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
static constexpr int MIOPEN_DIM_MAX = 5;
 | 
			
		||||
static const int MIOPEN_DIM_MAX = 5;
 | 
			
		||||
 | 
			
		||||
namespace at::meta {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto shape = self.sym_sizes().vec();
 | 
			
		||||
    auto shape = self.sizes().vec();
 | 
			
		||||
 | 
			
		||||
    // empty tensor could be converted to one hot representation,
 | 
			
		||||
    // but shape inference is not possible.
 | 
			
		||||
    if (self.sym_numel() == 0) {
 | 
			
		||||
    if (self.numel() == 0) {
 | 
			
		||||
        if (num_classes <= 0) {
 | 
			
		||||
            TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
 | 
			
		||||
        } else {
 | 
			
		||||
            shape.emplace_back(num_classes);
 | 
			
		||||
            return at::empty_symint(shape, self.options());
 | 
			
		||||
            shape.push_back(num_classes);
 | 
			
		||||
            return at::empty(shape, self.options());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    shape.emplace_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros_symint(shape, self.options());
 | 
			
		||||
    shape.push_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros(shape, self.options());
 | 
			
		||||
    ret.scatter_(-1, self.unsqueeze(-1), 1);
 | 
			
		||||
    return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1906,9 +1906,11 @@ Tensor& index_fill_(
 | 
			
		||||
        "This also applies to advanced indexing e.g. tensor[mask] = scalar");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      self.is_complex() || !source.isComplex(),
 | 
			
		||||
      "index_fill_(): Converting complex Scalar to non-complex type is not supported");
 | 
			
		||||
  if (!self.is_complex() && source.isComplex()) {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false,
 | 
			
		||||
        "index_fill_(): Converting complex Scalar to non-complex type is not supported");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Handle the case when `self` is 0-dim
 | 
			
		||||
  Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
 | 
			
		||||
 | 
			
		||||
@ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
 | 
			
		||||
  // next broadcast all index tensors together
 | 
			
		||||
  try {
 | 
			
		||||
    indices = expand_outplace(indices);
 | 
			
		||||
  } catch (std::exception&) {
 | 
			
		||||
  } catch (std::exception& e) {
 | 
			
		||||
    TORCH_CHECK_INDEX(
 | 
			
		||||
        false,
 | 
			
		||||
        "shape mismatch: indexing tensors could not be broadcast together"
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
 | 
			
		||||
  } else if (dtype == ScalarType::Half) {
 | 
			
		||||
    [&]() {
 | 
			
		||||
      using scalar_t =
 | 
			
		||||
          c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
 | 
			
		||||
      const auto exp = exp_scalar.to<scalar_t>();
 | 
			
		||||
      using Vec = Vectorized<scalar_t>;
 | 
			
		||||
      cpu_kernel_vec(iter,
 | 
			
		||||
 | 
			
		||||
@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
 | 
			
		||||
  // We keep this structure for BC and consider as deprecated.
 | 
			
		||||
  // See HelperInterpNearestExact as replacement
 | 
			
		||||
 | 
			
		||||
  static constexpr int interp_size = 1;
 | 
			
		||||
  static const int interp_size = 1;
 | 
			
		||||
 | 
			
		||||
  static inline void init_indices_weights(
 | 
			
		||||
    at::ScalarType output_type,
 | 
			
		||||
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
 | 
			
		||||
 | 
			
		||||
struct HelperInterpLinear : public HelperInterpBase {
 | 
			
		||||
 | 
			
		||||
  static constexpr int interp_size = 2;
 | 
			
		||||
  static const int interp_size = 2;
 | 
			
		||||
 | 
			
		||||
  // Compute indices and weights for each interpolated dimension
 | 
			
		||||
  // indices_weights = {
 | 
			
		||||
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
 | 
			
		||||
 | 
			
		||||
struct HelperInterpCubic : public HelperInterpBase {
 | 
			
		||||
 | 
			
		||||
  static constexpr int interp_size = 4;
 | 
			
		||||
  static const int interp_size = 4;
 | 
			
		||||
 | 
			
		||||
  // Compute indices and weights for each interpolated dimension
 | 
			
		||||
  // indices_weights = {
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -856,13 +856,9 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
      out_calc_t output_offset_calculator,
 | 
			
		||||
      loader_t loader,
 | 
			
		||||
      storer_t storer) {
 | 
			
		||||
    constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
 | 
			
		||||
    constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
 | 
			
		||||
    constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
 | 
			
		||||
    if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
 | 
			
		||||
      using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
 | 
			
		||||
      using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
 | 
			
		||||
      using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
 | 
			
		||||
    if (ret_t == rt_binary_specializations[arg_index][0] &&
 | 
			
		||||
        arg0_t == rt_binary_specializations[arg_index][1] &&
 | 
			
		||||
        arg1_t == rt_binary_specializations[arg_index][2])
 | 
			
		||||
      launch_vectorized_templated_kernel<
 | 
			
		||||
          func_t,
 | 
			
		||||
          array_t,
 | 
			
		||||
@ -870,9 +866,12 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          out_calc_t,
 | 
			
		||||
          loader_t,
 | 
			
		||||
          storer_t,
 | 
			
		||||
          cret_t,
 | 
			
		||||
          carg0_t,
 | 
			
		||||
          carg1_t>(
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][0]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][1]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][2]>::t)>(
 | 
			
		||||
          numel,
 | 
			
		||||
          f,
 | 
			
		||||
          data,
 | 
			
		||||
@ -880,7 +879,6 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          output_offset_calculator,
 | 
			
		||||
          loader,
 | 
			
		||||
          storer);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,41 +38,12 @@ __device__ inline int min(int a, int b) {
 | 
			
		||||
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename index_t>
 | 
			
		||||
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
 | 
			
		||||
  const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
 | 
			
		||||
  return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
 | 
			
		||||
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
 | 
			
		||||
  return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename index_t>
 | 
			
		||||
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
 | 
			
		||||
  return std::min((size + pad) / stride + 1, pooled_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline bool can_use_int32_nhwc(
 | 
			
		||||
    int64_t nbatch, int64_t channels,
 | 
			
		||||
    int64_t height, int64_t width,
 | 
			
		||||
    int64_t pooled_height, int64_t pooled_width,
 | 
			
		||||
    int64_t in_stride_n, int64_t in_stride_c,
 | 
			
		||||
    int64_t in_stride_h, int64_t in_stride_w)
 | 
			
		||||
{
 | 
			
		||||
  constexpr int64_t int_max = std::numeric_limits<int>::max();
 | 
			
		||||
 | 
			
		||||
  int64_t max_intra_batch =
 | 
			
		||||
      (height ? (height - 1) * in_stride_h : 0) +
 | 
			
		||||
      (width ? (width - 1) * in_stride_w : 0) +
 | 
			
		||||
      (channels? (channels - 1) * in_stride_c : 0);
 | 
			
		||||
 | 
			
		||||
  int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;
 | 
			
		||||
 | 
			
		||||
  if (max_input_offset > int_max) return false;
 | 
			
		||||
 | 
			
		||||
  int64_t out_batch_stride = pooled_height * pooled_width * channels;
 | 
			
		||||
  if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;
 | 
			
		||||
 | 
			
		||||
  if (height * width > int_max) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
 | 
			
		||||
  return min((size + pad) / stride + 1, pooled_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// kernels borrowed from Caffe
 | 
			
		||||
@ -114,25 +85,21 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename index_t>
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
 | 
			
		||||
__global__ void max_pool_forward_nhwc(
 | 
			
		||||
    const scalar_t* bottom_data,
 | 
			
		||||
    const int nbatch,
 | 
			
		||||
    const index_t channels, const index_t height, const index_t width,
 | 
			
		||||
    const index_t pooled_height, const index_t pooled_width,
 | 
			
		||||
    const int kernel_h, const int kernel_w, const int stride_h,
 | 
			
		||||
    const int stride_w, const int pad_h, const int pad_w,
 | 
			
		||||
    const int dilation_h, const int dilation_w,
 | 
			
		||||
    const index_t in_stride_n, const index_t in_stride_c,
 | 
			
		||||
    const index_t in_stride_h, const index_t in_stride_w,
 | 
			
		||||
    const int kernel_stride_C, const int kernel_size_C,
 | 
			
		||||
    scalar_t* top_data, int64_t* top_mask) {
 | 
			
		||||
 | 
			
		||||
  extern __shared__ unsigned char smem_raw[];
 | 
			
		||||
  index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
 | 
			
		||||
  scalar_t *out_cached = reinterpret_cast<scalar_t*>(
 | 
			
		||||
      out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);
 | 
			
		||||
__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
 | 
			
		||||
                                   const int64_t channels, const int64_t height,
 | 
			
		||||
                                   const int64_t width, const int pooled_height, const int pooled_width,
 | 
			
		||||
                                   const int kernel_h, const int kernel_w, const int stride_h,
 | 
			
		||||
                                   const int stride_w, const int pad_h, const int pad_w,
 | 
			
		||||
                                   const int dilation_h, const int dilation_w,
 | 
			
		||||
                                   const int in_stride_n, const int in_stride_c,
 | 
			
		||||
                                   const int in_stride_h, const int in_stride_w,
 | 
			
		||||
                                   const int kernel_stride_C, const int kernel_size_C,
 | 
			
		||||
                                   scalar_t* top_data, int64_t* top_mask) {
 | 
			
		||||
  extern __shared__ int smem[];
 | 
			
		||||
  int *out_mask_cached = smem;
 | 
			
		||||
  scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
 | 
			
		||||
 | 
			
		||||
  // flattening cta for pre-computation & smem initialization;
 | 
			
		||||
  int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
 | 
			
		||||
@ -151,26 +118,26 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
  int channel_id = blockIdx.x / nbatch;
 | 
			
		||||
  int channel_offset = threadIdx.x + channel_id * blockDim.x;
 | 
			
		||||
 | 
			
		||||
  top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
 | 
			
		||||
  top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
 | 
			
		||||
  bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;
 | 
			
		||||
  top_data = top_data + batch_id * pooled_height * pooled_width * channels;
 | 
			
		||||
  top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
 | 
			
		||||
  bottom_data = bottom_data + batch_id * in_stride_n;
 | 
			
		||||
 | 
			
		||||
  out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
 | 
			
		||||
  out_mask_cached  += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
 | 
			
		||||
  out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
 | 
			
		||||
  out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
 | 
			
		||||
 | 
			
		||||
  int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
 | 
			
		||||
  int oW = (static_cast<int>(pooled_width)  + gridDim.y - 1) / gridDim.y;
 | 
			
		||||
  int oH = (pooled_height + gridDim.z-1) / gridDim.z;
 | 
			
		||||
  int oW = (pooled_width + gridDim.y-1) / gridDim.y;
 | 
			
		||||
  int ostartH = threadIdx.z + blockIdx.z*oH;
 | 
			
		||||
  int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
 | 
			
		||||
  int oendH = ::min(ostartH+oH, pooled_height);
 | 
			
		||||
  int ostartW = threadIdx.y + blockIdx.y*oW;
 | 
			
		||||
  int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));
 | 
			
		||||
  int oendW = ::min(ostartW+oW, pooled_width);
 | 
			
		||||
 | 
			
		||||
  for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
 | 
			
		||||
    index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
 | 
			
		||||
    index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
 | 
			
		||||
    int hstart = oh * stride_h - pad_h;
 | 
			
		||||
    int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
 | 
			
		||||
    for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
 | 
			
		||||
      index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
 | 
			
		||||
      index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
 | 
			
		||||
      int wstart = ow * stride_w - pad_w;
 | 
			
		||||
      int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
 | 
			
		||||
      while(hstart < 0)
 | 
			
		||||
        hstart += dilation_h;
 | 
			
		||||
      while(wstart < 0)
 | 
			
		||||
@ -218,12 +185,12 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
      // Else do it Non-Prefetch...
 | 
			
		||||
      else
 | 
			
		||||
#endif
 | 
			
		||||
      for (index_t ih = hstart; ih < hend; ih += dilation_h) {
 | 
			
		||||
        for (index_t iw = wstart; iw < wend; iw += dilation_w) {
 | 
			
		||||
      for (int ih = hstart; ih < hend; ih += dilation_h) {
 | 
			
		||||
        for (int iw = wstart; iw < wend; iw += dilation_w) {
 | 
			
		||||
          int cached_index = threadIdx.x;
 | 
			
		||||
          const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
 | 
			
		||||
          for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
 | 
			
		||||
            scalar_t val = ptr_input[c * in_stride_c];
 | 
			
		||||
          for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
 | 
			
		||||
            scalar_t val = ptr_input[c*in_stride_c];
 | 
			
		||||
            if ((val > out_cached[cached_index]) || at::_isnan(val)) {
 | 
			
		||||
              out_cached[cached_index] = val;
 | 
			
		||||
              out_mask_cached[cached_index] = ih * width + iw;
 | 
			
		||||
@ -233,15 +200,15 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
 | 
			
		||||
      int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
 | 
			
		||||
      scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
 | 
			
		||||
      int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
 | 
			
		||||
 | 
			
		||||
      int cached_index = threadIdx.x;
 | 
			
		||||
      for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
 | 
			
		||||
      for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
 | 
			
		||||
        ptr_output_data[c] = out_cached[cached_index];
 | 
			
		||||
        ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
 | 
			
		||||
        ptr_output_mask[c] = out_mask_cached[cached_index];
 | 
			
		||||
        out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
 | 
			
		||||
        out_mask_cached[cached_index] = index_t(0);
 | 
			
		||||
        out_mask_cached[cached_index] = 0;
 | 
			
		||||
        cached_index += blockDim.x;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -249,7 +216,7 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static constexpr int BLOCK_THREADS = 256;
 | 
			
		||||
static const int BLOCK_THREADS = 256;
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
#if defined (USE_ROCM)
 | 
			
		||||
@ -495,11 +462,6 @@ const Tensor& indices) {
 | 
			
		||||
              maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
 | 
			
		||||
          const dim3 block(block_x, block_y, block_z);
 | 
			
		||||
 | 
			
		||||
          bool use_int32 = can_use_int32_nhwc(
 | 
			
		||||
              nbatch, nInputPlane, inputHeight, inputWidth,
 | 
			
		||||
              outputHeight, outputWidth,
 | 
			
		||||
              in_stride_n, in_stride_c, in_stride_h, in_stride_w);
 | 
			
		||||
 | 
			
		||||
          int kernel_stride_C = ceil_div(
 | 
			
		||||
              safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
 | 
			
		||||
          int kernel_size_C = ceil_div(
 | 
			
		||||
@ -514,41 +476,18 @@ const Tensor& indices) {
 | 
			
		||||
              ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
 | 
			
		||||
          const dim3 grid(grid_x, grid_y, grid_z);
 | 
			
		||||
 | 
			
		||||
          size_t shmem_size;
 | 
			
		||||
          size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;
 | 
			
		||||
          size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
 | 
			
		||||
          AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
 | 
			
		||||
 | 
			
		||||
          if (use_int32) {
 | 
			
		||||
            shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
 | 
			
		||||
            TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
 | 
			
		||||
                        "shared memory too small");
 | 
			
		||||
            max_pool_forward_nhwc<scalar_t, int32_t>
 | 
			
		||||
              <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
                input_data, static_cast<int>(nbatch),
 | 
			
		||||
                static_cast<int32_t>(nInputPlane),
 | 
			
		||||
                static_cast<int32_t>(inputHeight),
 | 
			
		||||
                static_cast<int32_t>(inputWidth),
 | 
			
		||||
                static_cast<int32_t>(outputHeight),
 | 
			
		||||
                static_cast<int32_t>(outputWidth),
 | 
			
		||||
                kH, kW, dH, dW, padH, padW, dilationH, dilationW,
 | 
			
		||||
                static_cast<int32_t>(in_stride_n),
 | 
			
		||||
                static_cast<int32_t>(in_stride_c),
 | 
			
		||||
                static_cast<int32_t>(in_stride_h),
 | 
			
		||||
                static_cast<int32_t>(in_stride_w),
 | 
			
		||||
                kernel_stride_C, kernel_size_C,
 | 
			
		||||
                output_data, indices_data);
 | 
			
		||||
          } else {
 | 
			
		||||
            shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
 | 
			
		||||
            TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
 | 
			
		||||
                        "shared memory too small");
 | 
			
		||||
            max_pool_forward_nhwc<scalar_t, int64_t>
 | 
			
		||||
              <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
                input_data, static_cast<int>(nbatch),
 | 
			
		||||
                nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
 | 
			
		||||
                kH, kW, dH, dW, padH, padW, dilationH, dilationW,
 | 
			
		||||
                in_stride_n, in_stride_c, in_stride_h, in_stride_w,
 | 
			
		||||
                kernel_stride_C, kernel_size_C,
 | 
			
		||||
                output_data, indices_data);
 | 
			
		||||
          }
 | 
			
		||||
          max_pool_forward_nhwc<scalar_t>
 | 
			
		||||
          <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
              input_data, nbatch,
 | 
			
		||||
                  nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
 | 
			
		||||
                  kH, kW, dH, dW, padH, padW, dilationH, dilationW,
 | 
			
		||||
                  in_stride_n, in_stride_c,
 | 
			
		||||
                  in_stride_h, in_stride_w,
 | 
			
		||||
                  kernel_stride_C, kernel_size_C,
 | 
			
		||||
                  output_data, indices_data);
 | 
			
		||||
          C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,9 @@
 | 
			
		||||
#include <ATen/native/cuda/block_reduce.cuh>
 | 
			
		||||
#include <ATen/native/cuda/thread_constants.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/reverse_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -34,9 +36,9 @@ namespace at::native {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
static constexpr int BLOCKDIMY = 16;
 | 
			
		||||
static const int BLOCKDIMY = 16;
 | 
			
		||||
#else
 | 
			
		||||
static constexpr int BLOCKDIMY = 32;
 | 
			
		||||
static const int BLOCKDIMY = 32;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
@ -238,6 +240,10 @@ __global__ void renorm_kernel(
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
 | 
			
		||||
                               int64_t num_weights, int64_t padding_idx,
 | 
			
		||||
@ -300,6 +306,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
 | 
			
		||||
 | 
			
		||||
  if (scale_grad_by_freq) {
 | 
			
		||||
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
 | 
			
		||||
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
@ -326,6 +333,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
 | 
			
		||||
        num_indices
 | 
			
		||||
      );
 | 
			
		||||
    });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
 | 
			
		||||
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
 | 
			
		||||
    });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return embedding_backward_cuda_kernel(grad, orig_indices,
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,9 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/counting_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -194,9 +196,18 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm
 | 
			
		||||
            partials_per_segment_offset[num_of_segments-1];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
 | 
			
		||||
  *num_of_segments_ptr = num_of_segments;
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
} // anon namespace
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Tensor embedding_backward_cuda_kernel(
 | 
			
		||||
        const Tensor &grad,
 | 
			
		||||
@ -223,12 +234,20 @@ Tensor embedding_backward_cuda_kernel(
 | 
			
		||||
  auto segment_offsets = at::empty({numel}, orig_indices.options());
 | 
			
		||||
  auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
 | 
			
		||||
  int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
 | 
			
		||||
    int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
 | 
			
		||||
    write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  });
 | 
			
		||||
#else
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
 | 
			
		||||
    cuda::cub::unique_by_key(
 | 
			
		||||
      sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
 | 
			
		||||
      segment_offsets.mutable_data_ptr<index_t>(),
 | 
			
		||||
      num_of_segments_ptr, sorted_indices.numel());
 | 
			
		||||
  });
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  int64_t max_segments = std::min<int64_t>(numel, num_weights);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,10 +31,16 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/reverse_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
@ -193,6 +199,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
 | 
			
		||||
 | 
			
		||||
  if (scale_grad_by_freq) {
 | 
			
		||||
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
 | 
			
		||||
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
@ -219,6 +226,11 @@ Tensor embedding_bag_backward_cuda_sum_avg(
 | 
			
		||||
        num_indices
 | 
			
		||||
      );
 | 
			
		||||
    });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
 | 
			
		||||
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
 | 
			
		||||
    });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
 | 
			
		||||
      count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,
 | 
			
		||||
 | 
			
		||||
@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
  // lanczos approximation
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
 | 
			
		||||
  constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
  static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
 | 
			
		||||
    0.006061842346248906525783753964555936883222,
 | 
			
		||||
    0.5098416655656676188125178644804694509993,
 | 
			
		||||
    19.51992788247617482847860966235652136208,
 | 
			
		||||
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
 | 
			
		||||
    103794043.1163445451906271053616070238554,
 | 
			
		||||
    56906521.91347156388090791033559122686859
 | 
			
		||||
  };
 | 
			
		||||
  constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
  static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
 | 
			
		||||
    1.,
 | 
			
		||||
    66.,
 | 
			
		||||
    1925.,
 | 
			
		||||
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  accscalar_t ax, fac, res, num, numfac;
 | 
			
		||||
  constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
  static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
    7.09782712893383996843E2 : 88.72283905206835;
 | 
			
		||||
  constexpr accscalar_t EXP1 = 2.718281828459045;
 | 
			
		||||
  constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
 | 
			
		||||
  static const accscalar_t EXP1 = 2.718281828459045;
 | 
			
		||||
  static const accscalar_t lanczos_g = 6.024680040776729583740234375;
 | 
			
		||||
 | 
			
		||||
  if (::fabs(a - x) > 0.4 * ::fabs(a)) {
 | 
			
		||||
    ax = a * ::log(x) - x - ::lgamma(a);
 | 
			
		||||
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
 | 
			
		||||
  // Compute igam using DLMF 8.11.4. [igam1]
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
  constexpr int MAXITER = 2000;
 | 
			
		||||
  static const int MAXITER = 2000;
 | 
			
		||||
 | 
			
		||||
  int i;
 | 
			
		||||
  accscalar_t ans, ax, c, r;
 | 
			
		||||
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
 | 
			
		||||
  accscalar_t fac = 1;
 | 
			
		||||
  accscalar_t sum = 0;
 | 
			
		||||
  accscalar_t term, logx;
 | 
			
		||||
  constexpr int MAXITER = 2000;
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  static const int MAXITER = 2000;
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
 | 
			
		||||
  for (n = 1; n < MAXITER; n++) {
 | 
			
		||||
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
 | 
			
		||||
  // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  constexpr accscalar_t d[25][25] =
 | 
			
		||||
  static const accscalar_t d[25][25] =
 | 
			
		||||
    {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
 | 
			
		||||
    {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
 | 
			
		||||
    {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
 | 
			
		||||
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
 | 
			
		||||
 | 
			
		||||
  int k, n, sgn;
 | 
			
		||||
  int maxpow = 0;
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
  accscalar_t lambda = x / a;
 | 
			
		||||
  accscalar_t sigma = (x - a) / a;
 | 
			
		||||
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
 | 
			
		||||
  int i;
 | 
			
		||||
  accscalar_t ans, ax, c, yc, r, t, y, z;
 | 
			
		||||
  accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
 | 
			
		||||
  constexpr int MAXITER = 2000;
 | 
			
		||||
  constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
  static const int MAXITER = 2000;
 | 
			
		||||
  static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
 | 
			
		||||
    1.11022302462515654042E-16 : 5.9604644775390625E-8;
 | 
			
		||||
  constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
  static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
    4.503599627370496e15 : 16777216.;
 | 
			
		||||
  constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
  static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
 | 
			
		||||
    2.22044604925031308085e-16 : 5.9604644775390625E-8;
 | 
			
		||||
 | 
			
		||||
  ax = _igam_helper_fac(a, x);
 | 
			
		||||
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  accscalar_t absxma_a;
 | 
			
		||||
 | 
			
		||||
  constexpr accscalar_t SMALL = 20.0;
 | 
			
		||||
  constexpr accscalar_t LARGE = 200.0;
 | 
			
		||||
  constexpr accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  constexpr accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
  static const accscalar_t SMALL = 20.0;
 | 
			
		||||
  static const accscalar_t LARGE = 200.0;
 | 
			
		||||
  static const accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  static const accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
 | 
			
		||||
  if ((x < 0) || (a < 0)) {
 | 
			
		||||
    // out of defined-region of the function
 | 
			
		||||
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
 | 
			
		||||
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  accscalar_t absxma_a;
 | 
			
		||||
  constexpr accscalar_t SMALL = 20.0;
 | 
			
		||||
  constexpr accscalar_t LARGE = 200.0;
 | 
			
		||||
  constexpr accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  constexpr accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
  static const accscalar_t SMALL = 20.0;
 | 
			
		||||
  static const accscalar_t LARGE = 200.0;
 | 
			
		||||
  static const accscalar_t SMALLRATIO = 0.3;
 | 
			
		||||
  static const accscalar_t LARGERATIO = 4.5;
 | 
			
		||||
 | 
			
		||||
  // boundary values following SciPy
 | 
			
		||||
  if ((x < 0) || (a < 0)) {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/native/cuda/SortingCommon.cuh>
 | 
			
		||||
#include <ATen/cuda/cub_definitions.cuh>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/empty_like.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/ThrustAllocator.h>
 | 
			
		||||
#include <thrust/device_ptr.h>
 | 
			
		||||
#include <thrust/execution_policy.h>
 | 
			
		||||
#include <thrust/sort.h>
 | 
			
		||||
#include <thrust/unique.h>
 | 
			
		||||
#include <thrust/device_ptr.h>
 | 
			
		||||
#include <thrust/iterator/constant_iterator.h>
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  at::cuda::ThrustAllocator allocator;
 | 
			
		||||
  auto policy = thrust::cuda::par(allocator).on(stream);
 | 
			
		||||
 | 
			
		||||
  auto num_indices = count.numel();
 | 
			
		||||
 | 
			
		||||
  // Compute an increasing sequence per unique item in sortedIndices:
 | 
			
		||||
  // sorted: 2 5 5 5 7 7 8 9 9
 | 
			
		||||
  //  count: 1 1 2 3 1 2 1 1 2
 | 
			
		||||
  auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
 | 
			
		||||
  auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
 | 
			
		||||
  thrust::inclusive_scan_by_key(
 | 
			
		||||
    policy,
 | 
			
		||||
    sorted_data,
 | 
			
		||||
    sorted_data + num_indices,
 | 
			
		||||
    thrust::make_constant_iterator(1),
 | 
			
		||||
    count_data
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  // Take the maximum of each count per unique key in reverse:
 | 
			
		||||
  // sorted: 2 5 5 5 7 7 8 9 9
 | 
			
		||||
  //  count: 1 3 3 3 2 2 1 2 2
 | 
			
		||||
  thrust::inclusive_scan_by_key(
 | 
			
		||||
    policy,
 | 
			
		||||
    thrust::make_reverse_iterator(sorted_data + num_indices),
 | 
			
		||||
    thrust::make_reverse_iterator(sorted_data),
 | 
			
		||||
    thrust::make_reverse_iterator(count_data + num_indices),
 | 
			
		||||
    thrust::make_reverse_iterator(count_data + num_indices),
 | 
			
		||||
    thrust::equal_to<index_t>(),
 | 
			
		||||
    thrust::maximum<index_t>()
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
template
 | 
			
		||||
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
 | 
			
		||||
  auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  at::cuda::ThrustAllocator allocator;
 | 
			
		||||
  auto policy = thrust::cuda::par(allocator).on(stream);
 | 
			
		||||
  const ptrdiff_t numel = sorted_indices.numel();
 | 
			
		||||
  auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
 | 
			
		||||
  auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
  auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
 | 
			
		||||
  auto ends = thrust::unique_by_key_copy(
 | 
			
		||||
          policy,
 | 
			
		||||
          sorted_indices_dev,
 | 
			
		||||
          sorted_indices_dev + numel,
 | 
			
		||||
          thrust::make_counting_iterator(0),
 | 
			
		||||
          dummy_dev,
 | 
			
		||||
          thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
 | 
			
		||||
  return thrust::get<0>(ends) - dummy_dev;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
template
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
 | 
			
		||||
const auto digamma_string = jiterator_stringify(
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  T digamma(T x) {
 | 
			
		||||
    static constexpr double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
    static const double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
 | 
			
		||||
    // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
 | 
			
		||||
    if (x == 0) {
 | 
			
		||||
@ -3072,9 +3072,9 @@ template <typename scalar_t>
 | 
			
		||||
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
 | 
			
		||||
  // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
 | 
			
		||||
  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
 | 
			
		||||
  static constexpr double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
  constexpr accscalar_t PSI_10 = 2.25175258906672110764;
 | 
			
		||||
  constexpr accscalar_t A[] = {
 | 
			
		||||
  static const double PI_f64 = 3.14159265358979323846;
 | 
			
		||||
  const accscalar_t PSI_10 = 2.25175258906672110764;
 | 
			
		||||
  const accscalar_t A[] = {
 | 
			
		||||
      8.33333333333333333333E-2,
 | 
			
		||||
      -2.10927960927960927961E-2,
 | 
			
		||||
      7.57575757575757575758E-3,
 | 
			
		||||
 | 
			
		||||
@ -146,7 +146,6 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
 | 
			
		||||
  int64_t batch_size = target.size(0);
 | 
			
		||||
  int64_t H = target.size(1);
 | 
			
		||||
  int64_t W = target.size(2);
 | 
			
		||||
  int64_t n_classes = grad_input.size(1);
 | 
			
		||||
 | 
			
		||||
  CUDA_KERNEL_LOOP(index, n_threads) {
 | 
			
		||||
    const int64_t b = index % batch_size;
 | 
			
		||||
@ -157,7 +156,6 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
 | 
			
		||||
    if (cur_target == ignore_index) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
 | 
			
		||||
    scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
 | 
			
		||||
    grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ namespace at::native {
 | 
			
		||||
 | 
			
		||||
// The maximum number of threads in a block
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 1024;
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 256;
 | 
			
		||||
#else
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 512;
 | 
			
		||||
#endif
 | 
			
		||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
 | 
			
		||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
 | 
			
		||||
static int getNumThreads(int nElem) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
  int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
 | 
			
		||||
  int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
 | 
			
		||||
#else
 | 
			
		||||
  int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user