mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			gh/malfet/
			...
			copilot/co
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 241b702918 | |||
| 83df2e0610 | |||
| 77fe8234bb | 
@ -20,7 +20,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
 | 
			
		||||
 | 
			
		||||
# cmake-3.18.4 from pip
 | 
			
		||||
RUN yum install -y python3-pip && \
 | 
			
		||||
    python3 -mpip install cmake==3.18.4 && \
 | 
			
		||||
    python3 -m pip install cmake==3.18.4 && \
 | 
			
		||||
    ln -s /usr/local/bin/cmake /usr/bin/cmake3
 | 
			
		||||
RUN rm -rf /usr/local/cuda-*
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -83,6 +83,10 @@ function build_cpython {
 | 
			
		||||
        py_suffix=${py_ver::-1}
 | 
			
		||||
        py_folder=$py_suffix
 | 
			
		||||
    fi
 | 
			
		||||
    # Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4
 | 
			
		||||
    if [ "$py_suffix" == "3.14.0" ]; then
 | 
			
		||||
        py_suffix="3.14.0rc2"
 | 
			
		||||
    fi
 | 
			
		||||
    wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
 | 
			
		||||
    do_cpython_build $py_ver Python-$py_suffix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ function install_torchbench() {
 | 
			
		||||
  python install.py --continue_on_fail
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
  python -m pip freeze
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
  chown -R jenkins torchbench
 | 
			
		||||
 | 
			
		||||
@ -8,8 +8,8 @@ MKLROOT=/opt/intel
 | 
			
		||||
mkdir -p ${MKLROOT}
 | 
			
		||||
pushd /tmp
 | 
			
		||||
 | 
			
		||||
python3 -mpip install wheel
 | 
			
		||||
python3 -mpip download -d . mkl-static==${MKL_VERSION}
 | 
			
		||||
python3 -m pip install wheel
 | 
			
		||||
python3 -m pip download -d . mkl-static==${MKL_VERSION}
 | 
			
		||||
python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
 | 
			
		||||
python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
 | 
			
		||||
mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT}
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ pip_install \
 | 
			
		||||
  transformers==4.36.2
 | 
			
		||||
 | 
			
		||||
pip_install coloredlogs packaging
 | 
			
		||||
pip_install onnxruntime==1.23.1
 | 
			
		||||
pip_install onnxruntime==1.23.0
 | 
			
		||||
pip_install onnxscript==0.5.4
 | 
			
		||||
 | 
			
		||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
 | 
			
		||||
 | 
			
		||||
@ -11,5 +11,5 @@ ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
 | 
			
		||||
python -m venv /var/lib/jenkins/ci_env
 | 
			
		||||
source /var/lib/jenkins/ci_env/bin/activate
 | 
			
		||||
 | 
			
		||||
python -mpip install --upgrade pip
 | 
			
		||||
python -mpip install -r /opt/requirements-ci.txt
 | 
			
		||||
python -m pip install --upgrade pip
 | 
			
		||||
python -m pip install -r /opt/requirements-ci.txt
 | 
			
		||||
 | 
			
		||||
@ -39,13 +39,9 @@ case ${DOCKER_TAG_PREFIX} in
 | 
			
		||||
        DOCKER_GPU_BUILD_ARG=""
 | 
			
		||||
        ;;
 | 
			
		||||
    rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        BASE_TARGET=rocm
 | 
			
		||||
        GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op
 | 
			
		||||
 | 
			
		||||
# cmake-3.18.4 from pip
 | 
			
		||||
RUN yum install -y python3-pip && \
 | 
			
		||||
    python3 -mpip install cmake==3.18.4 && \
 | 
			
		||||
    python3 -m pip install cmake==3.18.4 && \
 | 
			
		||||
    ln -s /usr/local/bin/cmake /usr/bin/cmake3
 | 
			
		||||
 | 
			
		||||
FROM base as openssl
 | 
			
		||||
@ -135,7 +135,7 @@ RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
 | 
			
		||||
 | 
			
		||||
# cmake-3.18.4 from pip; force in case cmake3 already exists
 | 
			
		||||
RUN yum install -y python3-pip && \
 | 
			
		||||
    python3 -mpip install cmake==3.18.4 && \
 | 
			
		||||
    python3 -m pip install cmake==3.18.4 && \
 | 
			
		||||
    ln -sf /usr/local/bin/cmake /usr/bin/cmake3
 | 
			
		||||
 | 
			
		||||
FROM cpu_final as cuda_final
 | 
			
		||||
@ -157,7 +157,7 @@ ENV ROCM_PATH /opt/rocm
 | 
			
		||||
# cmake-3.28.4 from pip to get enable_language(HIP)
 | 
			
		||||
# and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker
 | 
			
		||||
RUN python3 -m pip install --upgrade pip && \
 | 
			
		||||
    python3 -mpip install cmake==3.28.4
 | 
			
		||||
    python3 -m pip install cmake==3.28.4
 | 
			
		||||
# replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path
 | 
			
		||||
ADD ./common/install_rocm_drm.sh install_rocm_drm.sh
 | 
			
		||||
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
 | 
			
		||||
@ -174,7 +174,7 @@ FROM cpu_final as xpu_final
 | 
			
		||||
ENV XPU_DRIVER_TYPE ROLLING
 | 
			
		||||
# cmake-3.28.4 from pip
 | 
			
		||||
RUN python3 -m pip install --upgrade pip && \
 | 
			
		||||
    python3 -mpip install cmake==3.28.4
 | 
			
		||||
    python3 -m pip install cmake==3.28.4
 | 
			
		||||
ADD ./common/install_xpu.sh install_xpu.sh
 | 
			
		||||
ENV XPU_VERSION 2025.2
 | 
			
		||||
RUN bash ./install_xpu.sh && rm install_xpu.sh
 | 
			
		||||
 | 
			
		||||
@ -113,7 +113,7 @@ RUN dnf install -y \
 | 
			
		||||
RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
 | 
			
		||||
 | 
			
		||||
# cmake-3.28.0 from pip for onnxruntime
 | 
			
		||||
RUN python3 -mpip install cmake==3.28.0
 | 
			
		||||
RUN python3 -m pip install cmake==3.28.0
 | 
			
		||||
 | 
			
		||||
ADD ./common/patch_libstdc.sh patch_libstdc.sh
 | 
			
		||||
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
 | 
			
		||||
 | 
			
		||||
@ -75,13 +75,9 @@ case ${image} in
 | 
			
		||||
        DOCKERFILE_SUFFIX="_cuda_aarch64"
 | 
			
		||||
        ;;
 | 
			
		||||
    manylinux2_28-builder:rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        TARGET=rocm_final
 | 
			
		||||
        MANY_LINUX_VERSION="2_28"
 | 
			
		||||
 | 
			
		||||
@ -334,12 +334,12 @@ sympy==1.13.3
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
onnx==1.19.1
 | 
			
		||||
onnx==1.18.0
 | 
			
		||||
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
onnxscript==0.5.4
 | 
			
		||||
onnxscript==0.5.3
 | 
			
		||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
@ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules
 | 
			
		||||
        logger.info("Successfully cloned %s", target)
 | 
			
		||||
        return r, commit
 | 
			
		||||
 | 
			
		||||
    except GitCommandError:
 | 
			
		||||
        logger.exception("Git operation failed")
 | 
			
		||||
    except GitCommandError as e:
 | 
			
		||||
        logger.error("Git operation failed: %s", e)
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ dependencies = [
 | 
			
		||||
    "GitPython==3.1.45",
 | 
			
		||||
    "docker==7.1.0",
 | 
			
		||||
    "pytest==7.3.2",
 | 
			
		||||
    "uv==0.9.5"
 | 
			
		||||
    "uv==0.8.6"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
 | 
			
		||||
@ -288,7 +288,7 @@ else
 | 
			
		||||
    # or building non-XLA tests.
 | 
			
		||||
    if [[ "$BUILD_ENVIRONMENT" != *rocm*  && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
 | 
			
		||||
      # Install numpy-2.0.2 for builds which are backward compatible with 1.X
 | 
			
		||||
      python -mpip install numpy==2.0.2
 | 
			
		||||
      python -m pip install numpy==2.0.2
 | 
			
		||||
 | 
			
		||||
      WERROR=1 python setup.py clean
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -67,13 +67,13 @@ function pip_install_whl() {
 | 
			
		||||
    # Loop through each path and install individually
 | 
			
		||||
    for path in "${paths[@]}"; do
 | 
			
		||||
      echo "Installing $path"
 | 
			
		||||
      python3 -mpip install --no-index --no-deps "$path"
 | 
			
		||||
      python3 -m pip install --no-index --no-deps "$path"
 | 
			
		||||
    done
 | 
			
		||||
  else
 | 
			
		||||
    # Loop through each argument and install individually
 | 
			
		||||
    for path in "${args[@]}"; do
 | 
			
		||||
      echo "Installing $path"
 | 
			
		||||
      python3 -mpip install --no-index --no-deps "$path"
 | 
			
		||||
      python3 -m pip install --no-index --no-deps "$path"
 | 
			
		||||
    done
 | 
			
		||||
  fi
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -182,7 +182,7 @@ checkout_install_torchbench() {
 | 
			
		||||
  pip uninstall -y torchao
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
  python -m pip freeze
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
torchbench_setup_macos() {
 | 
			
		||||
@ -211,7 +211,7 @@ torchbench_setup_macos() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pip_benchmark_deps() {
 | 
			
		||||
  python -mpip install --no-input requests cython scikit-learn six
 | 
			
		||||
  python -m pip install --no-input requests cython scikit-learn six
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1434,7 +1434,7 @@ EOF
 | 
			
		||||
  # shellcheck source=./common-build.sh
 | 
			
		||||
  source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
 | 
			
		||||
  python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist"
 | 
			
		||||
  python -mpip install base_dist/*.whl
 | 
			
		||||
  python -m pip install base_dist/*.whl
 | 
			
		||||
  echo "::endgroup::"
 | 
			
		||||
 | 
			
		||||
  pushd test/forward_backward_compatibility
 | 
			
		||||
 | 
			
		||||
@ -173,7 +173,7 @@ esac
 | 
			
		||||
PINNED_PACKAGES=(
 | 
			
		||||
    "numpy${NUMPY_PINNED_VERSION}"
 | 
			
		||||
)
 | 
			
		||||
python -mvenv ~/${desired_python}-build
 | 
			
		||||
python -m venv ~/${desired_python}-build
 | 
			
		||||
source ~/${desired_python}-build/bin/activate
 | 
			
		||||
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
 | 
			
		||||
retry brew install libomp
 | 
			
		||||
 | 
			
		||||
@ -163,13 +163,8 @@ if [[ "$(uname)" != Darwin ]]; then
 | 
			
		||||
  MEMORY_LIMIT_MAX_JOBS=12
 | 
			
		||||
  NUM_CPUS=$(( $(nproc) - 2 ))
 | 
			
		||||
 | 
			
		||||
  if [[ "$(uname)" == Linux ]]; then
 | 
			
		||||
    # Defaults here for **binary** linux builds so they can be changed in one place
 | 
			
		||||
    export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
 | 
			
		||||
  else
 | 
			
		||||
    # For other builds
 | 
			
		||||
    export MAX_JOBS=${NUM_CPUS}
 | 
			
		||||
  fi
 | 
			
		||||
  # Defaults here for **binary** linux builds so they can be changed in one place
 | 
			
		||||
  export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
 | 
			
		||||
 | 
			
		||||
  cat >>"$envfile" <<EOL
 | 
			
		||||
  export MAX_JOBS="${MAX_JOBS}"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								.flake8
									
									
									
									
									
								
							@ -7,12 +7,16 @@ max-line-length = 120
 | 
			
		||||
# C408 ignored because we like the dict keyword argument syntax
 | 
			
		||||
# E501 is not flexible enough, we're using B950 instead
 | 
			
		||||
ignore =
 | 
			
		||||
    E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
 | 
			
		||||
    E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
 | 
			
		||||
    # shebang has extra meaning in fbcode lints, so I think it's not worth trying
 | 
			
		||||
    # to line this up with executable bit
 | 
			
		||||
    EXE001,
 | 
			
		||||
    # these ignores are from flake8-bugbear; please fix!
 | 
			
		||||
    B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
 | 
			
		||||
    # these ignores are from flake8-comprehensions; please fix!
 | 
			
		||||
    C407,
 | 
			
		||||
    # these ignores are from flake8-logging-format; please fix!
 | 
			
		||||
    G100,G101,G200
 | 
			
		||||
    # these ignores are from flake8-simplify. please fix or ignore with commented reason
 | 
			
		||||
    SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
 | 
			
		||||
    # SIM104 is already covered by pyupgrade ruff
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							@ -124,10 +124,3 @@ runs:
 | 
			
		||||
      id: login-ecr
 | 
			
		||||
      continue-on-error: true
 | 
			
		||||
      uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
 | 
			
		||||
 | 
			
		||||
    - name: Preserve github env variables for use in docker
 | 
			
		||||
      shell: bash
 | 
			
		||||
      run: |
 | 
			
		||||
        env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
        env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
        env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
69bbe7363897764f9e758d851cd0340147d27f94
 | 
			
		||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
1752fe6809b74921644866275ab80244b96e80bc
 | 
			
		||||
faffd5cf673615583da6517275e361cb3dbc77e6
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							@ -133,32 +133,3 @@
 | 
			
		||||
 | 
			
		||||
"ciflow/vllm":
 | 
			
		||||
- .github/ci_commit_pins/vllm.txt
 | 
			
		||||
 | 
			
		||||
"ciflow/b200":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/h100":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/rocm":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -33,7 +33,6 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/rocm
 | 
			
		||||
- ciflow/rocm-mi300
 | 
			
		||||
- ciflow/rocm-mi355
 | 
			
		||||
- ciflow/rocm-navi31
 | 
			
		||||
- ciflow/s390
 | 
			
		||||
- ciflow/slow
 | 
			
		||||
- ciflow/torchbench
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
 | 
			
		||||
        "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "12.9": (
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
 | 
			
		||||
    ),
 | 
			
		||||
    "13.0": (
 | 
			
		||||
        "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/scripts/prepare_vllm_wheels.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/scripts/prepare_vllm_wheels.sh
									
									
									
									
										vendored
									
									
								
							@ -24,7 +24,7 @@ change_wheel_version() {
 | 
			
		||||
  local t_version=$4
 | 
			
		||||
 | 
			
		||||
  # Extract the wheel
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -mwheel unpack $wheel
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -m wheel unpack $wheel
 | 
			
		||||
 | 
			
		||||
  mv "${package}-${f_version}" "${package}-${t_version}"
 | 
			
		||||
  # Change the version from f_version to t_version in the dist-info dir
 | 
			
		||||
@ -47,7 +47,7 @@ change_wheel_version() {
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
  # Repack the wheel
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -mwheel pack "${package}-${t_version}"
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -m wheel pack "${package}-${t_version}"
 | 
			
		||||
 | 
			
		||||
  # Clean up
 | 
			
		||||
  rm -rf "${package}-${t_version}"
 | 
			
		||||
@ -85,7 +85,7 @@ repackage_wheel() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Require to re-package the wheel
 | 
			
		||||
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1
 | 
			
		||||
${PYTHON_EXECUTABLE} -m pip install wheel==0.45.1
 | 
			
		||||
 | 
			
		||||
pushd externals/vllm/wheels
 | 
			
		||||
for package in xformers flashinfer-python vllm; do
 | 
			
		||||
 | 
			
		||||
@ -26,8 +26,9 @@ name: !{{ build_environment }}
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}"
 | 
			
		||||
          python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}"
 | 
			
		||||
          freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
 | 
			
		||||
{%- endmacro %}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,9 +79,9 @@ jobs:
 | 
			
		||||
    runs-on: "windows-11-arm64-preview"
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    {%- if branches == "nightly" %}
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
    timeout-minutes: !{{ common.timeout_minutes_windows_binary }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/_mac-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/_mac-test.yml
									
									
									
									
										vendored
									
									
								
							@ -211,7 +211,7 @@ jobs:
 | 
			
		||||
            $tool --version
 | 
			
		||||
          done
 | 
			
		||||
 | 
			
		||||
          python3 -mpip install --no-index --no-deps dist/*.whl
 | 
			
		||||
          python3 -m pip install --no-index --no-deps dist/*.whl
 | 
			
		||||
 | 
			
		||||
          set +e
 | 
			
		||||
          pushd "${RUNNER_TEMP}"
 | 
			
		||||
@ -222,7 +222,7 @@ jobs:
 | 
			
		||||
          popd
 | 
			
		||||
 | 
			
		||||
          if [ "${RC}" -ne 0 ]; then
 | 
			
		||||
            python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
 | 
			
		||||
            python3 -m pip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
 | 
			
		||||
          fi
 | 
			
		||||
          set -e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/_win-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_win-test.yml
									
									
									
									
										vendored
									
									
								
							@ -204,7 +204,7 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
 | 
			
		||||
          # shellcheck disable=SC2046,SC2102
 | 
			
		||||
          python3 -mpip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
 | 
			
		||||
          python3 -m pip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
 | 
			
		||||
          popd
 | 
			
		||||
 | 
			
		||||
          .ci/pytorch/win-test.sh
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -126,13 +126,13 @@ jobs:
 | 
			
		||||
            "${MANYLINUX_IMAGE}"
 | 
			
		||||
          )
 | 
			
		||||
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip install \
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install \
 | 
			
		||||
            --pre torch torchvision torchaudio \
 | 
			
		||||
            --index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
 | 
			
		||||
 | 
			
		||||
          # I wonder if there is a command to both download and install the wheels
 | 
			
		||||
          # in one go
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip download \
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip download \
 | 
			
		||||
            --pre torch torchvision torchaudio \
 | 
			
		||||
            --index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -224,7 +224,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -473,7 +473,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -722,7 +722,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -971,7 +971,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1220,7 +1220,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1469,7 +1469,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1718,7 +1718,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -259,7 +259,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_10-cuda12_9-test:  # Testing
 | 
			
		||||
@ -925,7 +925,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_11-cuda12_9-test:  # Testing
 | 
			
		||||
@ -1591,7 +1591,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_12-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2257,7 +2257,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2923,7 +2923,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13t-cuda12_9-test:  # Testing
 | 
			
		||||
@ -3589,7 +3589,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14-cuda12_9-test:  # Testing
 | 
			
		||||
@ -4255,7 +4255,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14t-cuda12_9-test:  # Testing
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -63,6 +63,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.10.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -59,6 +59,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.10.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -105,7 +106,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -168,6 +169,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.11.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -214,7 +216,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -277,6 +279,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.12.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -323,7 +326,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -386,6 +389,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.13.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -432,7 +436,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -495,6 +499,7 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.13.4"
 | 
			
		||||
          freethreaded: true
 | 
			
		||||
@ -541,7 +546,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -604,8 +609,9 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.14.0"
 | 
			
		||||
          python-version: "3.14.0-rc.2"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
@ -650,7 +656,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -713,8 +719,9 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.14.0"
 | 
			
		||||
          python-version: "3.14.0-rc.2"
 | 
			
		||||
          freethreaded: true
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
@ -759,7 +766,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  libtorch-cpu-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -291,7 +291,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_6-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -541,7 +541,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_8-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -791,7 +791,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  libtorch-cpu-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -291,7 +291,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_6-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -541,7 +541,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda12_8-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -791,7 +791,7 @@ jobs:
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -44,7 +44,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -279,7 +279,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -517,7 +517,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -755,7 +755,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -993,7 +993,7 @@ jobs:
 | 
			
		||||
  wheel-py3_10-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1229,7 +1229,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1464,7 +1464,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1702,7 +1702,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -1940,7 +1940,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2178,7 +2178,7 @@ jobs:
 | 
			
		||||
  wheel-py3_11-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2414,7 +2414,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2649,7 +2649,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -2887,7 +2887,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3125,7 +3125,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3363,7 +3363,7 @@ jobs:
 | 
			
		||||
  wheel-py3_12-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3599,7 +3599,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -3834,7 +3834,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4072,7 +4072,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4310,7 +4310,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4548,7 +4548,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -4784,7 +4784,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5019,7 +5019,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5257,7 +5257,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5495,7 +5495,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5733,7 +5733,7 @@ jobs:
 | 
			
		||||
  wheel-py3_13t-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -5969,7 +5969,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6204,7 +6204,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6442,7 +6442,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6680,7 +6680,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -6918,7 +6918,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7154,7 +7154,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7389,7 +7389,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7627,7 +7627,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -7865,7 +7865,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-cuda13_0-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -8103,7 +8103,7 @@ jobs:
 | 
			
		||||
  wheel-py3_14t-xpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -88,6 +88,7 @@ jobs:
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -147,16 +147,15 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
 | 
			
		||||
      cuda-arch-list: 8.9
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
          { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							@ -347,8 +347,7 @@ jobs:
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      # This should sync with the build in xpu.yml but xpu uses a larger runner
 | 
			
		||||
      # sync-tag: linux-xpu-n-build
 | 
			
		||||
      sync-tag: linux-xpu-n-build
 | 
			
		||||
      runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
 | 
			
		||||
      build-environment: linux-jammy-xpu-n-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							@ -45,6 +45,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-noble-rocm-py3.12-mi300
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							@ -42,6 +42,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-noble-rocm-py3.12-mi355
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										75
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										75
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							@ -1,75 +0,0 @@
 | 
			
		||||
name: rocm-navi31
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/rocm-navi31/*
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  schedule:
 | 
			
		||||
    # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
 | 
			
		||||
    # Also run less frequently on weekends.
 | 
			
		||||
    - cron: 45 */2 * * 1-5
 | 
			
		||||
    - cron: 45 4,12 * * 0,6
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  target-determination:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: before-test
 | 
			
		||||
    uses: ./.github/workflows/target_determination.yml
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3_10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
      tests-to-include: >-
 | 
			
		||||
         ${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
 | 
			
		||||
         test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
 | 
			
		||||
         inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
 | 
			
		||||
         inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
 | 
			
		||||
         inductor/test_flex_attention inductor/test_max_autotune' || '' }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							@ -26,23 +26,11 @@ jobs:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
@ -71,3 +59,29 @@ jobs:
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-gfx1100-test:
 | 
			
		||||
    if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3_10-gfx1100
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
 | 
			
		||||
        ]}
 | 
			
		||||
      tests-to-include: >
 | 
			
		||||
         test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
 | 
			
		||||
         test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
 | 
			
		||||
         inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
 | 
			
		||||
         inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
 | 
			
		||||
         inductor/test_flex_attention inductor/test_max_autotune
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							@ -58,10 +58,8 @@ jobs:
 | 
			
		||||
          else
 | 
			
		||||
            COMMIT_SHA="${{ github.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
          {
 | 
			
		||||
            echo "sha=${COMMIT_SHA}"
 | 
			
		||||
            echo "tag_name=trunk/${COMMIT_SHA}"
 | 
			
		||||
          } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
          echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
 | 
			
		||||
      - name: Validate commit SHA
 | 
			
		||||
        run: |
 | 
			
		||||
@ -89,7 +87,7 @@ jobs:
 | 
			
		||||
            echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Create and push tag(s) with retry
 | 
			
		||||
      - name: Create and push tag with retry
 | 
			
		||||
        id: check_tag
 | 
			
		||||
        env:
 | 
			
		||||
          TAG_NAME: ${{ steps.commit.outputs.tag_name }}
 | 
			
		||||
@ -114,23 +112,14 @@ jobs:
 | 
			
		||||
            return 1
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # Counters for summary reporting
 | 
			
		||||
          created_count=0
 | 
			
		||||
          skipped_count=0
 | 
			
		||||
          failed_count=0
 | 
			
		||||
          # Exit early if tag already exists
 | 
			
		||||
          if check_tag_exists; then
 | 
			
		||||
            echo "✅ Tag already exists - no action needed"
 | 
			
		||||
            echo "exists=true" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # Always write outputs once on exit
 | 
			
		||||
          finish() {
 | 
			
		||||
            set +e
 | 
			
		||||
            if [ -n "${GITHUB_OUTPUT:-}" ]; then
 | 
			
		||||
              {
 | 
			
		||||
                echo "created_count=${created_count}"
 | 
			
		||||
                echo "skipped_count=${skipped_count}"
 | 
			
		||||
                echo "failed_count=${failed_count}"
 | 
			
		||||
              } >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            fi
 | 
			
		||||
          }
 | 
			
		||||
          trap finish EXIT
 | 
			
		||||
          echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
          # Retry configuration
 | 
			
		||||
          MAX_RETRIES=5
 | 
			
		||||
@ -205,111 +194,31 @@ jobs:
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          # New behavior for push events: enumerate commits in the push and tag each one.
 | 
			
		||||
          # For workflow_dispatch, retain existing single-SHA behavior.
 | 
			
		||||
 | 
			
		||||
          # Always fetch tags once up front to improve idempotency in loops
 | 
			
		||||
          git fetch origin --tags --quiet || true
 | 
			
		||||
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            BEFORE_SHA="${{ github.event.before }}"
 | 
			
		||||
            AFTER_SHA="${{ github.sha }}"  # same as event.after
 | 
			
		||||
 | 
			
		||||
            # List commits introduced by this push (old..new), oldest first for stable ordering
 | 
			
		||||
            commits_file="$(mktemp)"
 | 
			
		||||
            git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ ! -s "${commits_file}" ]; then
 | 
			
		||||
              echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
 | 
			
		||||
              rm -f "${commits_file}"
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
 | 
			
		||||
            echo "Found ${commit_count} commit(s) to tag for push:"
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              printf '  %s\n' "${sha}"
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            while IFS= read -r sha; do
 | 
			
		||||
              TAG_NAME="trunk/${sha}"
 | 
			
		||||
              COMMIT_SHA="${sha}"
 | 
			
		||||
 | 
			
		||||
              # If tag already exists locally or remotely, skip (idempotent)
 | 
			
		||||
              if check_tag_exists; then
 | 
			
		||||
                echo "✅ Tag ${TAG_NAME} already exists - skipping"
 | 
			
		||||
                skipped_count=$((skipped_count + 1))
 | 
			
		||||
                continue
 | 
			
		||||
              fi
 | 
			
		||||
 | 
			
		||||
              echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
              if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
                created_count=$((created_count + 1))
 | 
			
		||||
              else
 | 
			
		||||
                echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
 | 
			
		||||
                failed_count=$((failed_count + 1))
 | 
			
		||||
              fi
 | 
			
		||||
            done < "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            rm -f "${commits_file}"
 | 
			
		||||
 | 
			
		||||
            if [ "${failed_count}" -gt 0 ]; then
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
          # Execute with retry
 | 
			
		||||
          if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
            echo "exists=false" >> "${GITHUB_OUTPUT}"
 | 
			
		||||
            exit 0
 | 
			
		||||
          else
 | 
			
		||||
            # workflow_dispatch path (single SHA tagging preserved)
 | 
			
		||||
 | 
			
		||||
            # Exit early if tag already exists
 | 
			
		||||
            if check_tag_exists; then
 | 
			
		||||
              echo "✅ Tag already exists - no action needed"
 | 
			
		||||
              skipped_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
 | 
			
		||||
 | 
			
		||||
            if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
 | 
			
		||||
              created_count=1
 | 
			
		||||
              exit 0
 | 
			
		||||
            else
 | 
			
		||||
              echo "Tag creation failed after all retry attempts"
 | 
			
		||||
              failed_count=1
 | 
			
		||||
              exit 1
 | 
			
		||||
            fi
 | 
			
		||||
            echo "Tag creation failed after all retry attempts"
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Tag creation summary
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          if [ "${{ github.event_name }}" = "push" ]; then
 | 
			
		||||
            echo "Trigger: push on main"
 | 
			
		||||
            echo "Created: ${{ steps.check_tag.outputs.created_count }}"
 | 
			
		||||
            echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
 | 
			
		||||
            echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
          if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
 | 
			
		||||
            echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
          elif [ "${{ job.status }}" = "success" ]; then
 | 
			
		||||
            echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          else
 | 
			
		||||
            if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
 | 
			
		||||
              if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
 | 
			
		||||
                echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
 | 
			
		||||
              else
 | 
			
		||||
                echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
              fi
 | 
			
		||||
            else
 | 
			
		||||
              echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            echo ""
 | 
			
		||||
            echo "Tag details:"
 | 
			
		||||
            echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
            echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
            echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
            if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
              echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
            fi
 | 
			
		||||
            echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          echo ""
 | 
			
		||||
          echo "Tag details:"
 | 
			
		||||
          echo "  Name: ${{ steps.commit.outputs.tag_name }}"
 | 
			
		||||
          echo "  Commit: ${{ steps.commit.outputs.sha }}"
 | 
			
		||||
          echo "  Trigger: ${{ github.event_name }}"
 | 
			
		||||
          if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
 | 
			
		||||
            echo "  Manual commit: ${{ github.event.inputs.commit_sha }}"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										34
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -190,40 +190,6 @@ jobs:
 | 
			
		||||
      runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
      tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  inductor-build:
 | 
			
		||||
    name: inductor-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -374,7 +374,6 @@ third_party/ruy/
 | 
			
		||||
third_party/glog/
 | 
			
		||||
 | 
			
		||||
# Virtualenv
 | 
			
		||||
.venv/
 | 
			
		||||
venv/
 | 
			
		||||
 | 
			
		||||
# Log files
 | 
			
		||||
 | 
			
		||||
@ -1138,8 +1138,11 @@ command = [
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'WORKFLOWSYNC'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    '.github/workflows/*.yml',
 | 
			
		||||
    '.github/workflows/*.yaml',
 | 
			
		||||
    '.github/workflows/pull.yml',
 | 
			
		||||
    '.github/workflows/trunk.yml',
 | 
			
		||||
    '.github/workflows/periodic.yml',
 | 
			
		||||
    '.github/workflows/mac-mps.yml',
 | 
			
		||||
    '.github/workflows/slow.yml',
 | 
			
		||||
]
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							@ -201,17 +201,3 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
 | 
			
		||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
 | 
			
		||||
/torch/headeronly/ @janeyx99
 | 
			
		||||
/torch/header_only_apis.txt @janeyx99
 | 
			
		||||
 | 
			
		||||
# FlexAttention
 | 
			
		||||
/torch/nn/attention/flex_attention.py @drisspg
 | 
			
		||||
/torch/_higher_order_ops/flex_attention.py @drisspg
 | 
			
		||||
/torch/_inductor/kernel/flex/ @drisspg
 | 
			
		||||
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_attention.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_decoding.py @drisspg
 | 
			
		||||
 | 
			
		||||
# Low Precision GEMMs
 | 
			
		||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
 | 
			
		||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ RUN chmod +x ~/miniconda.sh && \
 | 
			
		||||
    bash ~/miniconda.sh -b -p /opt/conda && \
 | 
			
		||||
    rm ~/miniconda.sh && \
 | 
			
		||||
    /opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
 | 
			
		||||
    /opt/conda/bin/python -mpip install -r requirements.txt && \
 | 
			
		||||
    /opt/conda/bin/python -m pip install -r requirements.txt && \
 | 
			
		||||
    /opt/conda/bin/conda clean -ya
 | 
			
		||||
 | 
			
		||||
FROM dev-base as submodule-update
 | 
			
		||||
 | 
			
		||||
@ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
 | 
			
		||||
    set(fbgemm_genai_cuh
 | 
			
		||||
    set(fbgemm_genai_mx8mx8bf16_grouped
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_cuh}
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
@ -314,14 +313,13 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    # Add additional HIPCC compiler flags for performance
 | 
			
		||||
    set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -amdgpu-coerce-illegal-types=1
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -enable-post-misched=0
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -greedy-reverse-local-assignment=1
 | 
			
		||||
      -fhip-new-launch-api)
 | 
			
		||||
    if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
 | 
			
		||||
        list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
 | 
			
		||||
      endif()
 | 
			
		||||
 | 
			
		||||
    # Only compile for gfx942 for now.
 | 
			
		||||
    # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@
 | 
			
		||||
#include <ATen/detail/MPSHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/MTIAHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
 | 
			
		||||
#include <ATen/detail/XLAHooksInterface.h>
 | 
			
		||||
#include <ATen/detail/XPUHooksInterface.h>
 | 
			
		||||
#include <c10/core/QEngine.h>
 | 
			
		||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
 | 
			
		||||
@ -89,8 +88,6 @@ class TORCH_API Context {
 | 
			
		||||
      return at::detail::getHIPHooks();
 | 
			
		||||
    } else if (opt_device_type == at::kHPU) {
 | 
			
		||||
      return at::detail::getHPUHooks();
 | 
			
		||||
    } else if (opt_device_type == at::kXLA) {
 | 
			
		||||
      return at::detail::getXLAHooks();
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
          false,
 | 
			
		||||
@ -199,7 +196,7 @@ class TORCH_API Context {
 | 
			
		||||
    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasXLA() {
 | 
			
		||||
    return detail::getXLAHooks().hasXLA();
 | 
			
		||||
    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasXPU() {
 | 
			
		||||
    return detail::getXPUHooks().hasXPU();
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ struct HostBlock {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename B>
 | 
			
		||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
 | 
			
		||||
struct alignas(64) FreeBlockList {
 | 
			
		||||
  std::mutex mutex_;
 | 
			
		||||
  std::deque<B*> list_;
 | 
			
		||||
};
 | 
			
		||||
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
 | 
			
		||||
// Struct containing memory allocator summary statistics for host, as they
 | 
			
		||||
// are staged for reporting. This is a temporary struct that is used to
 | 
			
		||||
// avoid locking the allocator while collecting stats.
 | 
			
		||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
 | 
			
		||||
struct alignas(64) HostStatsStaged {
 | 
			
		||||
  std::mutex timing_mutex_;
 | 
			
		||||
  // COUNT: total allocations (active + free)
 | 
			
		||||
  // LOCK: access to this stat is protected by the allocator's blocks_mutex_
 | 
			
		||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
 | 
			
		||||
  alignas(64) std::mutex blocks_mutex_;
 | 
			
		||||
  ska::flat_hash_set<B*> blocks_; // block list
 | 
			
		||||
  ska::flat_hash_map<void*, B*> ptr_to_block_;
 | 
			
		||||
 | 
			
		||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
  // size. This allows us to quickly find a free block of the right size.
 | 
			
		||||
  // We use deque to store per size free list and guard the list with its own
 | 
			
		||||
  // mutex.
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
 | 
			
		||||
  alignas(64) std::vector<FreeBlockList<B>> free_list_ =
 | 
			
		||||
      std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
 | 
			
		||||
 | 
			
		||||
  alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
 | 
			
		||||
  alignas(64) std::mutex events_mutex_;
 | 
			
		||||
  std::deque<std::pair<E, B*>> events_; // event queue paired with block
 | 
			
		||||
 | 
			
		||||
  // Indicates whether the object is active.
 | 
			
		||||
  // Set to false in the destructor to signal background threads to stop.
 | 
			
		||||
  std::atomic<bool> active_{true};
 | 
			
		||||
protected:
 | 
			
		||||
  alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
 | 
			
		||||
  alignas(64) HostStatsStaged stats_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API HostAllocator : public at::Allocator {
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,9 @@ struct TORCH_API Generator {
 | 
			
		||||
 | 
			
		||||
  explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
 | 
			
		||||
   : impl_(std::move(gen_impl)) {
 | 
			
		||||
    TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
 | 
			
		||||
    if (impl_.get() == nullptr) {
 | 
			
		||||
      throw std::runtime_error("GeneratorImpl with nullptr is not supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool operator==(const Generator& rhs) const {
 | 
			
		||||
 | 
			
		||||
@ -111,7 +111,9 @@ class TORCH_API TensorBase {
 | 
			
		||||
  explicit TensorBase(
 | 
			
		||||
      c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
 | 
			
		||||
      : impl_(std::move(tensor_impl)) {
 | 
			
		||||
    TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
 | 
			
		||||
    if (impl_.get() == nullptr) {
 | 
			
		||||
      throw std::runtime_error("TensorImpl with nullptr is not supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TensorBase(const TensorBase&) = default;
 | 
			
		||||
  TensorBase(TensorBase&&) noexcept = default;
 | 
			
		||||
 | 
			
		||||
@ -68,7 +68,11 @@ Symbol InternedStrings::_symbol(const std::string& s) {
 | 
			
		||||
    return it->second;
 | 
			
		||||
 | 
			
		||||
  auto pos = s.find("::");
 | 
			
		||||
  TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
 | 
			
		||||
  if (pos == std::string::npos) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
 | 
			
		||||
    throw std::runtime_error(ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
 | 
			
		||||
 | 
			
		||||
  Symbol sym(sym_to_info_.size());
 | 
			
		||||
@ -117,7 +121,12 @@ std::string Symbol::domainString() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
 | 
			
		||||
  TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
 | 
			
		||||
  if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
 | 
			
		||||
    std::ostringstream ss;
 | 
			
		||||
    ss << "Symbol: domain string is expected to be prefixed with '"
 | 
			
		||||
       << domain_prefix() << "', e.g. 'org.pytorch.aten'";
 | 
			
		||||
    throw std::runtime_error(ss.str());
 | 
			
		||||
  }
 | 
			
		||||
  std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
 | 
			
		||||
  return fromQualString(qualString);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@
 | 
			
		||||
#include <ATen/core/jit_type.h>
 | 
			
		||||
#include <ATen/core/stack.h>
 | 
			
		||||
#include <ATen/core/type_factory.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
#include <c10/util/hash.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
@ -413,7 +412,7 @@ size_t IValue::hash(const IValue& v) {
 | 
			
		||||
    case Tag::Enum:
 | 
			
		||||
    case Tag::Stream:
 | 
			
		||||
    case Tag::Uninitialized:
 | 
			
		||||
      TORCH_CHECK(false,
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "unhashable type: '" + v.type()->repr_str() + "'");
 | 
			
		||||
  }
 | 
			
		||||
  // the above switch should be exhaustive
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@
 | 
			
		||||
#include <ATen/core/type_factory.h>
 | 
			
		||||
#include <ATen/core/qualified_name.h>
 | 
			
		||||
#include <c10/util/TypeList.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <c10/core/SymFloat.h>
 | 
			
		||||
#include <c10/core/SymBool.h>
 | 
			
		||||
@ -117,8 +116,10 @@ struct SingleElementType : public SharedType {
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
 | 
			
		||||
    TORCH_CHECK(this->elem, c10::str(
 | 
			
		||||
    if (!this->elem) {
 | 
			
		||||
      throw std::runtime_error(c10::str(
 | 
			
		||||
            "Can not create ", typeKindToString(Kind), " with None type"));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
@ -415,12 +416,16 @@ struct TORCH_API SymbolicShape {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ShapeSymbol operator[](size_t i) const {
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ShapeSymbol at(size_t i) const {
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -515,7 +520,9 @@ struct VaryingShape {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const std::optional<T> &operator[](size_t i) const {
 | 
			
		||||
    TORCH_CHECK(dims_, "Rank isn't fixed");
 | 
			
		||||
    if (!dims_) {
 | 
			
		||||
      throw std::runtime_error("Rank isn't fixed");
 | 
			
		||||
    }
 | 
			
		||||
    return (*dims_).at(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -950,7 +957,9 @@ struct TORCH_API DictType : public SharedType {
 | 
			
		||||
 | 
			
		||||
  TypePtr createWithContained(
 | 
			
		||||
      std::vector<TypePtr> contained_types) const override {
 | 
			
		||||
    TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
 | 
			
		||||
    if (contained_types.size() != 2) {
 | 
			
		||||
      throw std::runtime_error("Expected 2 contained types");
 | 
			
		||||
    }
 | 
			
		||||
    return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@
 | 
			
		||||
#include <ATen/core/jit_type.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <array>
 | 
			
		||||
@ -827,7 +826,9 @@ TupleType::TupleType(
 | 
			
		||||
    : NamedType(TypeKind::TupleType, std::move(name)),
 | 
			
		||||
      elements_(std::move(elements)),
 | 
			
		||||
      has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
 | 
			
		||||
        TORCH_CHECK(v, "Can not create tuple with None type");
 | 
			
		||||
        if (!v) {
 | 
			
		||||
          throw std::runtime_error("Can not create tuple with None type");
 | 
			
		||||
        }
 | 
			
		||||
        return v->hasFreeVariables();
 | 
			
		||||
      })), schema_(std::move(schema)) {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,11 +6,9 @@
 | 
			
		||||
#ifdef __aarch64__
 | 
			
		||||
#if !defined(CPU_CAPABILITY_SVE)
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
 | 
			
		||||
 | 
			
		||||
@ -354,47 +354,9 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
 | 
			
		||||
  Vectorized frac() const;
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
 | 
			
		||||
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  Vectorized<c10::BFloat16> neg() const {
 | 
			
		||||
    return -values;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<c10::BFloat16> reciprocal() const {
 | 
			
		||||
    return 1.0f / values;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<c10::BFloat16> operator==(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values == other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator!=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values != other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values < other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values <= other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values > other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values >= other.values;
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
 | 
			
		||||
@ -402,7 +364,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
@ -451,52 +412,28 @@ template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator+(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x + y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator-(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x - y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator*(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x * y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator/(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
@ -607,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also,
 | 
			
		||||
  // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
 | 
			
		||||
  // elements, not the bottom and top half, so they don't seem
 | 
			
		||||
  // particularly useful here. Ideally we would include dot product in
 | 
			
		||||
  // the Vectorized interface...
 | 
			
		||||
  return a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -627,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -643,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -659,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
 | 
			
		||||
 | 
			
		||||
@ -1,586 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <cmath>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
class Vectorized<double> {
 | 
			
		||||
 private:
 | 
			
		||||
  float64x2_t values;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using value_type = double;
 | 
			
		||||
  using size_type = int;
 | 
			
		||||
  static constexpr size_type size() {
 | 
			
		||||
    return 2;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized() {
 | 
			
		||||
    values = vdupq_n_f64(0.0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized(float64x2_t v) : values(v) {}
 | 
			
		||||
  Vectorized(double val) {
 | 
			
		||||
    values = vdupq_n_f64(val);
 | 
			
		||||
  }
 | 
			
		||||
  template <
 | 
			
		||||
      typename... Args,
 | 
			
		||||
      typename = std::enable_if_t<(sizeof...(Args) == size())>>
 | 
			
		||||
  Vectorized(Args... vals) {
 | 
			
		||||
    __at_align__ double buffer[size()] = {vals...};
 | 
			
		||||
    values = vld1q_f64(buffer);
 | 
			
		||||
  }
 | 
			
		||||
  operator float64x2_t() const {
 | 
			
		||||
    return values;
 | 
			
		||||
  }
 | 
			
		||||
  template <int64_t mask>
 | 
			
		||||
  static Vectorized<double> blend(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b) {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint64x2_t maskArray = {
 | 
			
		||||
        (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
 | 
			
		||||
        (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_f64(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<double> blendv(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      const Vectorized<double>& mask_) {
 | 
			
		||||
    return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename step_t>
 | 
			
		||||
  static Vectorized<double> arange(
 | 
			
		||||
      double base = 0.,
 | 
			
		||||
      step_t step = static_cast<step_t>(1)) {
 | 
			
		||||
    return {base, base + static_cast<double>(step)};
 | 
			
		||||
  }
 | 
			
		||||
  static inline Vectorized<double> set(
 | 
			
		||||
      const Vectorized<double>& a,
 | 
			
		||||
      const Vectorized<double>& b,
 | 
			
		||||
      int64_t count = size()) {
 | 
			
		||||
    if (count == 0) {
 | 
			
		||||
      return a;
 | 
			
		||||
    } else if (count >= 2) {
 | 
			
		||||
      return b;
 | 
			
		||||
    } else {
 | 
			
		||||
      float64x2_t c = {b.values[0], a.values[1]};
 | 
			
		||||
      return c;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
 | 
			
		||||
    if (count == size()) {
 | 
			
		||||
      return vld1q_f64(reinterpret_cast<const double*>(ptr));
 | 
			
		||||
    } else if (count == 1) {
 | 
			
		||||
      float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
 | 
			
		||||
      float64x1_t z = {0.0};
 | 
			
		||||
      return vcombine_f64(x, z);
 | 
			
		||||
    } else {
 | 
			
		||||
      return vdupq_n_f64(0.0);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  void store(void* ptr, int64_t count = size()) const {
 | 
			
		||||
    if (count == size()) {
 | 
			
		||||
      vst1q_f64(reinterpret_cast<double*>(ptr), values);
 | 
			
		||||
    } else if (count == 1) {
 | 
			
		||||
      vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const double& operator[](int idx) const = delete;
 | 
			
		||||
  double& operator[](int idx) = delete;
 | 
			
		||||
  int64_t zero_mask() const {
 | 
			
		||||
    // returns an integer mask where all zero elements are translated to 1-bit
 | 
			
		||||
    // and others are translated to 0-bit
 | 
			
		||||
    uint64x2_t cmpReg = vceqzq_f64(values);
 | 
			
		||||
    uint64x2_t mask = {1, 2};
 | 
			
		||||
    uint64x2_t res = vandq_u64(cmpReg, mask);
 | 
			
		||||
    return res[0] | res[1];
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> isnan() const {
 | 
			
		||||
    // NaN check
 | 
			
		||||
    return vreinterpretq_f64_u32(
 | 
			
		||||
        vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
 | 
			
		||||
  }
 | 
			
		||||
  bool has_inf_nan() const {
 | 
			
		||||
    Vectorized<double> x = vsubq_f64(values, values);
 | 
			
		||||
    float64x2_t r = x.isnan();
 | 
			
		||||
    uint64x2_t u = vreinterpretq_u64_f64(r);
 | 
			
		||||
    return u[0] | u[1];
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> map(double (*f)(double)) const {
 | 
			
		||||
    float64x2_t result;
 | 
			
		||||
    result[0] = f(values[0]);
 | 
			
		||||
    result[1] = f(values[1]);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> map2(
 | 
			
		||||
      const Vectorized<double>& second,
 | 
			
		||||
      double (*const f)(double, double)) const {
 | 
			
		||||
    float64x2_t result;
 | 
			
		||||
    result[0] = f(values[0], second.values[0]);
 | 
			
		||||
    result[1] = f(values[1], second.values[1]);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> abs() const {
 | 
			
		||||
    return vabsq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> angle() const {
 | 
			
		||||
    auto zero = Vectorized<double>(0.0);
 | 
			
		||||
    auto pi = Vectorized<double>(c10::pi<double>);
 | 
			
		||||
    auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
 | 
			
		||||
    return blendv(tmp, *this, isnan());
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> real() const {
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> imag() const {
 | 
			
		||||
    return Vectorized<double>(0.0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> conj() const {
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> acos() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> acosh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> asin() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> asinh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atan() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atanh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::atan2(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> copysign(const Vectorized<double>& sign) const {
 | 
			
		||||
      USE_SLEEF(
 | 
			
		||||
          { return Vectorized<double>(Sleef_copysignd2(values, sign)); },
 | 
			
		||||
          {
 | 
			
		||||
            __at_align__ double tmp[size()];
 | 
			
		||||
            __at_align__ double tmp_sign[size()];
 | 
			
		||||
            store(tmp);
 | 
			
		||||
            sign.store(tmp_sign);
 | 
			
		||||
            for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
              tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
 | 
			
		||||
            }
 | 
			
		||||
            return loadu(tmp);
 | 
			
		||||
          })} Vectorized<double> erf() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> erfc() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp2() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> expm1() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_fmodd2(values, q)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_q[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        q.store(tmp_q);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::fmod(tmp[i], tmp_q[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> hypot(const Vectorized<double>& b) const {
 | 
			
		||||
      USE_SLEEF(
 | 
			
		||||
          { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
 | 
			
		||||
          {
 | 
			
		||||
            __at_align__ double tmp[size()];
 | 
			
		||||
            __at_align__ double tmp_b[size()];
 | 
			
		||||
            store(tmp);
 | 
			
		||||
            b.store(tmp_b);
 | 
			
		||||
            for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
              tmp[i] = std::hypot(tmp[i], tmp_b[i]);
 | 
			
		||||
            }
 | 
			
		||||
            return loadu(tmp);
 | 
			
		||||
          })} Vectorized<double> i0() const {
 | 
			
		||||
    return map(calc_i0);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_nextafterd2(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); ++i) {
 | 
			
		||||
          tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} Vectorized<double> log() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log2() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log10() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> log1p() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> frac() const;
 | 
			
		||||
  Vectorized<double> sin() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> sinh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> cos() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> cosh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
 | 
			
		||||
      { return Vectorized<double>(Sleef_powd2_u10(values, b)); },
 | 
			
		||||
      {
 | 
			
		||||
        __at_align__ double tmp[size()];
 | 
			
		||||
        __at_align__ double tmp_b[size()];
 | 
			
		||||
        store(tmp);
 | 
			
		||||
        b.store(tmp_b);
 | 
			
		||||
        for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
          tmp[i] = std::pow(tmp[i], tmp_b[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return loadu(tmp);
 | 
			
		||||
      })} // Comparison using the _CMP_**_OQ predicate.
 | 
			
		||||
          //   `O`: get false if an operand is NaN
 | 
			
		||||
          //   `Q`: do not raise if an operand is NaN
 | 
			
		||||
  Vectorized<double> tan() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> tanh() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> lgamma() const {
 | 
			
		||||
    return USE_SLEEF(
 | 
			
		||||
        Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> erfinv() const {
 | 
			
		||||
    return map(calc_erfinv);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> exp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> fexp_u20() const {
 | 
			
		||||
    return exp();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> i0e() const {
 | 
			
		||||
    return map(calc_i0e);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> digamma() const {
 | 
			
		||||
    return map(calc_digamma);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> igamma(const Vectorized<double>& x) const {
 | 
			
		||||
    __at_align__ double tmp[size()];
 | 
			
		||||
    __at_align__ double tmp_x[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    x.store(tmp_x);
 | 
			
		||||
    for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return loadu(tmp);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> igammac(const Vectorized<double>& x) const {
 | 
			
		||||
    __at_align__ double tmp[size()];
 | 
			
		||||
    __at_align__ double tmp_x[size()];
 | 
			
		||||
    store(tmp);
 | 
			
		||||
    x.store(tmp_x);
 | 
			
		||||
    for (int64_t i = 0; i < size(); i++) {
 | 
			
		||||
      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return loadu(tmp);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> ceil() const {
 | 
			
		||||
    return vrndpq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> floor() const {
 | 
			
		||||
    return vrndmq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> neg() const {
 | 
			
		||||
    return vnegq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> round() const {
 | 
			
		||||
    return vrndiq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> trunc() const {
 | 
			
		||||
    return vrndq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> sqrt() const {
 | 
			
		||||
    return vsqrtq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> reciprocal() const {
 | 
			
		||||
    return vdivq_f64(vdupq_n_f64(1.0), values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> rsqrt() const {
 | 
			
		||||
    return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
 | 
			
		||||
  }
 | 
			
		||||
  double reduce_add() const {
 | 
			
		||||
    return vaddvq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  double reduce_max() const {
 | 
			
		||||
    return vmaxvq_f64(values);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<double> operator==(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator!=(const Vectorized<double>& other) const {
 | 
			
		||||
    float64x2_t r0 = vreinterpretq_f64_u32(
 | 
			
		||||
        vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
 | 
			
		||||
    return Vectorized<double>(r0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator<(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator<=(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator>(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> operator>=(const Vectorized<double>& other) const {
 | 
			
		||||
    return Vectorized<double>(
 | 
			
		||||
        vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<double> eq(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> ne(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> gt(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> ge(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> lt(const Vectorized<double>& other) const;
 | 
			
		||||
  Vectorized<double> le(const Vectorized<double>& other) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator+(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vaddq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator-(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vsubq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator*(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vmulq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator/(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vdivq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
Vectorized<double> inline Vectorized<double>::frac() const {
 | 
			
		||||
  return *this - this->trunc();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
 | 
			
		||||
// either input is a NaN.
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline maximum(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vmaxq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
 | 
			
		||||
// either input is a NaN.
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline minimum(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vminq_f64(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& min,
 | 
			
		||||
    const Vectorized<double>& max) {
 | 
			
		||||
  return vminq_f64(max, vmaxq_f64(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp_max(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& max) {
 | 
			
		||||
  return vminq_f64(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline clamp_min(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& min) {
 | 
			
		||||
  return vmaxq_f64(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator&(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator|(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline operator^(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b) {
 | 
			
		||||
  return vreinterpretq_f64_u64(
 | 
			
		||||
      veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::eq(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this == other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::ne(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this != other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::gt(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this > other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::ge(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this >= other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::lt(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this < other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<double> Vectorized<double>::le(
 | 
			
		||||
    const Vectorized<double>& other) const {
 | 
			
		||||
  return (*this <= other) & Vectorized<double>(1.0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fmadd(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmaq_f64(c, a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fnmadd(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmsq_f64(c, a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fmsub(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmaq_f64(vnegq_f64(c), a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<double> inline fnmsub(
 | 
			
		||||
    const Vectorized<double>& a,
 | 
			
		||||
    const Vectorized<double>& b,
 | 
			
		||||
    const Vectorized<double>& c) {
 | 
			
		||||
  return vfmsq_f64(vnegq_f64(c), a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1,378 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cpu/vec/intrinsics.h>
 | 
			
		||||
#include <ATen/cpu/vec/vec_base.h>
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
namespace at::vec {
 | 
			
		||||
// Note [CPU_CAPABILITY namespace]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// This header, and all of its subheaders, will be compiled with
 | 
			
		||||
// different architecture flags for each supported set of vector
 | 
			
		||||
// intrinsics. So we need to make sure they aren't inadvertently
 | 
			
		||||
// linked together. We do this by declaring objects in an `inline
 | 
			
		||||
// namespace` which changes the name mangling, but can still be
 | 
			
		||||
// accessed as `at::vec`.
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
 | 
			
		||||
                                                                              \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  class Vectorized<uint##bit##_t> {                                           \
 | 
			
		||||
    using neon_type = uint##bit##x##vl##_t;                                   \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   private:                                                                   \
 | 
			
		||||
    neon_type values;                                                         \
 | 
			
		||||
                                                                              \
 | 
			
		||||
   public:                                                                    \
 | 
			
		||||
    using value_type = uint##bit##_t;                                         \
 | 
			
		||||
    using size_type = int;                                                    \
 | 
			
		||||
    static constexpr size_type size() {                                       \
 | 
			
		||||
      return vl;                                                              \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized() {                                                            \
 | 
			
		||||
      values = vdupq_n_u##bit(0);                                             \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized(neon_type v) : values(v) {}                                    \
 | 
			
		||||
    Vectorized(uint##bit##_t val);                                            \
 | 
			
		||||
    template <                                                                \
 | 
			
		||||
        typename... Args,                                                     \
 | 
			
		||||
        typename = std::enable_if_t<(sizeof...(Args) == size())>>             \
 | 
			
		||||
    Vectorized(Args... vals) {                                                \
 | 
			
		||||
      __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \
 | 
			
		||||
      values = vld1q_u##bit(buffer);                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    operator neon_type() const {                                              \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    static Vectorized<uint##bit##_t> loadu(                                   \
 | 
			
		||||
        const void* ptr,                                                      \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    void store(void* ptr, uint64_t count = size()) const;                     \
 | 
			
		||||
    template <uint64_t mask>                                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blend(                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b);                                  \
 | 
			
		||||
    static Vectorized<uint##bit##_t> blendv(                                  \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& mask_) {                             \
 | 
			
		||||
      return vbslq_u##bit(mask_.values, b, a);                                \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    template <typename step_t>                                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> arange(                                  \
 | 
			
		||||
        value_type base = 0,                                                  \
 | 
			
		||||
        step_t step = static_cast<step_t>(1));                                \
 | 
			
		||||
    static Vectorized<uint##bit##_t> set(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& a,                                   \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& b,                                   \
 | 
			
		||||
        uint64_t count = size());                                             \
 | 
			
		||||
    const uint##bit##_t& operator[](uint idx) const = delete;                 \
 | 
			
		||||
    uint##bit##_t& operator[](uint idx) = delete;                             \
 | 
			
		||||
    Vectorized<uint##bit##_t> abs() const {                                   \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> real() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> imag() const {                                  \
 | 
			
		||||
      return vdupq_n_u##bit(0);                                               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> conj() const {                                  \
 | 
			
		||||
      return values;                                                          \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> neg() const {                                   \
 | 
			
		||||
      return vreinterpretq_u##bit##_s##bit(                                   \
 | 
			
		||||
          vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_add() const {                                        \
 | 
			
		||||
      return vaddvq_u##bit(values);                                           \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    uint##bit##_t reduce_max() const;                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator==(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator!=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator<=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>(                                      \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> operator>=(                                     \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const {                       \
 | 
			
		||||
      return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \
 | 
			
		||||
    }                                                                         \
 | 
			
		||||
    Vectorized<uint##bit##_t> eq(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ne(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> gt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> ge(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> lt(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
    Vectorized<uint##bit##_t> le(                                             \
 | 
			
		||||
        const Vectorized<uint##bit##_t>& other) const;                        \
 | 
			
		||||
  };                                                                          \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator+(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vaddq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator-(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vsubq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator&(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vandq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator|(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return vorrq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  template <>                                                                 \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline operator^(                                 \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& a,                                     \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& b) {                                   \
 | 
			
		||||
    return veorq_u##bit(a, b);                                                \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this == other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this != other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this > other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this < other) & Vectorized<uint##bit##_t>(1);                    \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \
 | 
			
		||||
      const Vectorized<uint##bit##_t>& other) const {                         \
 | 
			
		||||
    return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_TEMPLATE(16, 8)
 | 
			
		||||
 | 
			
		||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
 | 
			
		||||
  return vmaxvq_u8(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator*(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmulq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
 | 
			
		||||
  return vmvnq_u8(a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
 | 
			
		||||
    const Vectorized<uint8_t>& other) const {
 | 
			
		||||
  return ~(*this == other);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline minimum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vminq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline maximum(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  return vmaxq_u8(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <uint64_t mask>
 | 
			
		||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  // Build an array of flags: each bit of element is 1 if the corresponding bit
 | 
			
		||||
  // in 'mask' is set, 0 otherwise.
 | 
			
		||||
  uint8x16_t maskArray = {
 | 
			
		||||
      (mask & 1LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 64LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 128LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 256LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 512LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 1024LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 2048LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 4096LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 8192LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 16384LL) ? 0xFF : 0,
 | 
			
		||||
      (mask & 32768LL) ? 0xFF : 0};
 | 
			
		||||
  // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
  return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define VEC_UINT_NEON_OPS(vl, bit)                                             \
 | 
			
		||||
  inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \
 | 
			
		||||
    values = vdupq_n_u##bit(val);                                              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \
 | 
			
		||||
      const void* ptr, uint64_t count) {                                       \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      __at_align__ uint##bit##_t tmp_values[size()];                           \
 | 
			
		||||
      for (const auto i : c10::irange(size())) {                               \
 | 
			
		||||
        tmp_values[i] = 0;                                                     \
 | 
			
		||||
      }                                                                        \
 | 
			
		||||
      std::memcpy(                                                             \
 | 
			
		||||
          tmp_values,                                                          \
 | 
			
		||||
          reinterpret_cast<const uint##bit##_t*>(ptr),                         \
 | 
			
		||||
          count * sizeof(uint##bit##_t));                                      \
 | 
			
		||||
      return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \
 | 
			
		||||
      const {                                                                  \
 | 
			
		||||
    if (count == size()) {                                                     \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \
 | 
			
		||||
    } else {                                                                   \
 | 
			
		||||
      uint##bit##_t tmp_values[size()];                                        \
 | 
			
		||||
      vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \
 | 
			
		||||
      std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \
 | 
			
		||||
    }                                                                          \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
VEC_UINT_NEON_OPS(16, 8)
 | 
			
		||||
 | 
			
		||||
template <typename step_t>
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
 | 
			
		||||
    uint8_t base,
 | 
			
		||||
    step_t step) {
 | 
			
		||||
  const Vectorized<uint8_t> base_vec(base);
 | 
			
		||||
  const Vectorized<uint8_t> step_vec(step);
 | 
			
		||||
  const uint8x16_t step_sizes = {
 | 
			
		||||
      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
 | 
			
		||||
  return vmlaq_u8(base_vec, step_sizes, step_vec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator>>(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return x >> z;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator<<(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t bound = vdupq_n_u8(8);
 | 
			
		||||
  uint8x16_t z = vminq_u8(b, bound);
 | 
			
		||||
  return vshlq_u8(a, vreinterpretq_s8_u8(z));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b,
 | 
			
		||||
    uint64_t count) {
 | 
			
		||||
  if (count == 0) {
 | 
			
		||||
    return a;
 | 
			
		||||
  } else if (count >= 16) {
 | 
			
		||||
    return b;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Build an array of flags: each bit of element is 1 if the corresponding
 | 
			
		||||
    // bit in 'mask' is set, 0 otherwise.
 | 
			
		||||
    uint8x16_t maskArray = {
 | 
			
		||||
        static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
 | 
			
		||||
        static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
 | 
			
		||||
        0};
 | 
			
		||||
 | 
			
		||||
    // Use BSL to select elements from b where the mask is 1, else from a
 | 
			
		||||
    return vbslq_u8(maskArray, b.values, a.values);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline operator/(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& b) {
 | 
			
		||||
  uint8x16_t x = a;
 | 
			
		||||
  uint8x16_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, maximum(min, a));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_max(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& max) {
 | 
			
		||||
  return minimum(max, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<uint8_t> inline clamp_min(
 | 
			
		||||
    const Vectorized<uint8_t>& a,
 | 
			
		||||
    const Vectorized<uint8_t>& min) {
 | 
			
		||||
  return maximum(min, a);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace CPU_CAPABILITY
 | 
			
		||||
} // namespace at::vec
 | 
			
		||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
 | 
			
		||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
 | 
			
		||||
Vectorized<float> inline convert_int8_half_register_to_float(
 | 
			
		||||
    at::vec::Vectorized<uint8_t> src) {
 | 
			
		||||
  auto u8x8 = vget_low_u8(src);
 | 
			
		||||
  auto u8x8 = vld1_u8(src.operator const uint8_t*());
 | 
			
		||||
  auto u16x8 = vmovl_u8(u8x8);
 | 
			
		||||
  auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,192 +0,0 @@
 | 
			
		||||
#include <ATen/cuda/CUDAGreenContext.h>
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
  GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    int driver_version;
 | 
			
		||||
    C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        driver_version >= 12080, "cuda driver too old to use green context!");
 | 
			
		||||
    CUcontext pctx = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
 | 
			
		||||
    if (C10_UNLIKELY(!pctx)) {
 | 
			
		||||
      TORCH_WARN(
 | 
			
		||||
          "Attempted to create a green context but"
 | 
			
		||||
          " there was no primary context! Creating a primary context...");
 | 
			
		||||
 | 
			
		||||
      cudaFree(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    CUdevice device;
 | 
			
		||||
    device_id_ = device_id;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
 | 
			
		||||
 | 
			
		||||
    // Get device resources
 | 
			
		||||
    CUdevResource device_resource;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
 | 
			
		||||
        device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
 | 
			
		||||
 | 
			
		||||
    // Split resources
 | 
			
		||||
    std::vector<CUdevResource> result(1);
 | 
			
		||||
    auto result_data = result.data();
 | 
			
		||||
    unsigned int nb_groups = 1;
 | 
			
		||||
    CUdevResource remaining;
 | 
			
		||||
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
 | 
			
		||||
            result_data,
 | 
			
		||||
            &nb_groups,
 | 
			
		||||
            &device_resource,
 | 
			
		||||
            &remaining,
 | 
			
		||||
            0, // default flags
 | 
			
		||||
            num_sms));
 | 
			
		||||
 | 
			
		||||
    TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
 | 
			
		||||
 | 
			
		||||
    // Generate resource descriptor
 | 
			
		||||
    CUdevResourceDesc desc;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
 | 
			
		||||
            &desc, result_data, 1));
 | 
			
		||||
 | 
			
		||||
    // Create green context
 | 
			
		||||
    // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
 | 
			
		||||
    // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
 | 
			
		||||
        &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
 | 
			
		||||
 | 
			
		||||
    // Convert to regular context
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
 | 
			
		||||
    TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GreenContext> GreenContext::create(
 | 
			
		||||
      uint32_t num_sms,
 | 
			
		||||
      std::optional<uint32_t> device_id) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (!device_id.has_value()) {
 | 
			
		||||
      device_id = at::cuda::current_device();
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_unique<GreenContext>(device_id.value(), num_sms);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext::GreenContext(GreenContext&& other) noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
    green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
    context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
    parent_stream_ = std::exchange(other.parent_stream_, nullptr);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (this != &other) {
 | 
			
		||||
      // Clean up current resources
 | 
			
		||||
      if (green_ctx_) {
 | 
			
		||||
        CUcontext current = nullptr;
 | 
			
		||||
        C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
            c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
 | 
			
		||||
        if (current == context_) {
 | 
			
		||||
          TORCH_CHECK(
 | 
			
		||||
              false,
 | 
			
		||||
              "attempting to overwrite current green ctx "
 | 
			
		||||
              "when it is active!");
 | 
			
		||||
        }
 | 
			
		||||
        C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Take ownership of other's resources
 | 
			
		||||
      device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
      green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
      context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
      parent_stream_ = std::exchange(other.parent_stream_, nullptr);
 | 
			
		||||
    }
 | 
			
		||||
    return *this;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext::~GreenContext() noexcept{
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext GreenContext::getContext() const {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    return context_;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx GreenContext::getGreenContext() const {
 | 
			
		||||
    return green_ctx_;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void GreenContext::setContext() {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    auto current_stream = c10::cuda::getCurrentCUDAStream();
 | 
			
		||||
    parent_stream_ = current_stream.stream();
 | 
			
		||||
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(current_stream);
 | 
			
		||||
 | 
			
		||||
    CUcontext current = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
 | 
			
		||||
    if (!current) {
 | 
			
		||||
      C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
          c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
 | 
			
		||||
    } else {
 | 
			
		||||
      C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
          c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
 | 
			
		||||
    }
 | 
			
		||||
    // currently hardcodes the new green context to use the default stream
 | 
			
		||||
    // TODO(eqy): consider creating a new stream if e.g., it allows interop
 | 
			
		||||
    // with CUDA Graph captures etc.
 | 
			
		||||
    auto default_stream = c10::cuda::getDefaultCUDAStream();
 | 
			
		||||
    ev.block(default_stream);
 | 
			
		||||
    c10::cuda::setCurrentCUDAStream(default_stream);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void GreenContext::popContext() {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    // see above note about stream being hardcoded to the default stream
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
    CUcontext popped;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        popped == context_, "expected popped context to be the current ctx");
 | 
			
		||||
    ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
@ -1,53 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <ATen/cuda/CUDAEvent.h>
 | 
			
		||||
 | 
			
		||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API GreenContext {
 | 
			
		||||
 public:
 | 
			
		||||
  GreenContext(uint32_t device_id, uint32_t num_sms);
 | 
			
		||||
 | 
			
		||||
  static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
 | 
			
		||||
 | 
			
		||||
  // Delete copy constructor and assignment
 | 
			
		||||
  GreenContext(const GreenContext&) = delete;
 | 
			
		||||
  GreenContext& operator=(const GreenContext&) = delete;
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext(GreenContext&& other) noexcept;
 | 
			
		||||
  GreenContext& operator=(GreenContext&& other) noexcept;
 | 
			
		||||
  ~GreenContext() noexcept;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext getContext() const;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx getGreenContext() const;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void setContext();
 | 
			
		||||
 | 
			
		||||
  void popContext();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  int32_t device_id_ = -1;
 | 
			
		||||
  CUgreenCtx green_ctx_ = nullptr;
 | 
			
		||||
  CUcontext context_ = nullptr;
 | 
			
		||||
  cudaStream_t parent_stream_ = nullptr;
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
@ -183,6 +183,11 @@ struct CUDACachingHostAllocatorImpl
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool pinned_use_background_threads() override {
 | 
			
		||||
    return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
 | 
			
		||||
        pinned_use_background_threads();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EventPool::Event create_event_internal(DeviceIndex idx) {
 | 
			
		||||
    // Leak the event pool to avoid shutdown issue.
 | 
			
		||||
    static auto* event_pool = new EventPool();
 | 
			
		||||
 | 
			
		||||
@ -70,7 +70,11 @@
 | 
			
		||||
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
namespace at_cuda_detail {
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
 | 
			
		||||
 | 
			
		||||
@ -92,6 +96,10 @@ template <>
 | 
			
		||||
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
 | 
			
		||||
       ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
} // namespace at_cuda_detail
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
@ -113,7 +121,7 @@ struct cuda_type<c10::Half> {
 | 
			
		||||
  using type = __half;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct cuda_type<c10::BFloat16> {
 | 
			
		||||
@ -169,6 +177,7 @@ inline void segmented_sort_pairs(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
 | 
			
		||||
inline void unique_by_key(
 | 
			
		||||
  KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
 | 
			
		||||
@ -184,6 +193,7 @@ inline void unique_by_key(
 | 
			
		||||
  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
 | 
			
		||||
    keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace impl {
 | 
			
		||||
 | 
			
		||||
@ -195,6 +205,36 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
 | 
			
		||||
  *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
template<typename ValueT, typename InputIteratorT>
 | 
			
		||||
struct chained_iterator {
 | 
			
		||||
  using iterator_category = std::random_access_iterator_tag;
 | 
			
		||||
  using difference_type   = std::ptrdiff_t;
 | 
			
		||||
  using value_type        = ValueT;
 | 
			
		||||
  using pointer           = ValueT*;
 | 
			
		||||
  using reference         = ValueT&;
 | 
			
		||||
 | 
			
		||||
  InputIteratorT iter;
 | 
			
		||||
  ValueT *first;
 | 
			
		||||
  difference_type offset = 0;
 | 
			
		||||
 | 
			
		||||
  __device__ ValueT operator[](difference_type i) {
 | 
			
		||||
    i +=  offset;
 | 
			
		||||
    if (i == 0) {
 | 
			
		||||
      return *first;
 | 
			
		||||
    } else {
 | 
			
		||||
      return ValueT(iter[i - 1]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  __device__ chained_iterator operator+(difference_type i) {
 | 
			
		||||
    return chained_iterator{iter, first, i};
 | 
			
		||||
  }
 | 
			
		||||
  __device__ ValueT operator*() {
 | 
			
		||||
    return (*this)[0];
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
 | 
			
		||||
// so split at int_max/2
 | 
			
		||||
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
 | 
			
		||||
@ -239,6 +279,25 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        first_elem_ptr,
 | 
			
		||||
        scan_op);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
    using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
 | 
			
		||||
    using tuple = typename ArgIndexInputIterator::value_type;
 | 
			
		||||
    auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  {
 | 
			
		||||
      if (x.key == 0) {
 | 
			
		||||
        return *first_elem_ptr;
 | 
			
		||||
      } else {
 | 
			
		||||
        return x.value;
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
 | 
			
		||||
      ArgIndexInputIterator(input + i), input_iter_transform);
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
 | 
			
		||||
        input_,
 | 
			
		||||
        output + i,
 | 
			
		||||
        scan_op,
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#else
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
 | 
			
		||||
        input + i + 1,
 | 
			
		||||
        output + i,
 | 
			
		||||
@ -246,6 +305,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
@ -497,6 +557,16 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        first_elem_ptr,
 | 
			
		||||
        scan_op);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#if !CUB_SUPPORTS_FUTURE_VALUE()
 | 
			
		||||
    auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
 | 
			
		||||
      input + i, first_elem_ptr};
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
 | 
			
		||||
        input_,
 | 
			
		||||
        output + i,
 | 
			
		||||
        scan_op,
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#else
 | 
			
		||||
    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
 | 
			
		||||
        input + i,
 | 
			
		||||
        output + i,
 | 
			
		||||
@ -504,10 +574,12 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
        ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
 | 
			
		||||
        size_cub,
 | 
			
		||||
        at::cuda::getCurrentCUDAStream());
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
 | 
			
		||||
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
 | 
			
		||||
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
 | 
			
		||||
@ -535,6 +607,7 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
 | 
			
		||||
void unique(InputIteratorT input, OutputIteratorT output,
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,14 @@
 | 
			
		||||
#define CUB_VERSION 200001
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/306
 | 
			
		||||
#if CUB_VERSION >= 101300
 | 
			
		||||
#define CUB_SUPPORTS_NV_BFLOAT16() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_NV_BFLOAT16() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/326
 | 
			
		||||
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
 | 
			
		||||
@ -20,6 +28,30 @@
 | 
			
		||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for UniqueByKey is added to cub 1.16 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/405
 | 
			
		||||
#if CUB_VERSION >= 101600
 | 
			
		||||
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for scan by key is added to cub 1.15
 | 
			
		||||
// in https://github.com/NVIDIA/cub/pull/376
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for cub::FutureValue is added to cub 1.15 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/305
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
#define CUB_SUPPORTS_FUTURE_VALUE() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_FUTURE_VALUE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// There were many bc-breaking changes in major version release of CCCL v3.0.0
 | 
			
		||||
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
 | 
			
		||||
#if CUB_VERSION >= 200800
 | 
			
		||||
 | 
			
		||||
@ -1,23 +0,0 @@
 | 
			
		||||
#include <ATen/detail/XLAHooksInterface.h>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
const XLAHooksInterface& getXLAHooks() {
 | 
			
		||||
  auto create_impl = [] {
 | 
			
		||||
    // Create XLA hooks using the registry
 | 
			
		||||
    auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
 | 
			
		||||
    if (hooks) {
 | 
			
		||||
      return hooks;
 | 
			
		||||
    }
 | 
			
		||||
    // If hooks creation fails, fall back to default implementation
 | 
			
		||||
    return std::make_unique<XLAHooksInterface>();
 | 
			
		||||
  };
 | 
			
		||||
  static auto hooks = create_impl();
 | 
			
		||||
  return *hooks;
 | 
			
		||||
}
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
 | 
			
		||||
 | 
			
		||||
} // namespace at
 | 
			
		||||
@ -1,79 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Registry.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/detail/AcceleratorHooksInterface.h>
 | 
			
		||||
 | 
			
		||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
 | 
			
		||||
constexpr const char* XLA_HELP =
 | 
			
		||||
  "This error has occurred because you are trying "
 | 
			
		||||
  "to use some XLA functionality, but the XLA library has not been "
 | 
			
		||||
  "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
 | 
			
		||||
 | 
			
		||||
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
  ~XLAHooksInterface() override = default;
 | 
			
		||||
 | 
			
		||||
  void init() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool hasXLA() const {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual std::string showConfig() const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false,
 | 
			
		||||
        "Cannot query detailed XLA version without torch_xla library. ",
 | 
			
		||||
        XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const Generator& getDefaultGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Generator getNewGenerator(
 | 
			
		||||
      [[maybe_unused]] DeviceIndex device_index = -1) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual DeviceIndex getCurrentDevice() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Device getDeviceFromPtr(void* /*data*/) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Allocator* getPinnedMemoryAllocator() const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isPinnedPtr(const void* data) const override {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool hasPrimaryContext(DeviceIndex device_index) const override {
 | 
			
		||||
    TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API XLAHooksArgs {};
 | 
			
		||||
 | 
			
		||||
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
 | 
			
		||||
#define REGISTER_XLA_HOOKS(clsname) \
 | 
			
		||||
  C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
TORCH_API const XLAHooksInterface& getXLAHooks();
 | 
			
		||||
} // namespace detail
 | 
			
		||||
} // namespace at
 | 
			
		||||
C10_DIAGNOSTIC_POP()
 | 
			
		||||
@ -160,10 +160,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
 | 
			
		||||
  DispatchKey::CUDA,
 | 
			
		||||
  DispatchKey::CPU,
 | 
			
		||||
  DispatchKey::PrivateUse1,
 | 
			
		||||
  DispatchKey::SparseCPU,
 | 
			
		||||
  DispatchKey::SparseCUDA,
 | 
			
		||||
  DispatchKey::SparseCsrCPU,
 | 
			
		||||
  DispatchKey::SparseCsrCUDA,
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
 | 
			
		||||
 | 
			
		||||
@ -658,7 +658,6 @@ static void check_shape_forward(const at::Tensor& input,
 | 
			
		||||
  TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
 | 
			
		||||
  TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(weight_dim == k,
 | 
			
		||||
           "Expected ", weight_dim, "-dimensional input for ", weight_dim,
 | 
			
		||||
 | 
			
		||||
@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
 | 
			
		||||
    try {
 | 
			
		||||
      mkldnn_matmul_i8i8i32(self, mat2, result);
 | 
			
		||||
      dispatched = true;
 | 
			
		||||
    } catch ([[maybe_unused]] const std::exception& e) {
 | 
			
		||||
    } catch (const std::exception& e) {
 | 
			
		||||
      TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -11,8 +11,6 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
 | 
			
		||||
              "pixel_shuffle expects a positive upscale_factor, but got ",
 | 
			
		||||
              upscale_factor);
 | 
			
		||||
  int64_t c = self.size(-3);
 | 
			
		||||
  TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
 | 
			
		||||
        "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
 | 
			
		||||
  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
 | 
			
		||||
  TORCH_CHECK(c % upscale_factor_squared == 0,
 | 
			
		||||
              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
 | 
			
		||||
 | 
			
		||||
@ -259,20 +259,11 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
 | 
			
		||||
  const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
 | 
			
		||||
  const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
 | 
			
		||||
  const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
 | 
			
		||||
  /* GCC 14.2 (RISC-V RVV) ICE workaround:
 | 
			
		||||
   * Avoid single-statement read-modify-write on MEM_REF like:
 | 
			
		||||
   *   *input_tile_val =
 | 
			
		||||
   *     __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
 | 
			
		||||
   * This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
 | 
			
		||||
   * with -march=rv64gcv. Use a temporary then write back.
 | 
			
		||||
   * Do NOT refactor into the single-statement form. Clang is unaffected.
 | 
			
		||||
   */
 | 
			
		||||
  vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
 | 
			
		||||
  tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
 | 
			
		||||
  *input_tile_val = tmp_input_tile_val;
 | 
			
		||||
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void winograd_f2k3_output_transform_inplace__rvv(
 | 
			
		||||
@ -286,15 +277,9 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
 | 
			
		||||
  const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
 | 
			
		||||
  const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
 | 
			
		||||
  const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
 | 
			
		||||
  /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
 | 
			
		||||
   * Keep the temporary + write-back pattern to avoid ICE.
 | 
			
		||||
   * Do NOT rewrite into:
 | 
			
		||||
   *   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
 | 
			
		||||
   */
 | 
			
		||||
  vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
 | 
			
		||||
  tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
 | 
			
		||||
  tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
 | 
			
		||||
  *input_tile_val = tmp_output_tile_val;
 | 
			
		||||
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
 | 
			
		||||
  *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline vfloat32m1_t
 | 
			
		||||
@ -315,17 +300,11 @@ inline void winograd_f2k3_kernel_transform__rvv(
 | 
			
		||||
  const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
 | 
			
		||||
  const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
 | 
			
		||||
  vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
 | 
			
		||||
  /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
 | 
			
		||||
   * Keep the temporary + write-back pattern to avoid ICE.
 | 
			
		||||
   * Do NOT rewrite into:
 | 
			
		||||
   *   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
 | 
			
		||||
   */
 | 
			
		||||
  vfloat32m1x4_t tmp_transform = *transform;
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
 | 
			
		||||
  *transform = tmp_transform;
 | 
			
		||||
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
 | 
			
		||||
  *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
 | 
			
		||||
 | 
			
		||||
@ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 * Checks whether DISABLE_ADDMM_CUDA_LT is set.
 | 
			
		||||
 * Additionally, for ROCM we test whether the architecture supports the Lt.
 | 
			
		||||
 */
 | 
			
		||||
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
 | 
			
		||||
  // When hipBLASLt is not supported on the architecture, return true
 | 
			
		||||
  #ifdef USE_ROCM
 | 
			
		||||
  static const std::vector<std::string> archs = {
 | 
			
		||||
static bool getDisableAddmmCudaLt() {
 | 
			
		||||
    static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
 | 
			
		||||
    if (env_value == "1") {
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
    return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
static bool isSupportedHipLtROCmArch(int index) {
 | 
			
		||||
    static const std::vector<std::string> archs = {
 | 
			
		||||
        "gfx90a", "gfx942",
 | 
			
		||||
    #if ROCM_VERSION >= 60300
 | 
			
		||||
#if ROCM_VERSION >= 60300
 | 
			
		||||
        "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
 | 
			
		||||
    #endif
 | 
			
		||||
    #if ROCM_VERSION >= 70000
 | 
			
		||||
#endif
 | 
			
		||||
#if ROCM_VERSION >= 70000
 | 
			
		||||
        "gfx950", "gfx1150", "gfx1151"
 | 
			
		||||
    #endif
 | 
			
		||||
  };
 | 
			
		||||
  const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
 | 
			
		||||
  if (!is_hipblas_lt_arch_supported) {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  #endif
 | 
			
		||||
 | 
			
		||||
  // Check whether it is disabled in the env
 | 
			
		||||
  static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
 | 
			
		||||
  if (is_addmm_cuda_lt_disabled == "1") {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 * Check whether for the given input we want to enable the Lt interface
 | 
			
		||||
 */
 | 
			
		||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
 | 
			
		||||
  // Implies 2D bias which we currently not send through Lt.
 | 
			
		||||
  // TODO: this check is done pre col-major input preparation,
 | 
			
		||||
  // so, this condition can be ralexed in cases when a col-major
 | 
			
		||||
  // copy of result is needed.
 | 
			
		||||
  if (result.is_same(self)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  #if defined(USE_ROCM) && ROCM_VERSION == 60400
 | 
			
		||||
  // hipblaslt TT fp32 regression on ROCm 6.4, cannot use
 | 
			
		||||
  const auto args = cublasCommonArgs(mat1, mat2, result);
 | 
			
		||||
  if (args.transa == 't' && args.transb == 't') {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  #endif
 | 
			
		||||
 | 
			
		||||
  const auto mat1_sizes = mat1.sizes();
 | 
			
		||||
  const auto mat2_sizes = mat2.sizes();
 | 
			
		||||
  #if defined(CUDA_VERSION) || defined(USE_ROCM)
 | 
			
		||||
  const auto scalar_type = mat1.scalar_type();
 | 
			
		||||
  return (beta.toComplexDouble() == 1.0
 | 
			
		||||
    // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
 | 
			
		||||
    // is to use lt interface only when self is bias.
 | 
			
		||||
    && self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
 | 
			
		||||
    && result.dim() == 2 && result.is_contiguous()
 | 
			
		||||
    && ( // some dtype restrictions
 | 
			
		||||
      #ifndef USE_ROCM
 | 
			
		||||
      scalar_type == at::ScalarType::Double ||
 | 
			
		||||
      #endif
 | 
			
		||||
      scalar_type == at::ScalarType::Float ||
 | 
			
		||||
      scalar_type == at::ScalarType::Half ||
 | 
			
		||||
      scalar_type == at::ScalarType::BFloat16
 | 
			
		||||
    )
 | 
			
		||||
    && ( // some shape/stride restrictions
 | 
			
		||||
      // Strangely, if mat2 has only 1 row or column, we get
 | 
			
		||||
      // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
 | 
			
		||||
      // NOTE: extension to mat1 because mat1/mat2 can be swapped based off
 | 
			
		||||
      // their row-/col-majorness.
 | 
			
		||||
      mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
 | 
			
		||||
      mat2_sizes[0] > 1 && mat2_sizes[1] > 1
 | 
			
		||||
      // The last conditions is to skip 16b transA and non-trans-B having
 | 
			
		||||
      // leading dim >> rows when they are sliced from a large tensor
 | 
			
		||||
      // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
 | 
			
		||||
      #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
 | 
			
		||||
      // Related to avoiding the leading stride >> leading dim problematic case
 | 
			
		||||
      // with 16b dtypes described above. For such dtypes we only allow inputs
 | 
			
		||||
      // which are either row- or col-major (i.e. non-overlapping, compact memory layout).
 | 
			
		||||
      // In that case the leading stride will be equal to the outer dim len.
 | 
			
		||||
      // Why do we catch this case here? The following `prepare_matrix_for_cublas` method
 | 
			
		||||
      // does not modify inputs as long as there is a stride of length 1
 | 
			
		||||
      // and the leading stride is at least max(1, other dim length), so we might
 | 
			
		||||
      // end up with contiguous cols but not rows (i.e. holes between different rows)
 | 
			
		||||
      // and vice versa.
 | 
			
		||||
      && mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
 | 
			
		||||
      mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
 | 
			
		||||
      && (
 | 
			
		||||
        // filter by dtype
 | 
			
		||||
        (scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
 | 
			
		||||
        // check mat1/mat2 is row-/col-major
 | 
			
		||||
        (mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
 | 
			
		||||
      )
 | 
			
		||||
      #endif
 | 
			
		||||
    )
 | 
			
		||||
  );
 | 
			
		||||
  #endif
 | 
			
		||||
 | 
			
		||||
  // no compliance by default
 | 
			
		||||
  return false;
 | 
			
		||||
#endif
 | 
			
		||||
    };
 | 
			
		||||
    return at::detail::getCUDAHooks().isGPUArch(archs, index);
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
 | 
			
		||||
@ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
 | 
			
		||||
bool launchGemmAndBiasCublasLt(
 | 
			
		||||
    // args contains result which is modified
 | 
			
		||||
    cublasCommonArgs& args,
 | 
			
		||||
    const Tensor& self,
 | 
			
		||||
    const Scalar& alpha,
 | 
			
		||||
    Activation activation = Activation::None
 | 
			
		||||
) {
 | 
			
		||||
  const auto* self_ptr = self.const_data_ptr<scalar_t>();
 | 
			
		||||
 | 
			
		||||
  const auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
  if (tuning_ctx->IsTunableOpEnabled()) {
 | 
			
		||||
    // TODO: maybe also return some success state?
 | 
			
		||||
    launchTunableGemmAndBias<scalar_t>(
 | 
			
		||||
      args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
 | 
			
		||||
    );
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
 | 
			
		||||
    args.transa == 't',
 | 
			
		||||
    args.transb == 't',
 | 
			
		||||
    args.m,
 | 
			
		||||
    args.n,
 | 
			
		||||
    args.k,
 | 
			
		||||
    alpha.to<at::opmath_type<scalar_t>>(),
 | 
			
		||||
    args.mata->const_data_ptr<scalar_t>(),
 | 
			
		||||
    args.lda,
 | 
			
		||||
    args.matb->const_data_ptr<scalar_t>(),
 | 
			
		||||
    args.ldb,
 | 
			
		||||
    self_ptr,
 | 
			
		||||
    args.result->data_ptr<res_scalar_t>(),
 | 
			
		||||
    args.result_ld,
 | 
			
		||||
    activation_to_gemm_and_blas_arg(activation)
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
 | 
			
		||||
bool launchGemmCublas(
 | 
			
		||||
    // args contains result which is modified
 | 
			
		||||
    cublasCommonArgs& args,
 | 
			
		||||
    const Scalar& alpha,
 | 
			
		||||
    const Scalar& beta
 | 
			
		||||
) {
 | 
			
		||||
  at::cuda::blas::gemm<scalar_t, res_scalar_t>(
 | 
			
		||||
    args.transa,
 | 
			
		||||
    args.transb,
 | 
			
		||||
    args.m,
 | 
			
		||||
    args.n,
 | 
			
		||||
    args.k,
 | 
			
		||||
    alpha.to<at::opmath_type<scalar_t>>(),
 | 
			
		||||
    args.mata->const_data_ptr<scalar_t>(),
 | 
			
		||||
    args.lda,
 | 
			
		||||
    args.matb->const_data_ptr<scalar_t>(),
 | 
			
		||||
    args.ldb,
 | 
			
		||||
    beta.to<at::opmath_type<scalar_t>>(),
 | 
			
		||||
    args.result->data_ptr<res_scalar_t>(),
 | 
			
		||||
    args.result_ld
 | 
			
		||||
  );
 | 
			
		||||
  return true; // success!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
 | 
			
		||||
  // Shape checks {
 | 
			
		||||
  // Make sure to keep addmm_cuda below in sync with this code; it
 | 
			
		||||
  // preflights a check to try to avoid actually needing to call
 | 
			
		||||
  // expand().
 | 
			
		||||
@ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 | 
			
		||||
    "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  if (result.is_same(self)) {
 | 
			
		||||
    TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
 | 
			
		||||
    TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
 | 
			
		||||
    TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
 | 
			
		||||
  }
 | 
			
		||||
  // } Shape checks
 | 
			
		||||
 | 
			
		||||
  // NOLINTNEXTLINE(*c-array*)
 | 
			
		||||
  TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
 | 
			
		||||
  checkAllSameGPU(__func__, targs);
 | 
			
		||||
 | 
			
		||||
  // Handle whether to use the Lt interface {
 | 
			
		||||
  static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
 | 
			
		||||
  IntArrayRef mat1_sizes = mat1.sizes();
 | 
			
		||||
  IntArrayRef mat2_sizes = mat2.sizes();
 | 
			
		||||
  IntArrayRef self__sizes;
 | 
			
		||||
  bool useLtInterface = false;
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
  // When hipBLASLt is not supported on the architecture,
 | 
			
		||||
  // disable_addmm_cuda_lt will always be to set to true
 | 
			
		||||
  static bool disable_addmm_cuda_lt =
 | 
			
		||||
    !isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
 | 
			
		||||
#else
 | 
			
		||||
  static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
 | 
			
		||||
#endif
 | 
			
		||||
  // if lt path fails, we recurse back into this function here and force the lt path to off
 | 
			
		||||
  // we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
 | 
			
		||||
  bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
 | 
			
		||||
  #ifdef USE_ROCM
 | 
			
		||||
  // Conditioned on the device index, which is not persistent
 | 
			
		||||
  disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
 | 
			
		||||
  #endif
 | 
			
		||||
  // Condition on the input
 | 
			
		||||
  disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
 | 
			
		||||
  // }
 | 
			
		||||
 | 
			
		||||
  bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
 | 
			
		||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
 | 
			
		||||
  // hipblaslt TT fp32 regression on ROCm 6.4, cannot use
 | 
			
		||||
  cublasCommonArgs _args(mat1, mat2, result);
 | 
			
		||||
  if (_args.transa == 't' && _args.transb == 't') {
 | 
			
		||||
    disable_addmm_cuda_lt_final = true;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
  at::ScalarType scalar_type = mat1.scalar_type();
 | 
			
		||||
  bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
 | 
			
		||||
  c10::MaybeOwned<Tensor> self_;
 | 
			
		||||
  if (&result != &self) {
 | 
			
		||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
 | 
			
		||||
    // Strangely, if mat2 has only 1 row or column, we get
 | 
			
		||||
    // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
 | 
			
		||||
    // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
 | 
			
		||||
    // is to use lt interface only when self is bias.
 | 
			
		||||
    // for cuda 11.4, cublasLtMatmul is activated
 | 
			
		||||
    // the last two conditions is to skip 16b transA and non-trans-B having
 | 
			
		||||
    // leading dim >> rows when they are sliced from a large tensor
 | 
			
		||||
    // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
 | 
			
		||||
    if (!disable_addmm_cuda_lt_final) {
 | 
			
		||||
      useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
 | 
			
		||||
          result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
 | 
			
		||||
          self.is_contiguous() && result.is_contiguous() &&
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
          (scalar_type == at::ScalarType::Float ||
 | 
			
		||||
           scalar_type == at::ScalarType::Half ||
 | 
			
		||||
           scalar_type == at::ScalarType::BFloat16) &&
 | 
			
		||||
#else
 | 
			
		||||
          (scalar_type == at::ScalarType::Double ||
 | 
			
		||||
           scalar_type == at::ScalarType::Float ||
 | 
			
		||||
           scalar_type == at::ScalarType::Half ||
 | 
			
		||||
           scalar_type == at::ScalarType::BFloat16) &&
 | 
			
		||||
#endif
 | 
			
		||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
 | 
			
		||||
          mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
 | 
			
		||||
#else
 | 
			
		||||
          mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
 | 
			
		||||
          mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
 | 
			
		||||
          mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
 | 
			
		||||
          // avoid leading dim >> rows bugs
 | 
			
		||||
          ((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
 | 
			
		||||
           (mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
 | 
			
		||||
           (scalar_type != at::ScalarType::Half &&
 | 
			
		||||
            scalar_type != at::ScalarType::BFloat16)) &&
 | 
			
		||||
          ((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
 | 
			
		||||
           (mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
 | 
			
		||||
           (scalar_type != at::ScalarType::Half &&
 | 
			
		||||
            scalar_type != at::ScalarType::BFloat16));
 | 
			
		||||
#endif
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    if (!useLtInterface) {
 | 
			
		||||
      self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
 | 
			
		||||
    }
 | 
			
		||||
    self__sizes = self_->sizes();
 | 
			
		||||
  } else {
 | 
			
		||||
    self_ = c10::MaybeOwned<Tensor>::borrowed(self);
 | 
			
		||||
    self__sizes = self_->sizes();
 | 
			
		||||
    TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
 | 
			
		||||
    TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
 | 
			
		||||
    TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Handle result/self shapes
 | 
			
		||||
  if (!result.is_same(self)) {
 | 
			
		||||
    at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
 | 
			
		||||
 | 
			
		||||
    const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
 | 
			
		||||
      if (disable_addmm_cuda_lt) {
 | 
			
		||||
        // When in non-Lt path we do expand self even before
 | 
			
		||||
        // check for beta != 0.0 to make sure that
 | 
			
		||||
        // test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
 | 
			
		||||
        // runs green.
 | 
			
		||||
        return expand_size(self, result.sizes(), "addmm");
 | 
			
		||||
      }
 | 
			
		||||
      // copy next, should broadcast
 | 
			
		||||
      return c10::MaybeOwned<Tensor>::borrowed(self);
 | 
			
		||||
    }();
 | 
			
		||||
    // We copy bias when in the non-Lt path
 | 
			
		||||
    if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
 | 
			
		||||
      // NOTE: self should broadcast over result
 | 
			
		||||
      at::native::copy_(result, *self_maybe_expanded);
 | 
			
		||||
  if (&result != &self) {
 | 
			
		||||
    at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
 | 
			
		||||
    if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
 | 
			
		||||
      at::native::copy_(result, *self_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Short circuit on empty result
 | 
			
		||||
  if (result.numel() == 0) {
 | 
			
		||||
 | 
			
		||||
  IntArrayRef result_sizes = result.sizes();
 | 
			
		||||
  if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Short circuit if the reduction dim is empty
 | 
			
		||||
  if (mat1.sizes()[1] == 0) {
 | 
			
		||||
  cublasCommonArgs args(mat1, mat2, result);
 | 
			
		||||
 | 
			
		||||
  if (mat1.numel() == 0) {
 | 
			
		||||
    // By definition, when beta==0, values in self should be ignored. nans and infs
 | 
			
		||||
    // should not propagate
 | 
			
		||||
    if (beta.toComplexDouble() == 0.) {
 | 
			
		||||
@ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 | 
			
		||||
        result,
 | 
			
		||||
        self.expand(result.sizes()),
 | 
			
		||||
        at::native::scalar_tensor(
 | 
			
		||||
          beta,
 | 
			
		||||
          self.scalar_type(),
 | 
			
		||||
          std::nullopt /* layout */,
 | 
			
		||||
          at::kCPU,
 | 
			
		||||
          std::nullopt /* pin_memory */
 | 
			
		||||
        )
 | 
			
		||||
    );
 | 
			
		||||
            beta,
 | 
			
		||||
            self.scalar_type(),
 | 
			
		||||
            std::nullopt /* layout */,
 | 
			
		||||
            at::kCPU,
 | 
			
		||||
            std::nullopt /* pin_memory */));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  cublasCommonArgs args(mat1, mat2, result);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
 | 
			
		||||
 | 
			
		||||
  // The Lt path
 | 
			
		||||
  if (!disable_addmm_cuda_lt) {
 | 
			
		||||
    bool lt_success = false;
 | 
			
		||||
  if (useLtInterface) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
    bool okay = true;
 | 
			
		||||
    if (is_float_output_with_half_input) {
 | 
			
		||||
      #ifdef USE_ROCM
 | 
			
		||||
      TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
 | 
			
		||||
      #else
 | 
			
		||||
      if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
 | 
			
		||||
       TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
 | 
			
		||||
      }
 | 
			
		||||
      AT_DISPATCH_REDUCED_FLOATING_TYPES(
 | 
			
		||||
        scalar_type,
 | 
			
		||||
        "addmm_cuda_lt",
 | 
			
		||||
        [&] {
 | 
			
		||||
          lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
 | 
			
		||||
        }
 | 
			
		||||
      );
 | 
			
		||||
      #endif
 | 
			
		||||
    } else {
 | 
			
		||||
      // !is_float_output_with_half_input
 | 
			
		||||
      AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
        at::ScalarType::Half,
 | 
			
		||||
        at::ScalarType::BFloat16,
 | 
			
		||||
        scalar_type,
 | 
			
		||||
        "addmm_cuda_lt",
 | 
			
		||||
        [&] {
 | 
			
		||||
          lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
 | 
			
		||||
        auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
        if (tuning_ctx->IsTunableOpEnabled()) {
 | 
			
		||||
          launchTunableGemmAndBias<scalar_t>(
 | 
			
		||||
              args,
 | 
			
		||||
              alpha,
 | 
			
		||||
              (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
 | 
			
		||||
              activation_to_gemm_and_blas_arg(activation));
 | 
			
		||||
        } else {
 | 
			
		||||
          okay = at::cuda::blas::gemm_and_bias<scalar_t>(
 | 
			
		||||
            args.transa == 't',
 | 
			
		||||
            args.transb == 't',
 | 
			
		||||
            args.m,
 | 
			
		||||
            args.n,
 | 
			
		||||
            args.k,
 | 
			
		||||
            alpha.to<at::opmath_type<scalar_t>>(),
 | 
			
		||||
            args.mata->const_data_ptr<scalar_t>(),
 | 
			
		||||
            args.lda,
 | 
			
		||||
            args.matb->const_data_ptr<scalar_t>(),
 | 
			
		||||
            args.ldb,
 | 
			
		||||
            // This condition is needed for mm case on ROCm for hipblasLt path.
 | 
			
		||||
            // Passing the bias ptr as null to avoid accuracy issues for mm case.
 | 
			
		||||
            (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
 | 
			
		||||
            args.result->data_ptr<scalar_t>(),
 | 
			
		||||
            args.result_ld,
 | 
			
		||||
            activation_to_gemm_and_blas_arg(activation)
 | 
			
		||||
          );
 | 
			
		||||
        }
 | 
			
		||||
      );
 | 
			
		||||
    } // end is_float_output_with_half_input
 | 
			
		||||
 | 
			
		||||
    if (!lt_success) {
 | 
			
		||||
    // lt path failed; recurse but disable lt path
 | 
			
		||||
      });
 | 
			
		||||
    }
 | 
			
		||||
    if (!okay) {
 | 
			
		||||
      // lt path failed; recurse but disable lt path
 | 
			
		||||
      return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
 | 
			
		||||
    }
 | 
			
		||||
    // end Lt path
 | 
			
		||||
  } else {
 | 
			
		||||
    // No Lt, we use a GEMM instead
 | 
			
		||||
#else
 | 
			
		||||
    auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
 | 
			
		||||
    bool okay = true;
 | 
			
		||||
    if (is_float_output_with_half_input) {
 | 
			
		||||
      AT_DISPATCH_REDUCED_FLOATING_TYPES(
 | 
			
		||||
        scalar_type,
 | 
			
		||||
        "addmm_cuda_lt",
 | 
			
		||||
        [&] {
 | 
			
		||||
        auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
        if (tuning_ctx->IsTunableOpEnabled()) {
 | 
			
		||||
          TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
          okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
 | 
			
		||||
              args.transa == 't',
 | 
			
		||||
              args.transb == 't',
 | 
			
		||||
              args.m,
 | 
			
		||||
              args.n,
 | 
			
		||||
              args.k,
 | 
			
		||||
              alpha.to<at::opmath_type<scalar_t>>(),
 | 
			
		||||
              args.mata->const_data_ptr<scalar_t>(),
 | 
			
		||||
              args.lda,
 | 
			
		||||
              args.matb->const_data_ptr<scalar_t>(),
 | 
			
		||||
              args.ldb,
 | 
			
		||||
              self.const_data_ptr<scalar_t>(),
 | 
			
		||||
              args.result->data_ptr<float>(),
 | 
			
		||||
              args.result_ld,
 | 
			
		||||
              activation_epilogue
 | 
			
		||||
          );
 | 
			
		||||
        }});
 | 
			
		||||
    } else {
 | 
			
		||||
      AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
        at::ScalarType::Half,
 | 
			
		||||
        at::ScalarType::BFloat16,
 | 
			
		||||
        scalar_type,
 | 
			
		||||
        "addmm_cuda_lt",
 | 
			
		||||
        [&] {
 | 
			
		||||
        auto tuning_ctx = at::cuda::tunable::getTuningContext();
 | 
			
		||||
        if (tuning_ctx->IsTunableOpEnabled()) {
 | 
			
		||||
          launchTunableGemmAndBias<scalar_t>(
 | 
			
		||||
              args,
 | 
			
		||||
              alpha,
 | 
			
		||||
              self.const_data_ptr<scalar_t>(),
 | 
			
		||||
              activation_epilogue);
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
          okay = at::cuda::blas::gemm_and_bias<scalar_t>(
 | 
			
		||||
              args.transa == 't',
 | 
			
		||||
              args.transb == 't',
 | 
			
		||||
              args.m,
 | 
			
		||||
              args.n,
 | 
			
		||||
              args.k,
 | 
			
		||||
              alpha.to<at::opmath_type<scalar_t>>(),
 | 
			
		||||
              args.mata->const_data_ptr<scalar_t>(),
 | 
			
		||||
              args.lda,
 | 
			
		||||
              args.matb->const_data_ptr<scalar_t>(),
 | 
			
		||||
              args.ldb,
 | 
			
		||||
              self.const_data_ptr<scalar_t>(),
 | 
			
		||||
              args.result->data_ptr<scalar_t>(),
 | 
			
		||||
              args.result_ld,
 | 
			
		||||
              activation_epilogue
 | 
			
		||||
          );
 | 
			
		||||
      }});
 | 
			
		||||
    }
 | 
			
		||||
    if (!okay) {
 | 
			
		||||
      // lt path failed; recurse but disable lt path
 | 
			
		||||
      return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
  } else
 | 
			
		||||
  {
 | 
			
		||||
    if (is_float_output_with_half_input) {
 | 
			
		||||
      AT_DISPATCH_REDUCED_FLOATING_TYPES(
 | 
			
		||||
        scalar_type,
 | 
			
		||||
        "addmm_cuda",
 | 
			
		||||
        [&] {
 | 
			
		||||
          launchGemmCublas<scalar_t, float>(args, alpha, beta);
 | 
			
		||||
        }
 | 
			
		||||
      );
 | 
			
		||||
          using opmath_t = at::opmath_type<scalar_t>;
 | 
			
		||||
          opmath_t alpha_val = alpha.to<opmath_t>();
 | 
			
		||||
          opmath_t beta_val = beta.to<opmath_t>();
 | 
			
		||||
          const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
 | 
			
		||||
          const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
 | 
			
		||||
 | 
			
		||||
          float* result_ptr = args.result->mutable_data_ptr<float>();
 | 
			
		||||
          at::cuda::blas::gemm<scalar_t, float>(
 | 
			
		||||
              args.transa,
 | 
			
		||||
              args.transb,
 | 
			
		||||
              args.m,
 | 
			
		||||
              args.n,
 | 
			
		||||
              args.k,
 | 
			
		||||
              alpha_val,
 | 
			
		||||
              mat1_ptr,
 | 
			
		||||
              args.lda,
 | 
			
		||||
              mat2_ptr,
 | 
			
		||||
              args.ldb,
 | 
			
		||||
              beta_val,
 | 
			
		||||
              result_ptr,
 | 
			
		||||
              args.result_ld);
 | 
			
		||||
        });
 | 
			
		||||
    } else {
 | 
			
		||||
      AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
 | 
			
		||||
        at::ScalarType::Half,
 | 
			
		||||
@ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 | 
			
		||||
        scalar_type,
 | 
			
		||||
        "addmm_cuda",
 | 
			
		||||
        [&] {
 | 
			
		||||
          launchGemmCublas<scalar_t>(args, alpha, beta);
 | 
			
		||||
        }
 | 
			
		||||
      );
 | 
			
		||||
          using opmath_t = at::opmath_type<scalar_t>;
 | 
			
		||||
          opmath_t alpha_val = alpha.to<opmath_t>();
 | 
			
		||||
          opmath_t beta_val = beta.to<opmath_t>();
 | 
			
		||||
          const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
 | 
			
		||||
          const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
 | 
			
		||||
          scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
 | 
			
		||||
          at::cuda::blas::gemm<scalar_t>(
 | 
			
		||||
              args.transa,
 | 
			
		||||
              args.transb,
 | 
			
		||||
              args.m,
 | 
			
		||||
              args.n,
 | 
			
		||||
              args.k,
 | 
			
		||||
              alpha_val,
 | 
			
		||||
              mat1_ptr,
 | 
			
		||||
              args.lda,
 | 
			
		||||
              mat2_ptr,
 | 
			
		||||
              args.ldb,
 | 
			
		||||
              beta_val,
 | 
			
		||||
              result_ptr,
 | 
			
		||||
              args.result_ld);
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Apply epilogue
 | 
			
		||||
    switch (activation) {
 | 
			
		||||
      case Activation::RELU:
 | 
			
		||||
        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
 | 
			
		||||
@ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 | 
			
		||||
        break;
 | 
			
		||||
      default: break;
 | 
			
		||||
    }
 | 
			
		||||
  } // end GEMM path
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
// Preprocessor gate here needs to match the inverse of the check
 | 
			
		||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
 | 
			
		||||
// performing a post-GELU because we weren't able to use the GELU
 | 
			
		||||
// epilogue above.
 | 
			
		||||
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
 | 
			
		||||
  if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
 | 
			
		||||
  if (useLtInterface && activation == Activation::GELU) {
 | 
			
		||||
    at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
@ -2314,23 +2322,12 @@ _scaled_nvfp4_nvfp4(
 | 
			
		||||
          const Tensor& scale_b, const SwizzleType swizzle_b,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const c10::ScalarType out_dtype,
 | 
			
		||||
          Tensor& out,
 | 
			
		||||
          const std::optional<Tensor>& global_scale_a = std::nullopt,
 | 
			
		||||
          const std::optional<Tensor>& global_scale_b = std::nullopt) {
 | 
			
		||||
          const bool single_scale,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
 | 
			
		||||
#endif
 | 
			
		||||
  std::optional<Tensor> alpha = std::nullopt;
 | 
			
		||||
  // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise,
 | 
			
		||||
  //       if this is "And" we would silently do nothing in the case where one global scale is
 | 
			
		||||
  //       passed and not the other.
 | 
			
		||||
  if (global_scale_a.has_value() || global_scale_b.has_value()) {
 | 
			
		||||
    TORCH_CHECK_VALUE(global_scale_a.has_value(),
 | 
			
		||||
        "For two-level-scaled NVFP4, global_scale_a must have a value");
 | 
			
		||||
    TORCH_CHECK_VALUE(global_scale_b.has_value(),
 | 
			
		||||
        "For two-level-scaled NVFP4, global_scale_b must have a value");
 | 
			
		||||
    alpha = global_scale_a.value().mul(global_scale_b.value());
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
 | 
			
		||||
  // Restrictions:
 | 
			
		||||
  // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
 | 
			
		||||
  // Scales must be swizzled
 | 
			
		||||
@ -2352,7 +2349,7 @@ _scaled_nvfp4_nvfp4(
 | 
			
		||||
 | 
			
		||||
  auto scaling_choice_a = ScalingType::BlockWise1x16;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::BlockWise1x16;
 | 
			
		||||
  return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha);
 | 
			
		||||
  return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2558,10 +2555,9 @@ _scaled_mm_cuda_v2_out(
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
 | 
			
		||||
    return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
 | 
			
		||||
    return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out,
 | 
			
		||||
                               scale_a[1], scale_b[1]);
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
 | 
			
		||||
    return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
 | 
			
		||||
    return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
 | 
			
		||||
    return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
 | 
			
		||||
  } else {
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,9 @@
 | 
			
		||||
#include <ATen/native/cuda/block_reduce.cuh>
 | 
			
		||||
#include <ATen/native/cuda/thread_constants.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/reverse_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -238,6 +240,10 @@ __global__ void renorm_kernel(
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
 | 
			
		||||
                               int64_t num_weights, int64_t padding_idx,
 | 
			
		||||
@ -300,6 +306,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
 | 
			
		||||
 | 
			
		||||
  if (scale_grad_by_freq) {
 | 
			
		||||
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
 | 
			
		||||
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
@ -326,6 +333,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
 | 
			
		||||
        num_indices
 | 
			
		||||
      );
 | 
			
		||||
    });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
 | 
			
		||||
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
 | 
			
		||||
    });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return embedding_backward_cuda_kernel(grad, orig_indices,
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,9 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/counting_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -194,9 +196,18 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm
 | 
			
		||||
            partials_per_segment_offset[num_of_segments-1];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
 | 
			
		||||
  *num_of_segments_ptr = num_of_segments;
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
} // anon namespace
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Tensor embedding_backward_cuda_kernel(
 | 
			
		||||
        const Tensor &grad,
 | 
			
		||||
@ -223,12 +234,20 @@ Tensor embedding_backward_cuda_kernel(
 | 
			
		||||
  auto segment_offsets = at::empty({numel}, orig_indices.options());
 | 
			
		||||
  auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
 | 
			
		||||
  int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
 | 
			
		||||
    int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
 | 
			
		||||
    write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  });
 | 
			
		||||
#else
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
 | 
			
		||||
    cuda::cub::unique_by_key(
 | 
			
		||||
      sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
 | 
			
		||||
      segment_offsets.mutable_data_ptr<index_t>(),
 | 
			
		||||
      num_of_segments_ptr, sorted_indices.numel());
 | 
			
		||||
  });
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  int64_t max_segments = std::min<int64_t>(numel, num_weights);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,10 +31,16 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/reverse_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
@ -193,6 +199,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
 | 
			
		||||
 | 
			
		||||
  if (scale_grad_by_freq) {
 | 
			
		||||
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
 | 
			
		||||
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
@ -219,6 +226,11 @@ Tensor embedding_bag_backward_cuda_sum_avg(
 | 
			
		||||
        num_indices
 | 
			
		||||
      );
 | 
			
		||||
    });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
 | 
			
		||||
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
 | 
			
		||||
    });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
 | 
			
		||||
      count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/native/cuda/SortingCommon.cuh>
 | 
			
		||||
#include <ATen/cuda/cub_definitions.cuh>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/empty_like.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/ThrustAllocator.h>
 | 
			
		||||
#include <thrust/device_ptr.h>
 | 
			
		||||
#include <thrust/execution_policy.h>
 | 
			
		||||
#include <thrust/sort.h>
 | 
			
		||||
#include <thrust/unique.h>
 | 
			
		||||
#include <thrust/device_ptr.h>
 | 
			
		||||
#include <thrust/iterator/constant_iterator.h>
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  at::cuda::ThrustAllocator allocator;
 | 
			
		||||
  auto policy = thrust::cuda::par(allocator).on(stream);
 | 
			
		||||
 | 
			
		||||
  auto num_indices = count.numel();
 | 
			
		||||
 | 
			
		||||
  // Compute an increasing sequence per unique item in sortedIndices:
 | 
			
		||||
  // sorted: 2 5 5 5 7 7 8 9 9
 | 
			
		||||
  //  count: 1 1 2 3 1 2 1 1 2
 | 
			
		||||
  auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
 | 
			
		||||
  auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
 | 
			
		||||
  thrust::inclusive_scan_by_key(
 | 
			
		||||
    policy,
 | 
			
		||||
    sorted_data,
 | 
			
		||||
    sorted_data + num_indices,
 | 
			
		||||
    thrust::make_constant_iterator(1),
 | 
			
		||||
    count_data
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  // Take the maximum of each count per unique key in reverse:
 | 
			
		||||
  // sorted: 2 5 5 5 7 7 8 9 9
 | 
			
		||||
  //  count: 1 3 3 3 2 2 1 2 2
 | 
			
		||||
  thrust::inclusive_scan_by_key(
 | 
			
		||||
    policy,
 | 
			
		||||
    thrust::make_reverse_iterator(sorted_data + num_indices),
 | 
			
		||||
    thrust::make_reverse_iterator(sorted_data),
 | 
			
		||||
    thrust::make_reverse_iterator(count_data + num_indices),
 | 
			
		||||
    thrust::make_reverse_iterator(count_data + num_indices),
 | 
			
		||||
    thrust::equal_to<index_t>(),
 | 
			
		||||
    thrust::maximum<index_t>()
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
template
 | 
			
		||||
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
 | 
			
		||||
  auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  at::cuda::ThrustAllocator allocator;
 | 
			
		||||
  auto policy = thrust::cuda::par(allocator).on(stream);
 | 
			
		||||
  const ptrdiff_t numel = sorted_indices.numel();
 | 
			
		||||
  auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
 | 
			
		||||
  auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
  auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
 | 
			
		||||
  auto ends = thrust::unique_by_key_copy(
 | 
			
		||||
          policy,
 | 
			
		||||
          sorted_indices_dev,
 | 
			
		||||
          sorted_indices_dev + numel,
 | 
			
		||||
          thrust::make_counting_iterator(0),
 | 
			
		||||
          dummy_dev,
 | 
			
		||||
          thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
 | 
			
		||||
  return thrust::get<0>(ends) - dummy_dev;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
template
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
@ -1,17 +1,18 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
 | 
			
		||||
#include <ATen/detail/FunctionTraits.h>
 | 
			
		||||
#include <ATen/native/TensorIterator.h>
 | 
			
		||||
#include <ATen/native/TensorIteratorDynamicCasting.h>
 | 
			
		||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
 | 
			
		||||
#include <ATen/OpMathType.h>
 | 
			
		||||
#include <ATen/native/cuda/thread_constants.h>
 | 
			
		||||
 | 
			
		||||
#include <thrust/tuple.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/native/cuda/MemoryAccess.cuh>
 | 
			
		||||
 | 
			
		||||
#include <tuple>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
template<int N>
 | 
			
		||||
@ -61,11 +62,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
 | 
			
		||||
  #pragma unroll
 | 
			
		||||
  for (int i = 0; i < elems_per_thread; i++) {
 | 
			
		||||
    if (policy.check_inbounds(i)) {
 | 
			
		||||
#if defined(__HIP__)
 | 
			
		||||
      results[i] = c10::guts::apply(f, args[i]);
 | 
			
		||||
#else
 | 
			
		||||
      results[i] = std::apply(f, args[i]);
 | 
			
		||||
#endif
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -146,7 +146,6 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
 | 
			
		||||
  int64_t batch_size = target.size(0);
 | 
			
		||||
  int64_t H = target.size(1);
 | 
			
		||||
  int64_t W = target.size(2);
 | 
			
		||||
  int64_t n_classes = grad_input.size(1);
 | 
			
		||||
 | 
			
		||||
  CUDA_KERNEL_LOOP(index, n_threads) {
 | 
			
		||||
    const int64_t b = index % batch_size;
 | 
			
		||||
@ -157,7 +156,6 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
 | 
			
		||||
    if (cur_target == ignore_index) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
 | 
			
		||||
    scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
 | 
			
		||||
    grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ namespace at::native {
 | 
			
		||||
 | 
			
		||||
// The maximum number of threads in a block
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 1024;
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 256;
 | 
			
		||||
#else
 | 
			
		||||
constexpr int MAX_BLOCK_SIZE = 512;
 | 
			
		||||
#endif
 | 
			
		||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
 | 
			
		||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
 | 
			
		||||
static int getNumThreads(int nElem) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
  int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
 | 
			
		||||
  int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
 | 
			
		||||
#else
 | 
			
		||||
  int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
 | 
			
		||||
#endif
 | 
			
		||||
@ -115,23 +115,9 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
 | 
			
		||||
  // first the reductions each thread does separately
 | 
			
		||||
  scalar_t sum = static_cast<scalar_t>(0);
 | 
			
		||||
  for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
    constexpr int UNRL = 4; // load deserilize factor
 | 
			
		||||
    scalar_t tmp[UNRL];
 | 
			
		||||
    for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int u = 0; u < UNRL; u++)
 | 
			
		||||
        tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int u = 0; u < UNRL; u++)
 | 
			
		||||
        if (x+u*blockDim.x < tensor.size(2))
 | 
			
		||||
          sum += tmp[u];
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
 | 
			
		||||
      sum += op(batch, plane, x);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  __shared__ scalar_t shared[C10_WARP_SIZE];
 | 
			
		||||
  SumReduceOp<scalar_t> reduce_op;
 | 
			
		||||
@ -306,22 +292,6 @@ __global__ void batch_norm_collect_statistics_kernel(
 | 
			
		||||
  stat_accscalar_t var_n = 0;
 | 
			
		||||
  int n = 0;
 | 
			
		||||
  for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
    constexpr int UNRL = 4;
 | 
			
		||||
    stat_accscalar_t v_[UNRL];
 | 
			
		||||
    for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
 | 
			
		||||
      for (int u = 0; u < UNRL; u++)
 | 
			
		||||
        v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
 | 
			
		||||
      for (int u = 0; u < UNRL; u++) {
 | 
			
		||||
        if (x+u*blockDim.x < input.size(2)) {
 | 
			
		||||
          stat_accscalar_t d1 = v_[u] - avg;
 | 
			
		||||
          n++;
 | 
			
		||||
          avg += d1 / n;
 | 
			
		||||
          var_n += d1 * (v_[u] - avg);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
 | 
			
		||||
      stat_accscalar_t v = input[batch][plane][x];
 | 
			
		||||
      stat_accscalar_t d1 = v - avg;
 | 
			
		||||
@ -329,7 +299,6 @@ __global__ void batch_norm_collect_statistics_kernel(
 | 
			
		||||
      avg += d1 / n;
 | 
			
		||||
      var_n += d1 * (v - avg);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // first warpSum to get one value per thread to
 | 
			
		||||
 | 
			
		||||
@ -413,12 +413,14 @@ struct ReduceOp {
 | 
			
		||||
      value = thread_reduce<output_vec_size>(input_slice);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (config.should_block_x_reduce()) {
 | 
			
		||||
      value = block_x_reduce<output_vec_size>(value, shared_memory);
 | 
			
		||||
    }
 | 
			
		||||
    if (config.should_block_y_reduce()) {
 | 
			
		||||
      value = block_y_reduce<output_vec_size>(value, shared_memory);
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    if (config.should_block_x_reduce()) {
 | 
			
		||||
      value = block_x_reduce<output_vec_size>(value, shared_memory);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>;
 | 
			
		||||
    using offset_vec_t = std::array<index_t, output_vec_size>;
 | 
			
		||||
    offset_vec_t base_offsets;
 | 
			
		||||
@ -655,8 +657,8 @@ struct ReduceOp {
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
 | 
			
		||||
    // matching Triton, etc.
 | 
			
		||||
    // TODO(PaulZhang12): AMD and internal
 | 
			
		||||
    #if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
 | 
			
		||||
    // todo for AMD
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
 | 
			
		||||
@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t>  get_index_mapping2d(
 | 
			
		||||
    output_offset + output_y * output_dim_x + output_x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
 | 
			
		||||
  const int64_t two = (len - 1) * 2;
 | 
			
		||||
  if (two <= 0) {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
  int64_t m = x % two;
 | 
			
		||||
  if (m < 0) m += two;
 | 
			
		||||
  return (m < len) ? m : (two - m);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
__global__ void reflection_pad1d_out_kernel(
 | 
			
		||||
    const scalar_t * input, scalar_t * output,
 | 
			
		||||
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void reflection_pad1d_flat(
 | 
			
		||||
    const scalar_t* __restrict__ input,
 | 
			
		||||
    scalar_t* __restrict__ output,
 | 
			
		||||
    int64_t input_w, int64_t pad_l, int64_t pad_r,
 | 
			
		||||
    int64_t out_w, int64_t plane_count) {
 | 
			
		||||
 | 
			
		||||
  const int64_t bx = blockDim.x;
 | 
			
		||||
  const int64_t tx = threadIdx.x;
 | 
			
		||||
 | 
			
		||||
  const int64_t total = plane_count * out_w;
 | 
			
		||||
  const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
 | 
			
		||||
  int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
 | 
			
		||||
 | 
			
		||||
  for (; linear < total; linear += grid_stride) {
 | 
			
		||||
    const int64_t plane = linear / out_w;
 | 
			
		||||
    const int64_t x = linear - plane * out_w;
 | 
			
		||||
    const int64_t j = reflect_index(x - pad_l, input_w);
 | 
			
		||||
    output[plane * out_w + x] = input[plane * input_w + j];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void reflection_pad1d_backward_out_kernel(
 | 
			
		||||
    scalar_t * grad_input, const scalar_t * grad_output,
 | 
			
		||||
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
 | 
			
		||||
  int64_t input_w = input_.size(dim_w);
 | 
			
		||||
  int64_t output_w = input_w + pad_l + pad_r;
 | 
			
		||||
 | 
			
		||||
  dim3 block_size(output_w > 256 ? 256 : output_w);
 | 
			
		||||
  dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
 | 
			
		||||
 | 
			
		||||
  Tensor input = input_.contiguous();
 | 
			
		||||
 | 
			
		||||
  const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
 | 
			
		||||
  const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
 | 
			
		||||
  const int max_x = prop->maxGridSize[0];
 | 
			
		||||
  const int max_y = prop->maxGridSize[1];
 | 
			
		||||
  const int max_z = prop->maxGridSize[2];
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
 | 
			
		||||
    auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
    const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
 | 
			
		||||
 | 
			
		||||
    const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
 | 
			
		||||
 | 
			
		||||
    if (fits3d) {
 | 
			
		||||
      dim3 block(block_x, 1, 1);
 | 
			
		||||
      dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
 | 
			
		||||
      reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
          input.const_data_ptr<scalar_t>(),
 | 
			
		||||
          output.mutable_data_ptr<scalar_t>(),
 | 
			
		||||
          input_w, pad_l, pad_r);
 | 
			
		||||
    } else {
 | 
			
		||||
      dim3 block(block_x, 1, 1);
 | 
			
		||||
      const int64_t plane_count = nplane * nbatch;
 | 
			
		||||
      const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
 | 
			
		||||
      const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
 | 
			
		||||
      dim3 grid(grid_x, 1, 1);
 | 
			
		||||
 | 
			
		||||
      reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
          input.const_data_ptr<scalar_t>(),
 | 
			
		||||
          output.mutable_data_ptr<scalar_t>(),
 | 
			
		||||
          input_w, pad_l, pad_r, output_w, plane_count);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  });
 | 
			
		||||
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
 | 
			
		||||
      kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
 | 
			
		||||
        reflection_pad1d_out_kernel<<<
 | 
			
		||||
            grid_size,
 | 
			
		||||
            block_size,
 | 
			
		||||
            0,
 | 
			
		||||
            at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
            input.const_data_ptr<scalar_t>(),
 | 
			
		||||
            output.mutable_data_ptr<scalar_t>(),
 | 
			
		||||
            input_w,
 | 
			
		||||
            pad_l,
 | 
			
		||||
            pad_r);
 | 
			
		||||
        C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
// TODO: remove this when CUDA <11.6 is no longer supported
 | 
			
		||||
void topk_out_with_sort(
 | 
			
		||||
  const Tensor& self,
 | 
			
		||||
  int64_t k, int64_t dim, bool largest,
 | 
			
		||||
@ -30,12 +31,21 @@ void topk_out_with_sort(
 | 
			
		||||
  indices.copy_(sorted_indices.narrow(dim, 0, k));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: remove this when CUDA <11.6 is no longer supported
 | 
			
		||||
bool disable_sort_for_topk();
 | 
			
		||||
bool should_use_sort(const Tensor& self, int64_t dim) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
  if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972
 | 
			
		||||
  return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387
 | 
			
		||||
#else
 | 
			
		||||
  return false;
 | 
			
		||||
  if (disable_sort_for_topk()) return false;
 | 
			
		||||
  // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
 | 
			
		||||
  if (self.dim() == 0) return false;
 | 
			
		||||
  if (self.dtype() == kBool) return false; // Bool is not support by topk
 | 
			
		||||
  int64_t slice_size = self.size(dim);
 | 
			
		||||
  if (slice_size == 0) return false;
 | 
			
		||||
  int64_t num_slices = self.numel() / slice_size;
 | 
			
		||||
  return num_slices <= 10 && slice_size >= 100000;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,11 @@ using namespace at::native;
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
// TODO: remove this when CUDA <11.6 is no longer supported
 | 
			
		||||
bool disable_sort_for_topk() {
 | 
			
		||||
  return CUB_SUPPORTS_SCAN_BY_KEY();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace sbtopk { // single_block_topk
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
@ -413,6 +418,10 @@ __global__ void computeBlockwiseWithinKCounts(
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  return;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
 | 
			
		||||
 | 
			
		||||
  // if largest, then only threads that has tidx > desired_digit are active
 | 
			
		||||
@ -468,6 +477,7 @@ __global__ void computeBlockwiseWithinKCounts(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
// Assumption: slice_size can not be larger than UINT32_MAX
 | 
			
		||||
template <typename Bitwise>
 | 
			
		||||
__global__ void computeBlockwiseKthCounts(
 | 
			
		||||
@ -599,6 +609,7 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
 | 
			
		||||
  // occupancy of this kernel is limited by registers per threads
 | 
			
		||||
@ -676,12 +687,16 @@ void launch(
 | 
			
		||||
  uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get());
 | 
			
		||||
  AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream));
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
 | 
			
		||||
  uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
 | 
			
		||||
  AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
 | 
			
		||||
 | 
			
		||||
  auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
 | 
			
		||||
  uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
 | 
			
		||||
#else
 | 
			
		||||
  uint32_t* withinKCounts = nullptr;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  Bitwise desiredMask = 0;
 | 
			
		||||
  dim3 grid;
 | 
			
		||||
@ -728,6 +743,7 @@ void launch(
 | 
			
		||||
  }
 | 
			
		||||
  desired = desired_in;
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
 | 
			
		||||
    desired, counts, num_blocks, blocks_per_slice, kthCounts);
 | 
			
		||||
  C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
@ -743,6 +759,28 @@ void launch(
 | 
			
		||||
    topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
 | 
			
		||||
    blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
 | 
			
		||||
  C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#else
 | 
			
		||||
  // Find topk values based on kth values
 | 
			
		||||
  {
 | 
			
		||||
    dim3 grid;
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
 | 
			
		||||
    int warp_size = at::cuda::warp_size();
 | 
			
		||||
    dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
 | 
			
		||||
    sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
 | 
			
		||||
        input,
 | 
			
		||||
        inputSliceSize,
 | 
			
		||||
        outputSliceSize,
 | 
			
		||||
        largest,
 | 
			
		||||
        numInputSlices,
 | 
			
		||||
        inputWithinSliceStride,
 | 
			
		||||
        topK,
 | 
			
		||||
        topKWithinSliceStride,
 | 
			
		||||
        indices,
 | 
			
		||||
        indicesWithinSliceStride,
 | 
			
		||||
        kthValues);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mbtopk
 | 
			
		||||
@ -750,6 +788,7 @@ void launch(
 | 
			
		||||
bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
 | 
			
		||||
  if (num_slices > std::numeric_limits<uint32_t>::max() ||
 | 
			
		||||
      slice_size > std::numeric_limits<uint32_t>::max()) return false;
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
 | 
			
		||||
  return (num_slices <= 20 && slice_size >= 20000) ||
 | 
			
		||||
      (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
 | 
			
		||||
@ -758,6 +797,12 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
 | 
			
		||||
      (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
 | 
			
		||||
      (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
 | 
			
		||||
      (num_slices > 4000 && slice_size >= 400);
 | 
			
		||||
#else
 | 
			
		||||
  // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
 | 
			
		||||
  return (num_slices <= 400 && slice_size >= 5000) ||
 | 
			
		||||
      (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
 | 
			
		||||
      (num_slices >= 4000 && slice_size >= 300);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void launch_gather_topk_kernel(
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
 | 
			
		||||
    const int64_t k,
 | 
			
		||||
    const int64_t N_padded,
 | 
			
		||||
    const IndexType last_dim_padded) {
 | 
			
		||||
  int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
 | 
			
		||||
  int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
 | 
			
		||||
  if (linear_idx >= N_padded) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
// Helper function to compute output pixel range that can contribute to input pixel
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
__device__ __forceinline__ void compute_output_range(
 | 
			
		||||
    int input_pos,
 | 
			
		||||
    accscalar_t scale,
 | 
			
		||||
    int output_size,
 | 
			
		||||
    bool align_corners,
 | 
			
		||||
    int& min_output,
 | 
			
		||||
    int& max_output) {
 | 
			
		||||
  accscalar_t lo, hi;
 | 
			
		||||
  if (align_corners) {
 | 
			
		||||
      lo = static_cast<accscalar_t>(input_pos - 1) / scale;
 | 
			
		||||
      hi = static_cast<accscalar_t>(input_pos + 1) / scale;
 | 
			
		||||
  } else {
 | 
			
		||||
      lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
      hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
  }
 | 
			
		||||
  min_output = max(0, static_cast<int>(ceil(lo)));
 | 
			
		||||
  max_output = min(output_size - 1, static_cast<int>(floor(hi)));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(1024)
 | 
			
		||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
    const bool align_corners,
 | 
			
		||||
    scalar_t* __restrict__ idata,
 | 
			
		||||
    const scalar_t* __restrict__ odata) {
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    // Decode input pixel coordinates
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
    const int w1 = index_temp % width1;
 | 
			
		||||
    index_temp /= width1;
 | 
			
		||||
    const int h1 = index_temp % height1;
 | 
			
		||||
    const size_t nc_idx = index_temp / height1;
 | 
			
		||||
 | 
			
		||||
    accscalar_t grad_sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Find range of output pixels that could interpolate from this input pixel
 | 
			
		||||
    int h2_min, h2_max, w2_min, w2_max;
 | 
			
		||||
    compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
 | 
			
		||||
    compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
 | 
			
		||||
 | 
			
		||||
    // Iterate over potential output pixels
 | 
			
		||||
    for (int h2 = h2_min; h2 <= h2_max; h2++) {
 | 
			
		||||
      for (int w2 = w2_min; w2 <= w2_max; w2++) {
 | 
			
		||||
        // Compute source coordinates for this output pixel
 | 
			
		||||
        const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rheight, h2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int h1_base = (int)h1r;
 | 
			
		||||
        const int h1p = (h1_base < height1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t h1lambda = h1r - h1_base;
 | 
			
		||||
        const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
 | 
			
		||||
 | 
			
		||||
        const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rwidth, w2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int w1_base = (int)w1r;
 | 
			
		||||
        const int w1p = (w1_base < width1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t w1lambda = w1r - w1_base;
 | 
			
		||||
        const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
 | 
			
		||||
 | 
			
		||||
        // Check if our input pixel participates in this interpolation and accumulate all weights
 | 
			
		||||
        // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
 | 
			
		||||
        // to the same pixel, so we need to accumulate weights from all matching positions
 | 
			
		||||
        accscalar_t weight = 0;
 | 
			
		||||
 | 
			
		||||
        // Check all four interpolation positions and accumulate weights
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base) {
 | 
			
		||||
          weight += h0lambda * w0lambda;  // top-left
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base) {
 | 
			
		||||
          weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (weight > 0) {
 | 
			
		||||
          const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
 | 
			
		||||
          grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write accumulated gradient (no atomics needed)
 | 
			
		||||
    idata[index] = static_cast<scalar_t>(grad_sum);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
        static_cast<scalar_t>(h1lambda * w1lambda * d2val),
 | 
			
		||||
        true);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
  // threads are not covering the whole input tensor.
 | 
			
		||||
  grad_input.zero_();
 | 
			
		||||
 | 
			
		||||
  const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
  const int num_threads = std::min(
 | 
			
		||||
      at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  constexpr bool use_input = true;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr bool use_input = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half, at::ScalarType::BFloat16,
 | 
			
		||||
      grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
 | 
			
		||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
 | 
			
		||||
              input_height,
 | 
			
		||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
 | 
			
		||||
             num_threads,
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
 | 
			
		||||
  using opmath_t = at::opmath_type<scalar_t>;
 | 
			
		||||
 | 
			
		||||
  C10_DEVICE __forceinline__ void operator()(
 | 
			
		||||
      int64_t chunk_size,
 | 
			
		||||
      int chunk_size,
 | 
			
		||||
      FusedOptimizerTensorListMetadata<3>& tl,
 | 
			
		||||
      const float* lr_ptr,
 | 
			
		||||
      const double& lr,
 | 
			
		||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
@ -466,7 +466,7 @@ struct ReduceJitOp {
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    #if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
 | 
			
		||||
@ -487,7 +487,9 @@ std::unique_ptr<fe::graph::Graph> build_graph(
 | 
			
		||||
  auto scaled_dot_product_flash_attention_options =
 | 
			
		||||
      fe::graph::SDPA_attributes()
 | 
			
		||||
          .set_name("CUDNN_SDPA")
 | 
			
		||||
          .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_is_inference(return_softmaxstats == false)
 | 
			
		||||
          // TODO(eqy): switch to this API once cuDNN FE is upgraded
 | 
			
		||||
          // .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_causal_mask(is_causal)
 | 
			
		||||
          .set_attn_scale(attn_scale);
 | 
			
		||||
  if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
 | 
			
		||||
@ -705,7 +707,9 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
 | 
			
		||||
  auto scaled_dot_product_flash_attention_options =
 | 
			
		||||
      fe::graph::SDPA_attributes()
 | 
			
		||||
          .set_name("CUDNN_SDPA_NESTEDTENSOR")
 | 
			
		||||
          .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_is_inference(return_softmaxstats == false)
 | 
			
		||||
          // TODO(eqy): switch to this API once cuDNN FE is upgraded
 | 
			
		||||
          // .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_causal_mask(is_causal)
 | 
			
		||||
          .set_attn_scale(attn_scale)
 | 
			
		||||
          .set_seq_len_q(SEQ_LEN_Q_)
 | 
			
		||||
 | 
			
		||||
@ -222,13 +222,6 @@ struct nextafter_functor {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct hypot_functor {
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  inline T operator()(const T a, const T b) {
 | 
			
		||||
    return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Complex binary functors
 | 
			
		||||
struct polar_functor {
 | 
			
		||||
  template <typename U>
 | 
			
		||||
@ -369,7 +362,6 @@ struct igammac_functor {
 | 
			
		||||
  REGISTER_OPMATH_BINARY_OP(NAME, half, half);   \
 | 
			
		||||
  REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
 | 
			
		||||
 | 
			
		||||
REGISTER_FLOAT_BINARY_OP(hypot);
 | 
			
		||||
REGISTER_FLOAT_BINARY_OP(copysign);
 | 
			
		||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
 | 
			
		||||
REGISTER_FLOAT_BINARY_OP(fmax);
 | 
			
		||||
 | 
			
		||||
@ -441,7 +441,7 @@ kernel void applySYRK(
 | 
			
		||||
    uint3 tid [[thread_position_in_threadgroup]],
 | 
			
		||||
    uint3 tgid [[threadgroup_position_in_grid]],
 | 
			
		||||
    uint3 tpg [[threads_per_threadgroup]],
 | 
			
		||||
    uint warp_id [[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
    uint sgitg [[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
  const uint tx = tid.x;
 | 
			
		||||
  const uint ty = tid.y;
 | 
			
		||||
  const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32;
 | 
			
		||||
@ -474,8 +474,11 @@ kernel void applySYRK(
 | 
			
		||||
      (actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0);
 | 
			
		||||
 | 
			
		||||
  if (use_simdgroup) {
 | 
			
		||||
    uint warp_id = sgitg;
 | 
			
		||||
 | 
			
		||||
    simdgroup_matrix<float, 8, 8> negative_identity =
 | 
			
		||||
        simdgroup_matrix<float, 8, 8>(-1.0);
 | 
			
		||||
    simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0);
 | 
			
		||||
    simdgroup_matrix<float, 8, 8> Prod;
 | 
			
		||||
    simdgroup_matrix<float, 8, 8> Afrag;
 | 
			
		||||
    simdgroup_matrix<float, 8, 8> Bfrag;
 | 
			
		||||
@ -518,7 +521,8 @@ kernel void applySYRK(
 | 
			
		||||
            /* transpose = */ upper);
 | 
			
		||||
 | 
			
		||||
        simdgroup_multiply(Prod, Afrag, Bfrag);
 | 
			
		||||
        simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag);
 | 
			
		||||
        simdgroup_multiply(Prod, Prod, negative_identity);
 | 
			
		||||
        simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      simdgroup_store(
 | 
			
		||||
 | 
			
		||||
@ -5,21 +5,6 @@
 | 
			
		||||
using namespace metal;
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
 | 
			
		||||
struct angle_functor {
 | 
			
		||||
  template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
 | 
			
		||||
  inline T operator()(const T x) {
 | 
			
		||||
    return T(atan2(x.y, x.x), 0);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
 | 
			
		||||
  inline T operator()(const T x) {
 | 
			
		||||
    return T(isnan(x) ? x : x < 0 ? M_PI_F : 0.0);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
 | 
			
		||||
  inline float operator()(const T x) {
 | 
			
		||||
    return x < 0 ? M_PI_F : 0.0;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Implement exp wrapper for both real and complex types
 | 
			
		||||
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
 | 
			
		||||
inline T exp_(const T x) {
 | 
			
		||||
@ -560,7 +545,6 @@ REGISTER_UNARY_OP(abs, float, float);
 | 
			
		||||
REGISTER_UNARY_OP(abs, half, half);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
 | 
			
		||||
  REGISTER_UNARY_OP(angle, DTYPE1, DTYPE0);        \
 | 
			
		||||
  REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0);          \
 | 
			
		||||
  REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0);         \
 | 
			
		||||
  REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0);       \
 | 
			
		||||
@ -599,7 +583,6 @@ INSTANTIATE_UNARY_KERNELS2(float, int);
 | 
			
		||||
INSTANTIATE_UNARY_KERNELS2(float, long);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE)     \
 | 
			
		||||
  REGISTER_UNARY_OP(angle, DTYPE##2, DTYPE##2);   \
 | 
			
		||||
  REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2);     \
 | 
			
		||||
  REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2);     \
 | 
			
		||||
  REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2);   \
 | 
			
		||||
 | 
			
		||||
@ -92,8 +92,13 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // upcasting to float32 if needed to improve precision when multiplying by the scale factor
 | 
			
		||||
          maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
 | 
			
		||||
          if ([maskedMM dataType] != MPSDataTypeFloat32) {
 | 
			
		||||
            maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
 | 
			
		||||
          }
 | 
			
		||||
          maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
 | 
			
		||||
          if ([maskedMM dataType] != qTensor.dataType) {
 | 
			
		||||
            maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if (is_causal) {
 | 
			
		||||
            auto causalMask = [mpsGraph constantWithScalar:1.0f
 | 
			
		||||
@ -107,9 +112,7 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
 | 
			
		||||
                                                      name:nil];
 | 
			
		||||
          } else if (attn_mask) {
 | 
			
		||||
            graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
 | 
			
		||||
            maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
 | 
			
		||||
                                           secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
 | 
			
		||||
                                                      name:nil];
 | 
			
		||||
            maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
 | 
			
		||||
@ -130,8 +133,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
 | 
			
		||||
          graph->qTensor = qTensor;
 | 
			
		||||
          graph->kTensor = kTensor;
 | 
			
		||||
          graph->vTensor = vTensor;
 | 
			
		||||
          graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
 | 
			
		||||
          graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
 | 
			
		||||
          graph->outputTensor = output;
 | 
			
		||||
          graph->attnTensor = sm;
 | 
			
		||||
        });
 | 
			
		||||
    auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
 | 
			
		||||
    auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
 | 
			
		||||
 | 
			
		||||
@ -202,10 +202,6 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
 | 
			
		||||
  lib.exec_binary_kernel(iter, "igammac");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void hypot_mps_kernel(TensorIteratorBase& iter) {
 | 
			
		||||
  lib.exec_binary_kernel(iter, "hypot");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
 | 
			
		||||
@ -233,5 +229,4 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
 | 
			
		||||
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user