mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			114 Commits
		
	
	
		
			copilot/co
			...
			gh/eelliso
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 21cc201861 | |||
| ab82456c16 | |||
| b23f4687fd | |||
| 2705937080 | |||
| c1eda348be | |||
| ba93d5636e | |||
| 722b2b86c9 | |||
| e1e8491b31 | |||
| 767199fd9b | |||
| 602ace5eb4 | |||
| 47804ce467 | |||
| e8cb34dd52 | |||
| e9d8973427 | |||
| 61d9a5180e | |||
| 8a8329b51f | |||
| 6b80c94901 | |||
| 8951df03de | |||
| 8139f33fa5 | |||
| a88587348b | |||
| 633a3b7f67 | |||
| fa0db212e7 | |||
| 15ff1cd28b | |||
| c73f5080de | |||
| 22ae059d32 | |||
| 1b121d636e | |||
| 1ba808dd97 | |||
| b2f5c25b27 | |||
| a1114beed2 | |||
| 4888ed440e | |||
| 5d62b63a76 | |||
| 57ba575242 | |||
| ceb11a584d | |||
| 33adb276fe | |||
| e939651972 | |||
| 3255e7872b | |||
| c4f6619330 | |||
| f18041cca8 | |||
| 35e51893bd | |||
| 1f43d17ce6 | |||
| 032bed95cd | |||
| d14cbb4476 | |||
| f510d0dbc0 | |||
| beb6b62e8c | |||
| 4740ce7787 | |||
| ad67170c8b | |||
| fdab48a7c1 | |||
| a0948d4d23 | |||
| 0bbdd6b8db | |||
| 24520b8386 | |||
| c79dfdc655 | |||
| e595136187 | |||
| aaac8cb0f5 | |||
| 0f0b4bf029 | |||
| b8194268a6 | |||
| f02e3947f6 | |||
| 9095a9dfae | |||
| d9f94e0d7d | |||
| 23417ae50f | |||
| e4d6c56ffb | |||
| 017d2985f3 | |||
| c6a8db0b9a | |||
| de09bab4b6 | |||
| c137e222d4 | |||
| cf3a787bbc | |||
| de3da77cf7 | |||
| 543ddbf44c | |||
| e9f4999985 | |||
| 29b029648e | |||
| a25a649e70 | |||
| 69c33898fa | |||
| 1b397420f2 | |||
| fe80f03726 | |||
| e50dc40d28 | |||
| 2e22b1a61e | |||
| 616c6bdf8f | |||
| c18ddfc572 | |||
| 86ebce1766 | |||
| 8cb2fb44f2 | |||
| ab65498d71 | |||
| 06d324365c | |||
| 6c9c6e0936 | |||
| 2bcd892c86 | |||
| 75e2a9fae3 | |||
| a16fd6b488 | |||
| 382b0150de | |||
| a664b299ac | |||
| 9c12651417 | |||
| 08c97b4a1f | |||
| fae74cd52f | |||
| 7a65770013 | |||
| e4454947e2 | |||
| 3806e9767b | |||
| b08d8c2e50 | |||
| ca5b7f8ded | |||
| 9a71d96256 | |||
| 0d4c2b71e8 | |||
| d659bbde62 | |||
| 58879bfafa | |||
| a032510db3 | |||
| 39e0a832c9 | |||
| dd3b48e85d | |||
| cff1b20771 | |||
| da8517fa63 | |||
| 45afaf08a1 | |||
| 080365b7d8 | |||
| 2928c5c572 | |||
| 630520b346 | |||
| 1dc9a05d03 | |||
| bfcdbd0a97 | |||
| faff826a46 | |||
| 85c5433d38 | |||
| 935ccdbe75 | |||
| 3af2f0c12a | |||
| d0f0ee68aa | 
@ -20,7 +20,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
 | 
			
		||||
 | 
			
		||||
# cmake-3.18.4 from pip
 | 
			
		||||
RUN yum install -y python3-pip && \
 | 
			
		||||
    python3 -m pip install cmake==3.18.4 && \
 | 
			
		||||
    python3 -mpip install cmake==3.18.4 && \
 | 
			
		||||
    ln -s /usr/local/bin/cmake /usr/bin/cmake3
 | 
			
		||||
RUN rm -rf /usr/local/cuda-*
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -83,10 +83,6 @@ function build_cpython {
 | 
			
		||||
        py_suffix=${py_ver::-1}
 | 
			
		||||
        py_folder=$py_suffix
 | 
			
		||||
    fi
 | 
			
		||||
    # Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4
 | 
			
		||||
    if [ "$py_suffix" == "3.14.0" ]; then
 | 
			
		||||
        py_suffix="3.14.0rc2"
 | 
			
		||||
    fi
 | 
			
		||||
    wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
 | 
			
		||||
    do_cpython_build $py_ver Python-$py_suffix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ function install_torchbench() {
 | 
			
		||||
  python install.py --continue_on_fail
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -m pip freeze
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
  chown -R jenkins torchbench
 | 
			
		||||
 | 
			
		||||
@ -8,8 +8,8 @@ MKLROOT=/opt/intel
 | 
			
		||||
mkdir -p ${MKLROOT}
 | 
			
		||||
pushd /tmp
 | 
			
		||||
 | 
			
		||||
python3 -m pip install wheel
 | 
			
		||||
python3 -m pip download -d . mkl-static==${MKL_VERSION}
 | 
			
		||||
python3 -mpip install wheel
 | 
			
		||||
python3 -mpip download -d . mkl-static==${MKL_VERSION}
 | 
			
		||||
python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
 | 
			
		||||
python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
 | 
			
		||||
mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT}
 | 
			
		||||
 | 
			
		||||
@ -11,5 +11,5 @@ ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
 | 
			
		||||
python -m venv /var/lib/jenkins/ci_env
 | 
			
		||||
source /var/lib/jenkins/ci_env/bin/activate
 | 
			
		||||
 | 
			
		||||
python -m pip install --upgrade pip
 | 
			
		||||
python -m pip install -r /opt/requirements-ci.txt
 | 
			
		||||
python -mpip install --upgrade pip
 | 
			
		||||
python -mpip install -r /opt/requirements-ci.txt
 | 
			
		||||
 | 
			
		||||
@ -39,9 +39,13 @@ case ${DOCKER_TAG_PREFIX} in
 | 
			
		||||
        DOCKER_GPU_BUILD_ARG=""
 | 
			
		||||
        ;;
 | 
			
		||||
    rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
        fi
 | 
			
		||||
        BASE_TARGET=rocm
 | 
			
		||||
        GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op
 | 
			
		||||
 | 
			
		||||
# cmake-3.18.4 from pip
 | 
			
		||||
RUN yum install -y python3-pip && \
 | 
			
		||||
    python3 -m pip install cmake==3.18.4 && \
 | 
			
		||||
    python3 -mpip install cmake==3.18.4 && \
 | 
			
		||||
    ln -s /usr/local/bin/cmake /usr/bin/cmake3
 | 
			
		||||
 | 
			
		||||
FROM base as openssl
 | 
			
		||||
@ -135,7 +135,7 @@ RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
 | 
			
		||||
 | 
			
		||||
# cmake-3.18.4 from pip; force in case cmake3 already exists
 | 
			
		||||
RUN yum install -y python3-pip && \
 | 
			
		||||
    python3 -m pip install cmake==3.18.4 && \
 | 
			
		||||
    python3 -mpip install cmake==3.18.4 && \
 | 
			
		||||
    ln -sf /usr/local/bin/cmake /usr/bin/cmake3
 | 
			
		||||
 | 
			
		||||
FROM cpu_final as cuda_final
 | 
			
		||||
@ -157,7 +157,7 @@ ENV ROCM_PATH /opt/rocm
 | 
			
		||||
# cmake-3.28.4 from pip to get enable_language(HIP)
 | 
			
		||||
# and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker
 | 
			
		||||
RUN python3 -m pip install --upgrade pip && \
 | 
			
		||||
    python3 -m pip install cmake==3.28.4
 | 
			
		||||
    python3 -mpip install cmake==3.28.4
 | 
			
		||||
# replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path
 | 
			
		||||
ADD ./common/install_rocm_drm.sh install_rocm_drm.sh
 | 
			
		||||
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
 | 
			
		||||
@ -174,7 +174,7 @@ FROM cpu_final as xpu_final
 | 
			
		||||
ENV XPU_DRIVER_TYPE ROLLING
 | 
			
		||||
# cmake-3.28.4 from pip
 | 
			
		||||
RUN python3 -m pip install --upgrade pip && \
 | 
			
		||||
    python3 -m pip install cmake==3.28.4
 | 
			
		||||
    python3 -mpip install cmake==3.28.4
 | 
			
		||||
ADD ./common/install_xpu.sh install_xpu.sh
 | 
			
		||||
ENV XPU_VERSION 2025.2
 | 
			
		||||
RUN bash ./install_xpu.sh && rm install_xpu.sh
 | 
			
		||||
 | 
			
		||||
@ -113,7 +113,7 @@ RUN dnf install -y \
 | 
			
		||||
RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
 | 
			
		||||
 | 
			
		||||
# cmake-3.28.0 from pip for onnxruntime
 | 
			
		||||
RUN python3 -m pip install cmake==3.28.0
 | 
			
		||||
RUN python3 -mpip install cmake==3.28.0
 | 
			
		||||
 | 
			
		||||
ADD ./common/patch_libstdc.sh patch_libstdc.sh
 | 
			
		||||
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
 | 
			
		||||
 | 
			
		||||
@ -75,9 +75,13 @@ case ${image} in
 | 
			
		||||
        DOCKERFILE_SUFFIX="_cuda_aarch64"
 | 
			
		||||
        ;;
 | 
			
		||||
    manylinux2_28-builder:rocm*)
 | 
			
		||||
        # we want the patch version of 7.0 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
        fi
 | 
			
		||||
        # we want the patch version of 6.4 instead
 | 
			
		||||
        if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
 | 
			
		||||
            GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
 | 
			
		||||
        fi
 | 
			
		||||
        TARGET=rocm_final
 | 
			
		||||
        MANY_LINUX_VERSION="2_28"
 | 
			
		||||
 | 
			
		||||
@ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules
 | 
			
		||||
        logger.info("Successfully cloned %s", target)
 | 
			
		||||
        return r, commit
 | 
			
		||||
 | 
			
		||||
    except GitCommandError as e:
 | 
			
		||||
        logger.error("Git operation failed: %s", e)
 | 
			
		||||
    except GitCommandError:
 | 
			
		||||
        logger.exception("Git operation failed")
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -288,7 +288,7 @@ else
 | 
			
		||||
    # or building non-XLA tests.
 | 
			
		||||
    if [[ "$BUILD_ENVIRONMENT" != *rocm*  && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
 | 
			
		||||
      # Install numpy-2.0.2 for builds which are backward compatible with 1.X
 | 
			
		||||
      python -m pip install numpy==2.0.2
 | 
			
		||||
      python -mpip install numpy==2.0.2
 | 
			
		||||
 | 
			
		||||
      WERROR=1 python setup.py clean
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -67,13 +67,13 @@ function pip_install_whl() {
 | 
			
		||||
    # Loop through each path and install individually
 | 
			
		||||
    for path in "${paths[@]}"; do
 | 
			
		||||
      echo "Installing $path"
 | 
			
		||||
      python3 -m pip install --no-index --no-deps "$path"
 | 
			
		||||
      python3 -mpip install --no-index --no-deps "$path"
 | 
			
		||||
    done
 | 
			
		||||
  else
 | 
			
		||||
    # Loop through each argument and install individually
 | 
			
		||||
    for path in "${args[@]}"; do
 | 
			
		||||
      echo "Installing $path"
 | 
			
		||||
      python3 -m pip install --no-index --no-deps "$path"
 | 
			
		||||
      python3 -mpip install --no-index --no-deps "$path"
 | 
			
		||||
    done
 | 
			
		||||
  fi
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -182,7 +182,7 @@ checkout_install_torchbench() {
 | 
			
		||||
  pip uninstall -y torchao
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -m pip freeze
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
torchbench_setup_macos() {
 | 
			
		||||
@ -211,7 +211,7 @@ torchbench_setup_macos() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pip_benchmark_deps() {
 | 
			
		||||
  python -m pip install --no-input requests cython scikit-learn six
 | 
			
		||||
  python -mpip install --no-input requests cython scikit-learn six
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1434,7 +1434,7 @@ EOF
 | 
			
		||||
  # shellcheck source=./common-build.sh
 | 
			
		||||
  source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
 | 
			
		||||
  python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist"
 | 
			
		||||
  python -m pip install base_dist/*.whl
 | 
			
		||||
  python -mpip install base_dist/*.whl
 | 
			
		||||
  echo "::endgroup::"
 | 
			
		||||
 | 
			
		||||
  pushd test/forward_backward_compatibility
 | 
			
		||||
 | 
			
		||||
@ -173,7 +173,7 @@ esac
 | 
			
		||||
PINNED_PACKAGES=(
 | 
			
		||||
    "numpy${NUMPY_PINNED_VERSION}"
 | 
			
		||||
)
 | 
			
		||||
python -m venv ~/${desired_python}-build
 | 
			
		||||
python -mvenv ~/${desired_python}-build
 | 
			
		||||
source ~/${desired_python}-build/bin/activate
 | 
			
		||||
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
 | 
			
		||||
retry brew install libomp
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								.flake8
									
									
									
									
									
								
							@ -7,16 +7,12 @@ max-line-length = 120
 | 
			
		||||
# C408 ignored because we like the dict keyword argument syntax
 | 
			
		||||
# E501 is not flexible enough, we're using B950 instead
 | 
			
		||||
ignore =
 | 
			
		||||
    E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
 | 
			
		||||
    E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
 | 
			
		||||
    # shebang has extra meaning in fbcode lints, so I think it's not worth trying
 | 
			
		||||
    # to line this up with executable bit
 | 
			
		||||
    EXE001,
 | 
			
		||||
    # these ignores are from flake8-bugbear; please fix!
 | 
			
		||||
    B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
 | 
			
		||||
    # these ignores are from flake8-comprehensions; please fix!
 | 
			
		||||
    C407,
 | 
			
		||||
    # these ignores are from flake8-logging-format; please fix!
 | 
			
		||||
    G100,G101,G200
 | 
			
		||||
    # these ignores are from flake8-simplify. please fix or ignore with commented reason
 | 
			
		||||
    SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
 | 
			
		||||
    # SIM104 is already covered by pyupgrade ruff
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
 | 
			
		||||
69bbe7363897764f9e758d851cd0340147d27f94
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							@ -133,3 +133,32 @@
 | 
			
		||||
 | 
			
		||||
"ciflow/vllm":
 | 
			
		||||
- .github/ci_commit_pins/vllm.txt
 | 
			
		||||
 | 
			
		||||
"ciflow/b200":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/h100":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/**/*cublas*
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
"ciflow/rocm":
 | 
			
		||||
- test/test_matmul_cuda.py
 | 
			
		||||
- test/test_scaled_matmul_cuda.py
 | 
			
		||||
- test/inductor/test_fp8.py
 | 
			
		||||
- aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
- torch/_inductor/kernel/mm.py
 | 
			
		||||
- test/inductor/test_max_autotune.py
 | 
			
		||||
- third_party/fbgemm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
 | 
			
		||||
        "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "12.9": (
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
 | 
			
		||||
        "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
 | 
			
		||||
        "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
 | 
			
		||||
    ),
 | 
			
		||||
    "13.0": (
 | 
			
		||||
        "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/scripts/prepare_vllm_wheels.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/scripts/prepare_vllm_wheels.sh
									
									
									
									
										vendored
									
									
								
							@ -24,7 +24,7 @@ change_wheel_version() {
 | 
			
		||||
  local t_version=$4
 | 
			
		||||
 | 
			
		||||
  # Extract the wheel
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -m wheel unpack $wheel
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -mwheel unpack $wheel
 | 
			
		||||
 | 
			
		||||
  mv "${package}-${f_version}" "${package}-${t_version}"
 | 
			
		||||
  # Change the version from f_version to t_version in the dist-info dir
 | 
			
		||||
@ -47,7 +47,7 @@ change_wheel_version() {
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
  # Repack the wheel
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -m wheel pack "${package}-${t_version}"
 | 
			
		||||
  ${PYTHON_EXECUTABLE} -mwheel pack "${package}-${t_version}"
 | 
			
		||||
 | 
			
		||||
  # Clean up
 | 
			
		||||
  rm -rf "${package}-${t_version}"
 | 
			
		||||
@ -85,7 +85,7 @@ repackage_wheel() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Require to re-package the wheel
 | 
			
		||||
${PYTHON_EXECUTABLE} -m pip install wheel==0.45.1
 | 
			
		||||
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1
 | 
			
		||||
 | 
			
		||||
pushd externals/vllm/wheels
 | 
			
		||||
for package in xformers flashinfer-python vllm; do
 | 
			
		||||
 | 
			
		||||
@ -26,9 +26,8 @@ name: !{{ build_environment }}
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}"
 | 
			
		||||
          python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}"
 | 
			
		||||
          freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
 | 
			
		||||
{%- endmacro %}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/_mac-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/_mac-test.yml
									
									
									
									
										vendored
									
									
								
							@ -211,7 +211,7 @@ jobs:
 | 
			
		||||
            $tool --version
 | 
			
		||||
          done
 | 
			
		||||
 | 
			
		||||
          python3 -m pip install --no-index --no-deps dist/*.whl
 | 
			
		||||
          python3 -mpip install --no-index --no-deps dist/*.whl
 | 
			
		||||
 | 
			
		||||
          set +e
 | 
			
		||||
          pushd "${RUNNER_TEMP}"
 | 
			
		||||
@ -222,7 +222,7 @@ jobs:
 | 
			
		||||
          popd
 | 
			
		||||
 | 
			
		||||
          if [ "${RC}" -ne 0 ]; then
 | 
			
		||||
            python3 -m pip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
 | 
			
		||||
            python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
 | 
			
		||||
          fi
 | 
			
		||||
          set -e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/_win-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_win-test.yml
									
									
									
									
										vendored
									
									
								
							@ -204,7 +204,7 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
 | 
			
		||||
          # shellcheck disable=SC2046,SC2102
 | 
			
		||||
          python3 -m pip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
 | 
			
		||||
          python3 -mpip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
 | 
			
		||||
          popd
 | 
			
		||||
 | 
			
		||||
          .ci/pytorch/win-test.sh
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -126,13 +126,13 @@ jobs:
 | 
			
		||||
            "${MANYLINUX_IMAGE}"
 | 
			
		||||
          )
 | 
			
		||||
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install \
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip install \
 | 
			
		||||
            --pre torch torchvision torchaudio \
 | 
			
		||||
            --index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
 | 
			
		||||
 | 
			
		||||
          # I wonder if there is a command to both download and install the wheels
 | 
			
		||||
          # in one go
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip download \
 | 
			
		||||
          docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip download \
 | 
			
		||||
            --pre torch torchvision torchaudio \
 | 
			
		||||
            --index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -224,7 +224,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -473,7 +473,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -722,7 +722,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -971,7 +971,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1220,7 +1220,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1469,7 +1469,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
@ -1718,7 +1718,7 @@ jobs:
 | 
			
		||||
      ALPINE_IMAGE: "arm64v8/alpine"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda-aarch64-12_9
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
      timeout-minutes: 420
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -259,7 +259,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_10-cuda12_9-test:  # Testing
 | 
			
		||||
@ -925,7 +925,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_11-cuda12_9-test:  # Testing
 | 
			
		||||
@ -1591,7 +1591,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_12-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2257,7 +2257,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13-cuda12_9-test:  # Testing
 | 
			
		||||
@ -2923,7 +2923,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_13t-cuda12_9-test:  # Testing
 | 
			
		||||
@ -3589,7 +3589,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14-cuda12_9-test:  # Testing
 | 
			
		||||
@ -4255,7 +4255,7 @@ jobs:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_14t-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_14t-cuda12_9-test:  # Testing
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -63,7 +63,6 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.10.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -59,7 +59,6 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.10.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -106,7 +105,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -169,7 +168,6 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.11.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -216,7 +214,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -279,7 +277,6 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.12.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -326,7 +323,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -389,7 +386,6 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.13.4"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
@ -436,7 +432,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -499,7 +495,6 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.13.4"
 | 
			
		||||
          freethreaded: true
 | 
			
		||||
@ -546,7 +541,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -609,9 +604,8 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.14.0-rc.2"
 | 
			
		||||
          python-version: "3.14.0"
 | 
			
		||||
          freethreaded: false
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
@ -656,7 +650,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -719,9 +713,8 @@ jobs:
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@v6
 | 
			
		||||
        with:
 | 
			
		||||
          # TODO: Removeme once 3.14 is out
 | 
			
		||||
          # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
 | 
			
		||||
          python-version: "3.14.0-rc.2"
 | 
			
		||||
          python-version: "3.14.0"
 | 
			
		||||
          freethreaded: true
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
@ -766,7 +759,7 @@ jobs:
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python -m venv test_venv
 | 
			
		||||
          python -mvenv test_venv
 | 
			
		||||
          source test_venv/bin/activate
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										34
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -190,6 +190,40 @@ jobs:
 | 
			
		||||
      runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
      tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  inductor-build:
 | 
			
		||||
    name: inductor-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -374,6 +374,7 @@ third_party/ruy/
 | 
			
		||||
third_party/glog/
 | 
			
		||||
 | 
			
		||||
# Virtualenv
 | 
			
		||||
.venv/
 | 
			
		||||
venv/
 | 
			
		||||
 | 
			
		||||
# Log files
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							@ -201,3 +201,17 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
 | 
			
		||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
 | 
			
		||||
/torch/headeronly/ @janeyx99
 | 
			
		||||
/torch/header_only_apis.txt @janeyx99
 | 
			
		||||
 | 
			
		||||
# FlexAttention
 | 
			
		||||
/torch/nn/attention/flex_attention.py @drisspg
 | 
			
		||||
/torch/_higher_order_ops/flex_attention.py @drisspg
 | 
			
		||||
/torch/_inductor/kernel/flex/ @drisspg
 | 
			
		||||
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_attention.py @drisspg
 | 
			
		||||
/test/inductor/test_flex_decoding.py @drisspg
 | 
			
		||||
 | 
			
		||||
# Low Precision GEMMs
 | 
			
		||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
 | 
			
		||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
 | 
			
		||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ RUN chmod +x ~/miniconda.sh && \
 | 
			
		||||
    bash ~/miniconda.sh -b -p /opt/conda && \
 | 
			
		||||
    rm ~/miniconda.sh && \
 | 
			
		||||
    /opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
 | 
			
		||||
    /opt/conda/bin/python -m pip install -r requirements.txt && \
 | 
			
		||||
    /opt/conda/bin/python -mpip install -r requirements.txt && \
 | 
			
		||||
    /opt/conda/bin/conda clean -ya
 | 
			
		||||
 | 
			
		||||
FROM dev-base as submodule-update
 | 
			
		||||
 | 
			
		||||
@ -289,14 +289,15 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
 | 
			
		||||
    set(fbgemm_genai_mx8mx8bf16_grouped
 | 
			
		||||
    set(fbgemm_genai_cuh
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${fbgemm_genai_cuh}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -183,11 +183,6 @@ struct CUDACachingHostAllocatorImpl
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool pinned_use_background_threads() override {
 | 
			
		||||
    return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
 | 
			
		||||
        pinned_use_background_threads();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EventPool::Event create_event_internal(DeviceIndex idx) {
 | 
			
		||||
    // Leak the event pool to avoid shutdown issue.
 | 
			
		||||
    static auto* event_pool = new EventPool();
 | 
			
		||||
 | 
			
		||||
@ -177,7 +177,6 @@ inline void segmented_sort_pairs(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
 | 
			
		||||
inline void unique_by_key(
 | 
			
		||||
  KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
 | 
			
		||||
@ -193,7 +192,6 @@ inline void unique_by_key(
 | 
			
		||||
  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
 | 
			
		||||
    keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace impl {
 | 
			
		||||
 | 
			
		||||
@ -579,7 +577,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
 | 
			
		||||
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
 | 
			
		||||
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
 | 
			
		||||
@ -607,7 +604,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
 | 
			
		||||
void unique(InputIteratorT input, OutputIteratorT output,
 | 
			
		||||
 | 
			
		||||
@ -28,22 +28,6 @@
 | 
			
		||||
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for UniqueByKey is added to cub 1.16 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/405
 | 
			
		||||
#if CUB_VERSION >= 101600
 | 
			
		||||
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for scan by key is added to cub 1.15
 | 
			
		||||
// in https://github.com/NVIDIA/cub/pull/376
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// cub support for cub::FutureValue is added to cub 1.15 in:
 | 
			
		||||
// https://github.com/NVIDIA/cub/pull/305
 | 
			
		||||
#if CUB_VERSION >= 101500
 | 
			
		||||
 | 
			
		||||
@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
 | 
			
		||||
  DispatchKey::CUDA,
 | 
			
		||||
  DispatchKey::CPU,
 | 
			
		||||
  DispatchKey::PrivateUse1,
 | 
			
		||||
  DispatchKey::SparseCPU,
 | 
			
		||||
  DispatchKey::SparseCUDA,
 | 
			
		||||
  DispatchKey::SparseCsrCPU,
 | 
			
		||||
  DispatchKey::SparseCsrCUDA,
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
 | 
			
		||||
 | 
			
		||||
@ -658,6 +658,7 @@ static void check_shape_forward(const at::Tensor& input,
 | 
			
		||||
  TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
 | 
			
		||||
  TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
 | 
			
		||||
  TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(weight_dim == k,
 | 
			
		||||
           "Expected ", weight_dim, "-dimensional input for ", weight_dim,
 | 
			
		||||
 | 
			
		||||
@ -2322,12 +2322,23 @@ _scaled_nvfp4_nvfp4(
 | 
			
		||||
          const Tensor& scale_b, const SwizzleType swizzle_b,
 | 
			
		||||
          const std::optional<Tensor>& bias,
 | 
			
		||||
          const c10::ScalarType out_dtype,
 | 
			
		||||
          const bool single_scale,
 | 
			
		||||
          Tensor& out) {
 | 
			
		||||
          Tensor& out,
 | 
			
		||||
          const std::optional<Tensor>& global_scale_a = std::nullopt,
 | 
			
		||||
          const std::optional<Tensor>& global_scale_b = std::nullopt) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
 | 
			
		||||
#endif
 | 
			
		||||
  TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
 | 
			
		||||
  std::optional<Tensor> alpha = std::nullopt;
 | 
			
		||||
  // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise,
 | 
			
		||||
  //       if this is "And" we would silently do nothing in the case where one global scale is
 | 
			
		||||
  //       passed and not the other.
 | 
			
		||||
  if (global_scale_a.has_value() || global_scale_b.has_value()) {
 | 
			
		||||
    TORCH_CHECK_VALUE(global_scale_a.has_value(),
 | 
			
		||||
        "For two-level-scaled NVFP4, global_scale_a must have a value");
 | 
			
		||||
    TORCH_CHECK_VALUE(global_scale_b.has_value(),
 | 
			
		||||
        "For two-level-scaled NVFP4, global_scale_b must have a value");
 | 
			
		||||
    alpha = global_scale_a.value().mul(global_scale_b.value());
 | 
			
		||||
  }
 | 
			
		||||
  // Restrictions:
 | 
			
		||||
  // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
 | 
			
		||||
  // Scales must be swizzled
 | 
			
		||||
@ -2349,7 +2360,7 @@ _scaled_nvfp4_nvfp4(
 | 
			
		||||
 | 
			
		||||
  auto scaling_choice_a = ScalingType::BlockWise1x16;
 | 
			
		||||
  auto scaling_choice_b = ScalingType::BlockWise1x16;
 | 
			
		||||
  return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
 | 
			
		||||
  return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2555,9 +2566,10 @@ _scaled_mm_cuda_v2_out(
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
 | 
			
		||||
    return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
 | 
			
		||||
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
 | 
			
		||||
    return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out,
 | 
			
		||||
                               scale_a[1], scale_b[1]);
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
 | 
			
		||||
    return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
 | 
			
		||||
    return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
 | 
			
		||||
  } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) {
 | 
			
		||||
    return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
 | 
			
		||||
  } else {
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,7 @@
 | 
			
		||||
#include <ATen/native/cuda/block_reduce.cuh>
 | 
			
		||||
#include <ATen/native/cuda/thread_constants.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/reverse_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -240,10 +238,6 @@ __global__ void renorm_kernel(
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
 | 
			
		||||
                               int64_t num_weights, int64_t padding_idx,
 | 
			
		||||
@ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
 | 
			
		||||
 | 
			
		||||
  if (scale_grad_by_freq) {
 | 
			
		||||
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
 | 
			
		||||
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
@ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
 | 
			
		||||
        num_indices
 | 
			
		||||
      );
 | 
			
		||||
    });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
 | 
			
		||||
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
 | 
			
		||||
    });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return embedding_backward_cuda_kernel(grad, orig_indices,
 | 
			
		||||
 | 
			
		||||
@ -10,9 +10,7 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/counting_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm
 | 
			
		||||
            partials_per_segment_offset[num_of_segments-1];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
 | 
			
		||||
  *num_of_segments_ptr = num_of_segments;
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
} // anon namespace
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
Tensor embedding_backward_cuda_kernel(
 | 
			
		||||
        const Tensor &grad,
 | 
			
		||||
@ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel(
 | 
			
		||||
  auto segment_offsets = at::empty({numel}, orig_indices.options());
 | 
			
		||||
  auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
 | 
			
		||||
  int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
 | 
			
		||||
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
 | 
			
		||||
    int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
 | 
			
		||||
    write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  });
 | 
			
		||||
#else
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
 | 
			
		||||
    cuda::cub::unique_by_key(
 | 
			
		||||
      sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
 | 
			
		||||
      segment_offsets.mutable_data_ptr<index_t>(),
 | 
			
		||||
      num_of_segments_ptr, sorted_indices.numel());
 | 
			
		||||
  });
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  int64_t max_segments = std::min<int64_t>(numel, num_weights);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,16 +31,10 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
#include <thrust/iterator/reverse_iterator.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
@ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
 | 
			
		||||
 | 
			
		||||
  if (scale_grad_by_freq) {
 | 
			
		||||
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
 | 
			
		||||
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
@ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
 | 
			
		||||
        num_indices
 | 
			
		||||
      );
 | 
			
		||||
    });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
 | 
			
		||||
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
 | 
			
		||||
    });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
 | 
			
		||||
      count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,
 | 
			
		||||
 | 
			
		||||
@ -1,90 +0,0 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/native/cuda/SortingCommon.cuh>
 | 
			
		||||
#include <ATen/cuda/cub_definitions.cuh>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/empty_like.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/ThrustAllocator.h>
 | 
			
		||||
#include <thrust/device_ptr.h>
 | 
			
		||||
#include <thrust/execution_policy.h>
 | 
			
		||||
#include <thrust/sort.h>
 | 
			
		||||
#include <thrust/unique.h>
 | 
			
		||||
#include <thrust/device_ptr.h>
 | 
			
		||||
#include <thrust/iterator/constant_iterator.h>
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  at::cuda::ThrustAllocator allocator;
 | 
			
		||||
  auto policy = thrust::cuda::par(allocator).on(stream);
 | 
			
		||||
 | 
			
		||||
  auto num_indices = count.numel();
 | 
			
		||||
 | 
			
		||||
  // Compute an increasing sequence per unique item in sortedIndices:
 | 
			
		||||
  // sorted: 2 5 5 5 7 7 8 9 9
 | 
			
		||||
  //  count: 1 1 2 3 1 2 1 1 2
 | 
			
		||||
  auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
 | 
			
		||||
  auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
 | 
			
		||||
  thrust::inclusive_scan_by_key(
 | 
			
		||||
    policy,
 | 
			
		||||
    sorted_data,
 | 
			
		||||
    sorted_data + num_indices,
 | 
			
		||||
    thrust::make_constant_iterator(1),
 | 
			
		||||
    count_data
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  // Take the maximum of each count per unique key in reverse:
 | 
			
		||||
  // sorted: 2 5 5 5 7 7 8 9 9
 | 
			
		||||
  //  count: 1 3 3 3 2 2 1 2 2
 | 
			
		||||
  thrust::inclusive_scan_by_key(
 | 
			
		||||
    policy,
 | 
			
		||||
    thrust::make_reverse_iterator(sorted_data + num_indices),
 | 
			
		||||
    thrust::make_reverse_iterator(sorted_data),
 | 
			
		||||
    thrust::make_reverse_iterator(count_data + num_indices),
 | 
			
		||||
    thrust::make_reverse_iterator(count_data + num_indices),
 | 
			
		||||
    thrust::equal_to<index_t>(),
 | 
			
		||||
    thrust::maximum<index_t>()
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
template
 | 
			
		||||
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
 | 
			
		||||
  auto stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  at::cuda::ThrustAllocator allocator;
 | 
			
		||||
  auto policy = thrust::cuda::par(allocator).on(stream);
 | 
			
		||||
  const ptrdiff_t numel = sorted_indices.numel();
 | 
			
		||||
  auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
 | 
			
		||||
  auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
  auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
 | 
			
		||||
  auto ends = thrust::unique_by_key_copy(
 | 
			
		||||
          policy,
 | 
			
		||||
          sorted_indices_dev,
 | 
			
		||||
          sorted_indices_dev + numel,
 | 
			
		||||
          thrust::make_counting_iterator(0),
 | 
			
		||||
          dummy_dev,
 | 
			
		||||
          thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
 | 
			
		||||
  return thrust::get<0>(ends) - dummy_dev;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
template
 | 
			
		||||
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
@ -146,6 +146,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
 | 
			
		||||
  int64_t batch_size = target.size(0);
 | 
			
		||||
  int64_t H = target.size(1);
 | 
			
		||||
  int64_t W = target.size(2);
 | 
			
		||||
  int64_t n_classes = grad_input.size(1);
 | 
			
		||||
 | 
			
		||||
  CUDA_KERNEL_LOOP(index, n_threads) {
 | 
			
		||||
    const int64_t b = index % batch_size;
 | 
			
		||||
@ -156,6 +157,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
 | 
			
		||||
    if (cur_target == ignore_index) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
 | 
			
		||||
    scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
 | 
			
		||||
    grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -413,14 +413,12 @@ struct ReduceOp {
 | 
			
		||||
      value = thread_reduce<output_vec_size>(input_slice);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (config.should_block_y_reduce()) {
 | 
			
		||||
      value = block_y_reduce<output_vec_size>(value, shared_memory);
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    if (config.should_block_x_reduce()) {
 | 
			
		||||
      value = block_x_reduce<output_vec_size>(value, shared_memory);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (config.should_block_y_reduce()) {
 | 
			
		||||
      value = block_y_reduce<output_vec_size>(value, shared_memory);
 | 
			
		||||
    }
 | 
			
		||||
    using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>;
 | 
			
		||||
    using offset_vec_t = std::array<index_t, output_vec_size>;
 | 
			
		||||
    offset_vec_t base_offsets;
 | 
			
		||||
@ -655,14 +653,8 @@ struct ReduceOp {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
 | 
			
		||||
    // matching Triton, etc.
 | 
			
		||||
    // todo for AMD
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
    #endif
 | 
			
		||||
      #pragma unroll
 | 
			
		||||
      for (int i = 0; i < output_vec_size; i++) {
 | 
			
		||||
        arg_t other = ops.warp_shfl_down(value[i], offset);
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
// TODO: remove this when CUDA <11.6 is no longer supported
 | 
			
		||||
void topk_out_with_sort(
 | 
			
		||||
  const Tensor& self,
 | 
			
		||||
  int64_t k, int64_t dim, bool largest,
 | 
			
		||||
@ -31,21 +30,12 @@ void topk_out_with_sort(
 | 
			
		||||
  indices.copy_(sorted_indices.narrow(dim, 0, k));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: remove this when CUDA <11.6 is no longer supported
 | 
			
		||||
bool disable_sort_for_topk();
 | 
			
		||||
bool should_use_sort(const Tensor& self, int64_t dim) {
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
  if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972
 | 
			
		||||
  return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387
 | 
			
		||||
#else
 | 
			
		||||
  if (disable_sort_for_topk()) return false;
 | 
			
		||||
  // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
 | 
			
		||||
  if (self.dim() == 0) return false;
 | 
			
		||||
  if (self.dtype() == kBool) return false; // Bool is not support by topk
 | 
			
		||||
  int64_t slice_size = self.size(dim);
 | 
			
		||||
  if (slice_size == 0) return false;
 | 
			
		||||
  int64_t num_slices = self.numel() / slice_size;
 | 
			
		||||
  return num_slices <= 10 && slice_size >= 100000;
 | 
			
		||||
  return false;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,11 +21,6 @@ using namespace at::native;
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
// TODO: remove this when CUDA <11.6 is no longer supported
 | 
			
		||||
bool disable_sort_for_topk() {
 | 
			
		||||
  return CUB_SUPPORTS_SCAN_BY_KEY();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace sbtopk { // single_block_topk
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
@ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts(
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
#if !CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  return;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
 | 
			
		||||
 | 
			
		||||
  // if largest, then only threads that has tidx > desired_digit are active
 | 
			
		||||
@ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
// Assumption: slice_size can not be larger than UINT32_MAX
 | 
			
		||||
template <typename Bitwise>
 | 
			
		||||
__global__ void computeBlockwiseKthCounts(
 | 
			
		||||
@ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
 | 
			
		||||
  // occupancy of this kernel is limited by registers per threads
 | 
			
		||||
@ -687,16 +676,12 @@ void launch(
 | 
			
		||||
  uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get());
 | 
			
		||||
  AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream));
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
 | 
			
		||||
  uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
 | 
			
		||||
  AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
 | 
			
		||||
 | 
			
		||||
  auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
 | 
			
		||||
  uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
 | 
			
		||||
#else
 | 
			
		||||
  uint32_t* withinKCounts = nullptr;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  Bitwise desiredMask = 0;
 | 
			
		||||
  dim3 grid;
 | 
			
		||||
@ -743,7 +728,6 @@ void launch(
 | 
			
		||||
  }
 | 
			
		||||
  desired = desired_in;
 | 
			
		||||
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
 | 
			
		||||
    desired, counts, num_blocks, blocks_per_slice, kthCounts);
 | 
			
		||||
  C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
@ -759,28 +743,6 @@ void launch(
 | 
			
		||||
    topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
 | 
			
		||||
    blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
 | 
			
		||||
  C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
#else
 | 
			
		||||
  // Find topk values based on kth values
 | 
			
		||||
  {
 | 
			
		||||
    dim3 grid;
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
 | 
			
		||||
    int warp_size = at::cuda::warp_size();
 | 
			
		||||
    dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
 | 
			
		||||
    sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
 | 
			
		||||
        input,
 | 
			
		||||
        inputSliceSize,
 | 
			
		||||
        outputSliceSize,
 | 
			
		||||
        largest,
 | 
			
		||||
        numInputSlices,
 | 
			
		||||
        inputWithinSliceStride,
 | 
			
		||||
        topK,
 | 
			
		||||
        topKWithinSliceStride,
 | 
			
		||||
        indices,
 | 
			
		||||
        indicesWithinSliceStride,
 | 
			
		||||
        kthValues);
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mbtopk
 | 
			
		||||
@ -788,7 +750,6 @@ void launch(
 | 
			
		||||
bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
 | 
			
		||||
  if (num_slices > std::numeric_limits<uint32_t>::max() ||
 | 
			
		||||
      slice_size > std::numeric_limits<uint32_t>::max()) return false;
 | 
			
		||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
 | 
			
		||||
  // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
 | 
			
		||||
  return (num_slices <= 20 && slice_size >= 20000) ||
 | 
			
		||||
      (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
 | 
			
		||||
@ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
 | 
			
		||||
      (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
 | 
			
		||||
      (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
 | 
			
		||||
      (num_slices > 4000 && slice_size >= 400);
 | 
			
		||||
#else
 | 
			
		||||
  // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
 | 
			
		||||
  return (num_slices <= 400 && slice_size >= 5000) ||
 | 
			
		||||
      (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
 | 
			
		||||
      (num_slices >= 4000 && slice_size >= 300);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void launch_gather_topk_kernel(
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
 | 
			
		||||
    const int64_t k,
 | 
			
		||||
    const int64_t N_padded,
 | 
			
		||||
    const IndexType last_dim_padded) {
 | 
			
		||||
  int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
 | 
			
		||||
  int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
 | 
			
		||||
  if (linear_idx >= N_padded) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
// Helper function to compute output pixel range that can contribute to input pixel
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
__device__ __forceinline__ void compute_output_range(
 | 
			
		||||
    int input_pos,
 | 
			
		||||
    accscalar_t scale,
 | 
			
		||||
    int output_size,
 | 
			
		||||
    bool align_corners,
 | 
			
		||||
    int& min_output,
 | 
			
		||||
    int& max_output) {
 | 
			
		||||
  accscalar_t lo, hi;
 | 
			
		||||
  if (align_corners) {
 | 
			
		||||
      lo = static_cast<accscalar_t>(input_pos - 1) / scale;
 | 
			
		||||
      hi = static_cast<accscalar_t>(input_pos + 1) / scale;
 | 
			
		||||
  } else {
 | 
			
		||||
      lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
      hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
  }
 | 
			
		||||
  min_output = max(0, static_cast<int>(ceil(lo)));
 | 
			
		||||
  max_output = min(output_size - 1, static_cast<int>(floor(hi)));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(1024)
 | 
			
		||||
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
    const bool align_corners,
 | 
			
		||||
    scalar_t* __restrict__ idata,
 | 
			
		||||
    const scalar_t* __restrict__ odata) {
 | 
			
		||||
  // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    // Decode input pixel coordinates
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
    const int w1 = index_temp % width1;
 | 
			
		||||
    index_temp /= width1;
 | 
			
		||||
    const int h1 = index_temp % height1;
 | 
			
		||||
    const size_t nc_idx = index_temp / height1;
 | 
			
		||||
 | 
			
		||||
    accscalar_t grad_sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Find range of output pixels that could interpolate from this input pixel
 | 
			
		||||
    int h2_min, h2_max, w2_min, w2_max;
 | 
			
		||||
    compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
 | 
			
		||||
    compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
 | 
			
		||||
 | 
			
		||||
    // Iterate over potential output pixels
 | 
			
		||||
    for (int h2 = h2_min; h2 <= h2_max; h2++) {
 | 
			
		||||
      for (int w2 = w2_min; w2 <= w2_max; w2++) {
 | 
			
		||||
        // Compute source coordinates for this output pixel
 | 
			
		||||
        const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rheight, h2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int h1_base = (int)h1r;
 | 
			
		||||
        const int h1p = (h1_base < height1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t h1lambda = h1r - h1_base;
 | 
			
		||||
        const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
 | 
			
		||||
 | 
			
		||||
        const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rwidth, w2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int w1_base = (int)w1r;
 | 
			
		||||
        const int w1p = (w1_base < width1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t w1lambda = w1r - w1_base;
 | 
			
		||||
        const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
 | 
			
		||||
 | 
			
		||||
        // Check if our input pixel participates in this interpolation and accumulate all weights
 | 
			
		||||
        // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
 | 
			
		||||
        // to the same pixel, so we need to accumulate weights from all matching positions
 | 
			
		||||
        accscalar_t weight = 0;
 | 
			
		||||
 | 
			
		||||
        // Check all four interpolation positions and accumulate weights
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base) {
 | 
			
		||||
          weight += h0lambda * w0lambda;  // top-left
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base) {
 | 
			
		||||
          weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (weight > 0) {
 | 
			
		||||
          const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
 | 
			
		||||
          grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write accumulated gradient (no atomics needed)
 | 
			
		||||
    idata[index] = static_cast<scalar_t>(grad_sum);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
        static_cast<scalar_t>(h1lambda * w1lambda * d2val),
 | 
			
		||||
        true);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
  // threads are not covering the whole input tensor.
 | 
			
		||||
  grad_input.zero_();
 | 
			
		||||
 | 
			
		||||
  const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
  const int num_threads = std::min(
 | 
			
		||||
      at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  constexpr bool use_input = true;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr bool use_input = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half, at::ScalarType::BFloat16,
 | 
			
		||||
      grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
 | 
			
		||||
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
 | 
			
		||||
              input_height,
 | 
			
		||||
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
 | 
			
		||||
             num_threads,
 | 
			
		||||
 | 
			
		||||
@ -466,11 +466,7 @@ struct ReduceJitOp {
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
    #endif
 | 
			
		||||
      #pragma unroll
 | 
			
		||||
      for (int i = 0; i < output_vec_size; i++) {
 | 
			
		||||
        arg_t other = reducer::warp_shfl_down(value[i], offset);
 | 
			
		||||
 | 
			
		||||
@ -487,9 +487,7 @@ std::unique_ptr<fe::graph::Graph> build_graph(
 | 
			
		||||
  auto scaled_dot_product_flash_attention_options =
 | 
			
		||||
      fe::graph::SDPA_attributes()
 | 
			
		||||
          .set_name("CUDNN_SDPA")
 | 
			
		||||
          .set_is_inference(return_softmaxstats == false)
 | 
			
		||||
          // TODO(eqy): switch to this API once cuDNN FE is upgraded
 | 
			
		||||
          // .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_causal_mask(is_causal)
 | 
			
		||||
          .set_attn_scale(attn_scale);
 | 
			
		||||
  if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
 | 
			
		||||
@ -707,9 +705,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
 | 
			
		||||
  auto scaled_dot_product_flash_attention_options =
 | 
			
		||||
      fe::graph::SDPA_attributes()
 | 
			
		||||
          .set_name("CUDNN_SDPA_NESTEDTENSOR")
 | 
			
		||||
          .set_is_inference(return_softmaxstats == false)
 | 
			
		||||
          // TODO(eqy): switch to this API once cuDNN FE is upgraded
 | 
			
		||||
          // .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_generate_stats(return_softmaxstats)
 | 
			
		||||
          .set_causal_mask(is_causal)
 | 
			
		||||
          .set_attn_scale(attn_scale)
 | 
			
		||||
          .set_seq_len_q(SEQ_LEN_Q_)
 | 
			
		||||
 | 
			
		||||
@ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
 | 
			
		||||
       other.size(0) > max_stride_size || other.size(1) > max_stride_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void map_mps_decomposition_error_code_to_blas(const Tensor& status) {
 | 
			
		||||
  const auto& status_flat = status.view(-1);
 | 
			
		||||
 | 
			
		||||
  for (const auto i : c10::irange(status_flat.size(0))) {
 | 
			
		||||
    int code = status_flat[i].item<int>();
 | 
			
		||||
    switch (code) {
 | 
			
		||||
      case MPSMatrixDecompositionStatusSuccess:
 | 
			
		||||
        status_flat[i] = 0;
 | 
			
		||||
        break;
 | 
			
		||||
      case MPSMatrixDecompositionStatusNonPositiveDefinite:
 | 
			
		||||
      case MPSMatrixDecompositionStatusSingular:
 | 
			
		||||
        status_flat[i] = 2;
 | 
			
		||||
        break;
 | 
			
		||||
      case MPSMatrixDecompositionStatusFailure:
 | 
			
		||||
        status_flat[i] = -1;
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
 | 
			
		||||
@ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
 | 
			
		||||
                  "mpsmatrixdecompositionstatus for details.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  map_mps_decomposition_error_code_to_blas(info);
 | 
			
		||||
 | 
			
		||||
  if (!left) {
 | 
			
		||||
    // If this was a right solve, transpose the result back
 | 
			
		||||
    result.copy_(result_t.transpose(-2, -1).contiguous());
 | 
			
		||||
 | 
			
		||||
@ -1370,6 +1370,7 @@
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU: bmm_sparse_cpu
 | 
			
		||||
    SparseCUDA: bmm_sparse_cuda
 | 
			
		||||
    SparseMPS: bmm_sparse_mps
 | 
			
		||||
    NestedTensorCPU: bmm_nested
 | 
			
		||||
    NestedTensorCUDA: bmm_nested_cuda
 | 
			
		||||
  tags: core
 | 
			
		||||
@ -1385,6 +1386,7 @@
 | 
			
		||||
    MTIA: bmm_out_mtia
 | 
			
		||||
    SparseCPU: bmm_out_sparse_cpu
 | 
			
		||||
    SparseCUDA: bmm_out_sparse_cuda
 | 
			
		||||
    SparseMPS: bmm_out_sparse_mps
 | 
			
		||||
    SparseCsrCUDA: bmm_out_sparse_csr_cuda
 | 
			
		||||
 | 
			
		||||
- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
 | 
			
		||||
@ -4173,7 +4175,7 @@
 | 
			
		||||
  structured_delegate: mm.out
 | 
			
		||||
  variants: function, method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU, SparseCUDA: _sparse_mm
 | 
			
		||||
    SparseCPU, SparseCUDA, SparseMPS: _sparse_mm
 | 
			
		||||
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
 | 
			
		||||
  tags: core
 | 
			
		||||
 | 
			
		||||
@ -7112,6 +7114,7 @@
 | 
			
		||||
    MTIA: addmm_out_mtia
 | 
			
		||||
    SparseCPU: addmm_out_sparse_dense_cpu
 | 
			
		||||
    SparseCUDA: addmm_out_sparse_dense_cuda
 | 
			
		||||
    SparseMPS: addmm_out_sparse_dense_mps
 | 
			
		||||
    SparseCsrCPU: addmm_out_sparse_compressed_cpu
 | 
			
		||||
    SparseCsrCUDA: addmm_out_sparse_compressed_cuda
 | 
			
		||||
 | 
			
		||||
@ -7121,6 +7124,7 @@
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU: addmm_sparse_dense_cpu
 | 
			
		||||
    SparseCUDA: addmm_sparse_dense_cuda
 | 
			
		||||
    SparseMPS: addmm_sparse_dense_mps
 | 
			
		||||
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
 | 
			
		||||
  tags: core
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/native/SparseTensorUtils.h>
 | 
			
		||||
#include <ATen/ExpandUtils.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
#include <ATen/native/sparse/SparseStubs.h>
 | 
			
		||||
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
 | 
			
		||||
@ -18,6 +19,8 @@
 | 
			
		||||
#include <ATen/ops/ones_like.h>
 | 
			
		||||
#include <ATen/ops/argsort.h>
 | 
			
		||||
#include <ATen/ops/result_type.h>
 | 
			
		||||
#include <ATen/ops/bmm_native.h>
 | 
			
		||||
#include <ATen/ops/addmm_native.h>
 | 
			
		||||
#include <ATen/ops/copy_sparse_to_sparse.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
#endif
 | 
			
		||||
@ -33,6 +36,305 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
 | 
			
		||||
#include <ATen/native/mps/Mul_metallib.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
static Tensor& s_addmm_out_sparse_dense_mps(
 | 
			
		||||
    Tensor& r,
 | 
			
		||||
    const Tensor& t,
 | 
			
		||||
    const SparseTensor& sparse_,
 | 
			
		||||
    const Tensor& dense,
 | 
			
		||||
    const Scalar& beta,
 | 
			
		||||
    const Scalar& alpha) {
 | 
			
		||||
  TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim());
 | 
			
		||||
  TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim());
 | 
			
		||||
  TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim());
 | 
			
		||||
  TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim());
 | 
			
		||||
 | 
			
		||||
  const int64_t I = sparse_.size(0);
 | 
			
		||||
  const int64_t J = sparse_.size(1);
 | 
			
		||||
  const int64_t K = dense.size(1);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(dense.size(0) == J,
 | 
			
		||||
      "addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0));
 | 
			
		||||
  TORCH_CHECK(t.size(0) == I && t.size(1) == K,
 | 
			
		||||
      "addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")");
 | 
			
		||||
 | 
			
		||||
  r.resize_({I, K});
 | 
			
		||||
 | 
			
		||||
  auto sparse = sparse_.coalesce();
 | 
			
		||||
  const int64_t nnz = sparse._nnz();
 | 
			
		||||
 | 
			
		||||
  if (nnz == 0 || I == 0 || K == 0) {
 | 
			
		||||
    at::mul_out(r, t, beta);
 | 
			
		||||
    return r;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const auto v_dtype = sparse._values().scalar_type();
 | 
			
		||||
  const auto d_dtype = dense.scalar_type();
 | 
			
		||||
  const auto t_dtype = t.scalar_type();
 | 
			
		||||
  auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(canCast(compute_dtype, r.scalar_type()),
 | 
			
		||||
              "Can't convert computed type ", compute_dtype, " to output ", r.scalar_type());
 | 
			
		||||
 | 
			
		||||
  auto indices2d = sparse._indices().contiguous();
 | 
			
		||||
  auto values = sparse._values().to(compute_dtype);
 | 
			
		||||
  auto dense_c = dense.to(compute_dtype).contiguous();
 | 
			
		||||
  auto t_c = t.to(compute_dtype).contiguous();
 | 
			
		||||
 | 
			
		||||
  const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous();
 | 
			
		||||
  Tensor out_buf = out_needs_cast
 | 
			
		||||
      ? at::empty({I, K}, r.options().dtype(compute_dtype))
 | 
			
		||||
      : r;
 | 
			
		||||
  auto out_contig = out_buf.contiguous();
 | 
			
		||||
 | 
			
		||||
  auto device = r.device();
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  const float alpha_f = alpha.to<float>();
 | 
			
		||||
  const float beta_f  = beta.to<float>();
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values);
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc(func);
 | 
			
		||||
      auto enc = stream->commandEncoder();
 | 
			
		||||
      [enc setComputePipelineState:pso];
 | 
			
		||||
 | 
			
		||||
      const uint32_t tew = pso.threadExecutionWidth;
 | 
			
		||||
      const uint32_t gridX = static_cast<uint32_t>(K);
 | 
			
		||||
      const uint32_t gridZ = static_cast<uint32_t>(I);
 | 
			
		||||
      const uint32_t tgW = std::min<uint32_t>(gridX, tew);
 | 
			
		||||
 | 
			
		||||
      MTLSize grid = MTLSizeMake(gridX, 1, gridZ);
 | 
			
		||||
      MTLSize tgs = MTLSizeMake(tgW, 1, 1);
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(enc,
 | 
			
		||||
                  indices2d,
 | 
			
		||||
                  values,
 | 
			
		||||
                  dense_c,
 | 
			
		||||
                  t_c,
 | 
			
		||||
                  out_contig,
 | 
			
		||||
                  std::array<uint32_t, 3>{static_cast<uint32_t>(I),
 | 
			
		||||
                                           static_cast<uint32_t>(J),
 | 
			
		||||
                                           static_cast<uint32_t>(K)},
 | 
			
		||||
                  std::array<float, 2>{alpha_f, beta_f},
 | 
			
		||||
                  static_cast<uint32_t>(nnz));
 | 
			
		||||
      [enc dispatchThreads:grid threadsPerThreadgroup:tgs];
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  if (out_needs_cast) {
 | 
			
		||||
    r.copy_(out_contig.to(r.scalar_type()));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return r;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static void build_batch_ptr_mps(
 | 
			
		||||
    const Tensor& indices_dim0,
 | 
			
		||||
    int64_t B,
 | 
			
		||||
    Tensor& batch_ptr
 | 
			
		||||
) {
 | 
			
		||||
  // Builds an array of pointers which point to each batches elements. Example:
 | 
			
		||||
  // idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2]  // 9 non-zero elements
 | 
			
		||||
  //          └─────┘  └──┘  └─────────┘
 | 
			
		||||
  //          batch 0  batch 1  batch 2
 | 
			
		||||
  // batch_ptr = [0, 3, 5, 9]
 | 
			
		||||
  //              │  │  │  └─ end of batch 2 (total nnz)
 | 
			
		||||
  //              │  │  └──── batch 2 starts at index 5
 | 
			
		||||
  //              │  └─────── batch 1 starts at index 3
 | 
			
		||||
  //              └────────── batch 0 starts at index 0
 | 
			
		||||
  TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected");
 | 
			
		||||
  auto device = indices_dim0.device();
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t nnz = indices_dim0.numel();
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches");
 | 
			
		||||
      auto enc = stream->commandEncoder();
 | 
			
		||||
      [enc setComputePipelineState:pso];
 | 
			
		||||
 | 
			
		||||
      const uint32_t tew = pso.threadExecutionWidth;
 | 
			
		||||
      const uint32_t Q = static_cast<uint32_t>(B + 1);
 | 
			
		||||
      const uint32_t tgW = std::min<uint32_t>(Q, tew);
 | 
			
		||||
      MTLSize grid = MTLSizeMake(Q, 1, 1);
 | 
			
		||||
      MTLSize tgs  = MTLSizeMake(tgW, 1, 1);
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(enc,
 | 
			
		||||
                  indices_dim0,
 | 
			
		||||
                  batch_ptr,
 | 
			
		||||
                  std::array<uint32_t, 2>{static_cast<uint32_t>(nnz),
 | 
			
		||||
                                          static_cast<uint32_t>(B)});
 | 
			
		||||
      [enc dispatchThreads:grid threadsPerThreadgroup:tgs];
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void build_row_ptr_per_batch_mps(
 | 
			
		||||
    const Tensor& rows,
 | 
			
		||||
    const Tensor& batch_ptr,
 | 
			
		||||
    int64_t B,
 | 
			
		||||
    int64_t I,
 | 
			
		||||
    Tensor& row_ptr
 | 
			
		||||
) {
 | 
			
		||||
  // Build per-batch CSR-style row pointer arrays from row indices sorted by batch
 | 
			
		||||
  // Given:
 | 
			
		||||
  //   rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch
 | 
			
		||||
  //   batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b
 | 
			
		||||
  // Produces:
 | 
			
		||||
  //   - row_ptr: shape [B, I+1]
 | 
			
		||||
  //
 | 
			
		||||
  // Example (B = 2, I = 4):
 | 
			
		||||
  // rows       = [0,   0,   1,  3,  0,   2,    2]   // 7 non-zero elements
 | 
			
		||||
  //               └─── batch 0 ──┘  └─ batch 1 ─┘
 | 
			
		||||
  // batch_ptr  = [0, 4, 7]
 | 
			
		||||
  //               │  │  └─ end of batch 1 (total nnz)
 | 
			
		||||
  //               │  └──── end of batch 0/start of batch 1
 | 
			
		||||
  //               └─────── start of batch 0
 | 
			
		||||
  //
 | 
			
		||||
  // per-batch row pointers (I+1 entries each):
 | 
			
		||||
  //   row_ptr[0] = [0, 2, 3, 3, 4]
 | 
			
		||||
  //   row_ptr[1] = [0, 1, 1, 3, 3]
 | 
			
		||||
  // laid out in memory: [0, 2, 3, 3, 4,  0, 1, 1, 3, 3]
 | 
			
		||||
  TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected");
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch");
 | 
			
		||||
      auto enc = stream->commandEncoder();
 | 
			
		||||
      [enc setComputePipelineState:pso];
 | 
			
		||||
 | 
			
		||||
      const uint32_t tew = pso.threadExecutionWidth;
 | 
			
		||||
      const uint32_t Qx = static_cast<uint32_t>(I + 1);
 | 
			
		||||
      const uint32_t Qy = static_cast<uint32_t>(B);
 | 
			
		||||
      const uint32_t tgW = std::min<uint32_t>(Qx, tew);
 | 
			
		||||
 | 
			
		||||
      MTLSize grid = MTLSizeMake(Qx, Qy, 1);
 | 
			
		||||
      MTLSize tgs = MTLSizeMake(tgW, 1, 1);
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(enc,
 | 
			
		||||
                  rows,
 | 
			
		||||
                  batch_ptr,
 | 
			
		||||
                  row_ptr,
 | 
			
		||||
                  std::array<uint32_t, 2>{static_cast<uint32_t>(I),
 | 
			
		||||
                                           static_cast<uint32_t>(B)});
 | 
			
		||||
      [enc dispatchThreads:grid threadsPerThreadgroup:tgs];
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) {
 | 
			
		||||
  TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device());
 | 
			
		||||
  TORCH_CHECK(self_.is_mps(),  "bmm_sparse: expected 'self' to be MPS, got ", self_.device());
 | 
			
		||||
  TORCH_CHECK(mat2_.is_mps(),  "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device());
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim());
 | 
			
		||||
  TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim());
 | 
			
		||||
  TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim());
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match");
 | 
			
		||||
  TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match");
 | 
			
		||||
 | 
			
		||||
  const int64_t B = self_.size(0);
 | 
			
		||||
  const int64_t I = self_.size(1);
 | 
			
		||||
  const int64_t J = self_.size(2);
 | 
			
		||||
  const int64_t K = mat2_.size(2);
 | 
			
		||||
 | 
			
		||||
  auto self = self_.coalesce();
 | 
			
		||||
  const int64_t nnz = self._nnz();
 | 
			
		||||
  if (nnz == 0) {
 | 
			
		||||
    return result_.zero_();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const auto computeDtype = at::kFloat;
 | 
			
		||||
 | 
			
		||||
  auto indices = self._indices();
 | 
			
		||||
  auto values  = self._values();
 | 
			
		||||
 | 
			
		||||
  auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype);
 | 
			
		||||
  auto mat2_c = mat2_.scalar_type()   == computeDtype ? mat2_   : mat2_.to(computeDtype);
 | 
			
		||||
  auto mat2_contig = mat2_c.contiguous();
 | 
			
		||||
 | 
			
		||||
  auto idx_b = indices.select(0, 0).contiguous();
 | 
			
		||||
  auto idx_i = indices.select(0, 1).contiguous();
 | 
			
		||||
  auto idx_j = indices.select(0, 2).contiguous();
 | 
			
		||||
 | 
			
		||||
  // builds an array of pointers of where the batch_idx's pointer starts and ends
 | 
			
		||||
  // look in function for better explanation
 | 
			
		||||
  auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong));
 | 
			
		||||
  build_batch_ptr_mps(idx_b, B, batch_ptr);
 | 
			
		||||
  // build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals
 | 
			
		||||
  auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong));
 | 
			
		||||
  build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr);
 | 
			
		||||
 | 
			
		||||
  const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous();
 | 
			
		||||
  Tensor out_buf = out_needs_cast
 | 
			
		||||
      ? at::empty({B, I, K}, result_.options().dtype(computeDtype))
 | 
			
		||||
      : result_;
 | 
			
		||||
  auto out_contig = out_buf.contiguous();
 | 
			
		||||
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values));
 | 
			
		||||
      auto enc = stream->commandEncoder();
 | 
			
		||||
      [enc setComputePipelineState:pso];
 | 
			
		||||
 | 
			
		||||
      const uint32_t tew = pso.threadExecutionWidth;
 | 
			
		||||
      const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew);
 | 
			
		||||
 | 
			
		||||
      // One threadgroup per (row i, batch b), lanes cover K
 | 
			
		||||
      MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B);
 | 
			
		||||
      MTLSize tgs  = MTLSizeMake(tgW, 1, 1);
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(enc,
 | 
			
		||||
                  idx_i,
 | 
			
		||||
                  idx_j,
 | 
			
		||||
                  values_c,
 | 
			
		||||
                  mat2_contig,
 | 
			
		||||
                  out_contig,
 | 
			
		||||
                  row_ptr,
 | 
			
		||||
                  std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K});
 | 
			
		||||
      [enc dispatchThreads:grid threadsPerThreadgroup:tgs];
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
  if (out_needs_cast) {
 | 
			
		||||
    result_.copy_(out_contig.to(result_.scalar_type()));
 | 
			
		||||
  }
 | 
			
		||||
  return result_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) {
 | 
			
		||||
  Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options());
 | 
			
		||||
  return bmm_out_sparse_mps(self, mat2, result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& addmm_out_sparse_dense_mps(
 | 
			
		||||
    const Tensor& self,
 | 
			
		||||
    const SparseTensor& mat1,
 | 
			
		||||
    const Tensor& mat2,
 | 
			
		||||
    const Scalar& beta,
 | 
			
		||||
    const Scalar& alpha,
 | 
			
		||||
    Tensor& result) {
 | 
			
		||||
  c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
 | 
			
		||||
  return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor addmm_sparse_dense_mps(
 | 
			
		||||
    const Tensor& self,
 | 
			
		||||
    const SparseTensor& mat1,
 | 
			
		||||
    const Tensor& mat2,
 | 
			
		||||
    const Scalar& beta,
 | 
			
		||||
    const Scalar& alpha
 | 
			
		||||
) {
 | 
			
		||||
  c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
 | 
			
		||||
  Tensor result = at::empty({0}, self.options());
 | 
			
		||||
  return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static SparseTensor& mul_out_dense_sparse_mps(
 | 
			
		||||
    const Tensor& dense,
 | 
			
		||||
    const Tensor& sparse,
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,105 @@
 | 
			
		||||
#include <metal_stdlib>
 | 
			
		||||
#include <c10/metal/indexing.h>
 | 
			
		||||
#include <c10/metal/utils.h>
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
using namespace metal;
 | 
			
		||||
 | 
			
		||||
inline uint lower_bound_i64(device const long* arr, uint lo, uint hi, long key) {
 | 
			
		||||
  uint l = lo, r = hi;
 | 
			
		||||
  while (l < r) {
 | 
			
		||||
    uint m = (l + r) >> 1;
 | 
			
		||||
    long v = arr[m];
 | 
			
		||||
    if (v < key) {
 | 
			
		||||
      l = m + 1;
 | 
			
		||||
    } else {
 | 
			
		||||
      r = m;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return l;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T> struct MulAccum { using type = float; };
 | 
			
		||||
template <> struct MulAccum<float2> { using type = float2; };
 | 
			
		||||
inline uint upper_bound_i64(device const long* arr, uint lo, uint hi, long key) {
 | 
			
		||||
  uint l = lo, r = hi;
 | 
			
		||||
  while (l < r) {
 | 
			
		||||
    uint m = (l + r) >> 1;
 | 
			
		||||
    long v = arr[m];
 | 
			
		||||
    if (v <= key) {
 | 
			
		||||
      l = m + 1;
 | 
			
		||||
    } else {
 | 
			
		||||
      r = m;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return l;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
kernel void build_row_ptr_from_sorted_rows_by_batch(
 | 
			
		||||
    device const long* rows        [[buffer(0)]],
 | 
			
		||||
    device const long* batch_ptr   [[buffer(1)]],
 | 
			
		||||
    device long*       row_ptr     [[buffer(2)]],
 | 
			
		||||
    constant uint2&    dims        [[buffer(3)]],
 | 
			
		||||
    uint3              tid         [[thread_position_in_grid]])
 | 
			
		||||
{
 | 
			
		||||
  const uint I = dims.x;
 | 
			
		||||
  const uint B = dims.y;
 | 
			
		||||
 | 
			
		||||
  const uint i = tid.x;
 | 
			
		||||
  const uint b = tid.y;
 | 
			
		||||
 | 
			
		||||
  if (b >= B || i > I) return;
 | 
			
		||||
 | 
			
		||||
  const uint base = (uint)batch_ptr[b];
 | 
			
		||||
  const uint lim  = (uint)batch_ptr[b + 1];
 | 
			
		||||
 | 
			
		||||
  const ulong out_base = (ulong)b * (ulong)(I + 1);
 | 
			
		||||
 | 
			
		||||
  if (i == I) {
 | 
			
		||||
    row_ptr[out_base + (ulong)I] = (long)lim;
 | 
			
		||||
  } else {
 | 
			
		||||
    const long key = (long)i;
 | 
			
		||||
    const uint pos = lower_bound_i64(rows, base, lim, key);
 | 
			
		||||
    row_ptr[out_base + (ulong)i] = (long)pos;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void spmm_bmm_coo_rows_grouped(
 | 
			
		||||
    device const long*   rows      [[buffer(0)]],
 | 
			
		||||
    device const long*   cols      [[buffer(1)]],
 | 
			
		||||
    device const T*      vals      [[buffer(2)]],
 | 
			
		||||
    device const T*      dense     [[buffer(3)]],
 | 
			
		||||
    device T*            out       [[buffer(4)]],
 | 
			
		||||
    device const long*   row_ptr   [[buffer(5)]],
 | 
			
		||||
    constant uint4&      dims      [[buffer(6)]],
 | 
			
		||||
    uint3                tid       [[thread_position_in_grid]],
 | 
			
		||||
    uint3                ltid      [[thread_position_in_threadgroup]],
 | 
			
		||||
    uint3                tptg      [[threads_per_threadgroup]])
 | 
			
		||||
{
 | 
			
		||||
  const uint B = dims.x;
 | 
			
		||||
  const uint I = dims.y;
 | 
			
		||||
  const uint J = dims.z;
 | 
			
		||||
  const uint K = dims.w;
 | 
			
		||||
 | 
			
		||||
  const uint b = tid.z;
 | 
			
		||||
  const uint i = tid.y;
 | 
			
		||||
  const uint lane = ltid.x;
 | 
			
		||||
  const uint tgW  = tptg.x;
 | 
			
		||||
 | 
			
		||||
  const ulong rp_base = (ulong)b * (ulong)(I + 1);
 | 
			
		||||
  const uint start = (uint)row_ptr[rp_base + (ulong)i];
 | 
			
		||||
  const uint end   = (uint)row_ptr[rp_base + (ulong)i + 1];
 | 
			
		||||
 | 
			
		||||
  for (uint k = lane; k < K; k += tgW) {
 | 
			
		||||
    auto acc = static_cast<accum_t<T>>(T(0));
 | 
			
		||||
    for (uint p = start; p < end; ++p) {
 | 
			
		||||
      const uint c = (uint)cols[p];
 | 
			
		||||
      const auto v = static_cast<accum_t<T>>(vals[p]);
 | 
			
		||||
      const uint d_off = ((b * J) + c) * K + k;
 | 
			
		||||
      const auto d = static_cast<accum_t<T>>(dense[d_off]);
 | 
			
		||||
      acc += mul(v, d);
 | 
			
		||||
    }
 | 
			
		||||
    const uint y_off = ((b * I) + i) * K + k;
 | 
			
		||||
    out[y_off] = static_cast<T>(acc);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void dense_sparse_mul_kernel(
 | 
			
		||||
@ -32,10 +127,9 @@ kernel void dense_sparse_mul_kernel(
 | 
			
		||||
  ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
 | 
			
		||||
  ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
 | 
			
		||||
 | 
			
		||||
  using accum_t = typename MulAccum<T>::type;
 | 
			
		||||
  const accum_t a = static_cast<accum_t>(values[val_idx]);
 | 
			
		||||
  const accum_t b = static_cast<accum_t>(dense[dense_idx]);
 | 
			
		||||
  out_values[val_idx] = static_cast<T>(a * b);
 | 
			
		||||
  const auto a = static_cast<accum_t<T>>(values[val_idx]);
 | 
			
		||||
  const auto b = static_cast<accum_t<T>>(dense[dense_idx]);
 | 
			
		||||
  out_values[val_idx] = static_cast<T>(mul(a, b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
kernel void intersect_binary_search(
 | 
			
		||||
@ -120,6 +214,76 @@ kernel void fused_gather_mul_kernel(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
kernel void build_batch_ptr_from_sorted_batches(
 | 
			
		||||
    device const long* batches       [[buffer(0)]],
 | 
			
		||||
    device long*       batch_ptr     [[buffer(1)]],
 | 
			
		||||
    constant uint2&    nnz_B         [[buffer(2)]],
 | 
			
		||||
    uint3              tid           [[thread_position_in_grid]])
 | 
			
		||||
{
 | 
			
		||||
  uint b = tid.x;
 | 
			
		||||
  uint nnz = nnz_B.x;
 | 
			
		||||
  uint batch = nnz_B.y;
 | 
			
		||||
 | 
			
		||||
  if (b == batch) {
 | 
			
		||||
    batch_ptr[b] = (long)nnz;
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  uint lo = 0;
 | 
			
		||||
  uint hi = nnz;
 | 
			
		||||
  long key = (long)b;
 | 
			
		||||
  while (lo < hi) {
 | 
			
		||||
    uint mid = (lo + hi) >> 1;
 | 
			
		||||
    long v = batches[mid];
 | 
			
		||||
    if (v < key) lo = mid + 1;
 | 
			
		||||
    else         hi = mid;
 | 
			
		||||
  }
 | 
			
		||||
  batch_ptr[b] = (long)lo;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void spmm_addmm_coo(
 | 
			
		||||
    device const long*   indices2d   [[buffer(0)]],
 | 
			
		||||
    device const T*      vals        [[buffer(1)]],
 | 
			
		||||
    device const T*      dense       [[buffer(2)]],
 | 
			
		||||
    device const T*      t_in        [[buffer(3)]],
 | 
			
		||||
    device T*            out         [[buffer(4)]],
 | 
			
		||||
    constant uint3&      dims        [[buffer(5)]],
 | 
			
		||||
    constant float2&     alpha_beta  [[buffer(6)]],
 | 
			
		||||
    constant uint&       nnz         [[buffer(7)]],
 | 
			
		||||
    uint3                tid         [[thread_position_in_grid]])
 | 
			
		||||
{
 | 
			
		||||
  const uint K = dims.z;
 | 
			
		||||
  const uint k = tid.x;
 | 
			
		||||
  const uint i = tid.z;
 | 
			
		||||
  const float alpha = alpha_beta.x;
 | 
			
		||||
  const float beta = alpha_beta.y;
 | 
			
		||||
 | 
			
		||||
  device const long* rows = indices2d;
 | 
			
		||||
  device const long* cols = indices2d + nnz;
 | 
			
		||||
 | 
			
		||||
  const uint start = lower_bound_i64(rows, 0u, nnz, (long)i);
 | 
			
		||||
  const uint end = upper_bound_i64(rows, 0u, nnz, (long)i);
 | 
			
		||||
 | 
			
		||||
  // accumulator is float for scalar/half/bfloat and float2 for float2
 | 
			
		||||
  auto acc = static_cast<accum_t<T>>(T(0));
 | 
			
		||||
 | 
			
		||||
  for (uint p = start; p < end; ++p) {
 | 
			
		||||
    const uint c = (uint)cols[p];
 | 
			
		||||
    const auto v = static_cast<accum_t<T>>(vals[p]);
 | 
			
		||||
    const uint dense_off = c * K + k;
 | 
			
		||||
    const auto d = static_cast<accum_t<T>>(dense[dense_off]);
 | 
			
		||||
    acc += mul(v, d);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const uint off = i * K + k;
 | 
			
		||||
  const auto base = (beta != 0.0f) ? (static_cast<accum_t<T>>(t_in[off]) * beta) : static_cast<accum_t<T>>(T(0));
 | 
			
		||||
  const auto y = base + alpha * acc;
 | 
			
		||||
  out[off] = static_cast<T>(y);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE)                                 \
 | 
			
		||||
  template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void     \
 | 
			
		||||
  dense_sparse_mul_kernel<DTYPE>(                                           \
 | 
			
		||||
@ -151,6 +315,36 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2);
 | 
			
		||||
      constant uint2&     dims_output   [[buffer(8)]],                       \
 | 
			
		||||
      uint3               gid           [[thread_position_in_grid]]);
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_FUSED_GATHER_MUL(float);
 | 
			
		||||
INSTANTIATE_FUSED_GATHER_MUL(half);
 | 
			
		||||
INSTANTIATE_FUSED_GATHER_MUL(bfloat);
 | 
			
		||||
INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE)                         \
 | 
			
		||||
  template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void    \
 | 
			
		||||
  spmm_bmm_coo_rows_grouped<DTYPE>(                                          \
 | 
			
		||||
      device const long*   rows      [[buffer(0)]],                          \
 | 
			
		||||
      device const long*   cols      [[buffer(1)]],                          \
 | 
			
		||||
      device const DTYPE*  vals      [[buffer(2)]],                          \
 | 
			
		||||
      device const DTYPE*  dense     [[buffer(3)]],                          \
 | 
			
		||||
      device DTYPE*        out       [[buffer(4)]],                          \
 | 
			
		||||
      device const long*   row_ptr   [[buffer(5)]],                          \
 | 
			
		||||
      constant uint4&      dims      [[buffer(6)]],                          \
 | 
			
		||||
      uint3                tid       [[thread_position_in_grid]],            \
 | 
			
		||||
      uint3                ltid      [[thread_position_in_threadgroup]],     \
 | 
			
		||||
      uint3                tptg      [[threads_per_threadgroup]]);
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_SPMM_ADDMM_COO(DTYPE) \
 | 
			
		||||
  template [[host_name("spmm_addmm_coo_" #DTYPE)]] kernel void  \
 | 
			
		||||
  spmm_addmm_coo<DTYPE>(                                        \
 | 
			
		||||
    device const long*   indices2d   [[buffer(0)]],             \
 | 
			
		||||
    device const DTYPE*  vals        [[buffer(1)]],             \
 | 
			
		||||
    device const DTYPE*  dense       [[buffer(2)]],             \
 | 
			
		||||
    device const DTYPE*  t_in        [[buffer(3)]],             \
 | 
			
		||||
    device DTYPE*        out         [[buffer(4)]],             \
 | 
			
		||||
    constant uint3&      dims        [[buffer(5)]],             \
 | 
			
		||||
    constant float2&     alpha_beta  [[buffer(6)]],             \
 | 
			
		||||
    constant uint&       nnz         [[buffer(7)]],             \
 | 
			
		||||
    uint3                tid         [[thread_position_in_grid]]);
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_ADDMM_COO);
 | 
			
		||||
 | 
			
		||||
@ -1751,8 +1751,8 @@ def maybe_snapshot_memory(should_snapshot_memory, suffix):
 | 
			
		||||
                        f"{output_filename.rstrip('.csv')}_{suffix}.pickle",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                log.error("Failed to save memory snapshot, %s", e)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                log.exception("Failed to save memory snapshot")
 | 
			
		||||
 | 
			
		||||
            torch.cuda.memory._record_memory_history(enabled=None)
 | 
			
		||||
 | 
			
		||||
@ -2284,9 +2284,11 @@ class BenchmarkRunner:
 | 
			
		||||
                    )
 | 
			
		||||
                ):
 | 
			
		||||
                    is_same = False
 | 
			
		||||
            except Exception:
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                # Sometimes torch.allclose may throw RuntimeError
 | 
			
		||||
                is_same = False
 | 
			
		||||
                exception_string = str(e)
 | 
			
		||||
                accuracy_status = f"fail_exception: {exception_string}"
 | 
			
		||||
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
 | 
			
		||||
 | 
			
		||||
            if not is_same:
 | 
			
		||||
                accuracy_status = "eager_two_runs_differ"
 | 
			
		||||
@ -2403,9 +2405,11 @@ class BenchmarkRunner:
 | 
			
		||||
                    force_max_multiplier=force_max_multiplier,
 | 
			
		||||
                ):
 | 
			
		||||
                    is_same = False
 | 
			
		||||
            except Exception:
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                # Sometimes torch.allclose may throw RuntimeError
 | 
			
		||||
                is_same = False
 | 
			
		||||
                exception_string = str(e)
 | 
			
		||||
                accuracy_status = f"fail_exception: {exception_string}"
 | 
			
		||||
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
 | 
			
		||||
 | 
			
		||||
            if not is_same:
 | 
			
		||||
                if self.args.skip_accuracy_check:
 | 
			
		||||
 | 
			
		||||
@ -124,7 +124,7 @@ with open(MODELS_FILENAME) as fh:
 | 
			
		||||
            continue
 | 
			
		||||
        batch_size = int(batch_size)
 | 
			
		||||
        BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
 | 
			
		||||
assert len(BATCH_SIZE_KNOWN_MODELS)
 | 
			
		||||
assert BATCH_SIZE_KNOWN_MODELS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
 | 
			
		||||
@ -296,8 +296,8 @@ class OperatorInputsLoader:
 | 
			
		||||
        for key in self.operator_db.keys():
 | 
			
		||||
            try:
 | 
			
		||||
                op = eval(key)
 | 
			
		||||
            except AttributeError as ae:
 | 
			
		||||
                log.warning("Evaluating an op name into an OpOverload: %s", ae)
 | 
			
		||||
            except AttributeError:
 | 
			
		||||
                log.warning("Evaluating an op name into an OpOverload", exc_info=True)
 | 
			
		||||
                continue
 | 
			
		||||
            yield op
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ import sys
 | 
			
		||||
from benchmark_base import BenchmarkBase
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch._dynamo.utils import CompileTimeInstructionCounter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Benchmark(BenchmarkBase):
 | 
			
		||||
@ -32,7 +33,11 @@ class Benchmark(BenchmarkBase):
 | 
			
		||||
    def _work(self):
 | 
			
		||||
        # enable_cpp_symbolic_shape_guards has impact on this benchmark
 | 
			
		||||
        # Keep using False value for consistency.
 | 
			
		||||
        with torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False):
 | 
			
		||||
        with (
 | 
			
		||||
            torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
 | 
			
		||||
            torch._export.config.patch(use_new_tracer_experimental=True),
 | 
			
		||||
            CompileTimeInstructionCounter.record(),
 | 
			
		||||
        ):
 | 
			
		||||
            torch.export.export(self.m, (self.input,), strict=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ update_hint_regression,compile_time_instruction_count,1719000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1
 | 
			
		||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -85,7 +85,7 @@ class WeightOnlyInt8QuantHandler:
 | 
			
		||||
                cur_state_dict[f"{fqn}.weight"] = int8_weight
 | 
			
		||||
                cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
 | 
			
		||||
            elif isinstance(mod, ConditionalFeedForward):
 | 
			
		||||
                for weight_idx in range(0, 3):
 | 
			
		||||
                for weight_idx in range(3):
 | 
			
		||||
                    weight_name = f"w{weight_idx + 1}"
 | 
			
		||||
                    scales_name = f"scales{weight_idx + 1}"
 | 
			
		||||
                    weight = getattr(mod, weight_name)
 | 
			
		||||
 | 
			
		||||
@ -1729,10 +1729,8 @@ def define_buck_targets(
 | 
			
		||||
            "torch/csrc/jit/backends/backend_debug_info.cpp",
 | 
			
		||||
            "torch/csrc/jit/backends/backend_interface.cpp",
 | 
			
		||||
        ],
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags() + select({
 | 
			
		||||
            "DEFAULT": [],
 | 
			
		||||
            "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
 | 
			
		||||
        }),
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags(),
 | 
			
		||||
        fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
 | 
			
		||||
        # @lint-ignore BUCKLINT link_whole
 | 
			
		||||
        link_whole = True,
 | 
			
		||||
        linker_flags = get_no_as_needed_linker_flag(),
 | 
			
		||||
@ -2025,9 +2023,6 @@ def define_buck_targets(
 | 
			
		||||
                "ovr_config//os:android-x86_64": [
 | 
			
		||||
                    "-mssse3",
 | 
			
		||||
                ],
 | 
			
		||||
            }) + select({
 | 
			
		||||
                "DEFAULT": [],
 | 
			
		||||
                "ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
 | 
			
		||||
            }),
 | 
			
		||||
            exported_preprocessor_flags = get_aten_preprocessor_flags(),
 | 
			
		||||
            exported_deps = [
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
#include <c10/core/AllocatorConfig.h>
 | 
			
		||||
#include <c10/core/DeviceType.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
 | 
			
		||||
namespace c10::CachingAllocator {
 | 
			
		||||
@ -47,7 +46,7 @@ size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) {
 | 
			
		||||
      63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart);
 | 
			
		||||
  const size_t interval_end =
 | 
			
		||||
      63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      interval_end - interval_start == kRoundUpPowerOfTwoIntervals,
 | 
			
		||||
      "kRoundUpPowerOfTwoIntervals mismatch");
 | 
			
		||||
 | 
			
		||||
@ -66,7 +65,7 @@ size_t AcceleratorAllocatorConfig::parseMaxSplitSize(
 | 
			
		||||
      std::numeric_limits<size_t>::max() / kMB;
 | 
			
		||||
 | 
			
		||||
  size_t val_env = tokenizer.toSizeT(++i);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      val_env >= min_allowed_split_size_mb,
 | 
			
		||||
      "CachingAllocator option max_split_size_mb too small, must be >= ",
 | 
			
		||||
      min_allowed_split_size_mb);
 | 
			
		||||
@ -85,7 +84,7 @@ size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize(
 | 
			
		||||
      std::numeric_limits<size_t>::max() / kMB;
 | 
			
		||||
 | 
			
		||||
  size_t val_env = tokenizer.toSizeT(++i);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      val_env >= min_allowed_split_size_mb,
 | 
			
		||||
      "CachingAllocator option max_non_split_rounding_mb too small, must be >= ",
 | 
			
		||||
      min_allowed_split_size_mb);
 | 
			
		||||
@ -100,7 +99,7 @@ size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold(
 | 
			
		||||
    size_t i) {
 | 
			
		||||
  tokenizer.checkToken(++i, ":");
 | 
			
		||||
  double val_env = tokenizer.toDouble(++i);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      val_env > 0 && val_env < 1.0,
 | 
			
		||||
      "garbage_collect_threshold is invalid, set it in (0.0, 1.0)");
 | 
			
		||||
  garbage_collection_threshold_ = val_env;
 | 
			
		||||
@ -121,7 +120,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
 | 
			
		||||
      size_t value_index = i;
 | 
			
		||||
      tokenizer.checkToken(++i, ":");
 | 
			
		||||
      size_t value = tokenizer.toSizeT(++i);
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
      TORCH_CHECK_VALUE(
 | 
			
		||||
          value == 0 || llvm::isPowerOf2_64(value),
 | 
			
		||||
          "For roundups, the divisions has to be power of 2 or 0 to disable roundup ");
 | 
			
		||||
 | 
			
		||||
@ -129,12 +128,13 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
 | 
			
		||||
        std::fill(
 | 
			
		||||
            std::next(
 | 
			
		||||
                roundup_power2_divisions_.begin(),
 | 
			
		||||
                static_cast<std::vector<size_t>::difference_type>(last_index)),
 | 
			
		||||
                static_cast<std::vector<size_t>::difference_type>(
 | 
			
		||||
                    last_index + 1)),
 | 
			
		||||
            roundup_power2_divisions_.end(),
 | 
			
		||||
            value);
 | 
			
		||||
      } else {
 | 
			
		||||
        size_t boundary = tokenizer.toSizeT(value_index);
 | 
			
		||||
        TORCH_CHECK(
 | 
			
		||||
        TORCH_CHECK_VALUE(
 | 
			
		||||
            llvm::isPowerOf2_64(boundary),
 | 
			
		||||
            "For roundups, the intervals have to be power of 2 ");
 | 
			
		||||
 | 
			
		||||
@ -164,7 +164,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
 | 
			
		||||
        "Expected closing bracket ']' in ConfigTokenizer but reached end of config");
 | 
			
		||||
  } else { // Keep this for backwards compatibility
 | 
			
		||||
    size_t value = tokenizer.toSizeT(i);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
    TORCH_CHECK_VALUE(
 | 
			
		||||
        llvm::isPowerOf2_64(value),
 | 
			
		||||
        "For roundups, the divisions has to be power of 2 ");
 | 
			
		||||
    std::fill(
 | 
			
		||||
@ -224,7 +224,7 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
 | 
			
		||||
      // If a device-specific configuration parser hook is registered, it will
 | 
			
		||||
      // check if the key is unrecognized.
 | 
			
		||||
      if (device_config_parser_hook_) {
 | 
			
		||||
        TORCH_CHECK(
 | 
			
		||||
        TORCH_CHECK_VALUE(
 | 
			
		||||
            getKeys().find(key) != getKeys().end(),
 | 
			
		||||
            "Unrecognized key '",
 | 
			
		||||
            key,
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ class ConfigTokenizer {
 | 
			
		||||
    } else if (token == "False") {
 | 
			
		||||
      return false;
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
      TORCH_CHECK_VALUE(
 | 
			
		||||
          false,
 | 
			
		||||
          "Expected 'True' or 'False' at index ",
 | 
			
		||||
          i,
 | 
			
		||||
@ -253,7 +253,7 @@ class C10_API AcceleratorAllocatorConfig {
 | 
			
		||||
    device_config_parser_hook_ = std::move(hook);
 | 
			
		||||
    auto& mutable_keys = getMutableKeys();
 | 
			
		||||
    for (auto& key : keys) {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
      TORCH_CHECK_VALUE(
 | 
			
		||||
          mutable_keys.insert(key).second,
 | 
			
		||||
          "Duplicated key '",
 | 
			
		||||
          key,
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,6 @@
 | 
			
		||||
#include <c10/core/SymNodeImpl.h>
 | 
			
		||||
#include <c10/util/intrusive_ptr.h>
 | 
			
		||||
#include <c10/util/safe_numerics.h>
 | 
			
		||||
#include <functional>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@
 | 
			
		||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
 | 
			
		||||
#include <c10/util/Logging.h>
 | 
			
		||||
#include <c10/util/accumulate.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,5 @@
 | 
			
		||||
#include <c10/core/TensorOptions.h>
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/core/Layout.h>
 | 
			
		||||
#include <c10/util/Optional.h>
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Allocator.h>
 | 
			
		||||
#include <c10/core/StorageImpl.h>
 | 
			
		||||
#include <c10/core/alignment.h>
 | 
			
		||||
#include <c10/core/impl/COWDeleter.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/ParallelGuard.h>
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
#include <c10/core/DispatchKey.h>
 | 
			
		||||
#include <c10/core/SafePyObject.h>
 | 
			
		||||
#include <c10/core/impl/LocalDispatchKeySet.h>
 | 
			
		||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
 | 
			
		||||
  tokenizer.checkToken(++i, ":");
 | 
			
		||||
  i++; // Move to the value after the colon
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      ((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
 | 
			
		||||
       (tokenizer[i] == PYTORCH_TOKEN2)),
 | 
			
		||||
      "Unknown allocator backend, "
 | 
			
		||||
@ -36,7 +36,7 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
 | 
			
		||||
      " != ",
 | 
			
		||||
      get()->name());
 | 
			
		||||
#else // USE_ROCM
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      ((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1)),
 | 
			
		||||
      "Unknown allocator backend, "
 | 
			
		||||
      "options are native and " PYTORCH_TOKEN1);
 | 
			
		||||
@ -109,7 +109,7 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
 | 
			
		||||
    } else {
 | 
			
		||||
      const auto& keys =
 | 
			
		||||
          c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
      TORCH_CHECK_VALUE(
 | 
			
		||||
          keys.find(key) != keys.end(),
 | 
			
		||||
          "Unrecognized key '",
 | 
			
		||||
          key,
 | 
			
		||||
@ -151,12 +151,12 @@ size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
 | 
			
		||||
    size_t i) {
 | 
			
		||||
  tokenizer.checkToken(++i, ":");
 | 
			
		||||
  size_t val2 = tokenizer.toSizeT(++i);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      llvm::isPowerOf2_64(val2),
 | 
			
		||||
      "Number of register threads has to be power of 2, got ",
 | 
			
		||||
      val2);
 | 
			
		||||
  auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      val2 <= maxThreads,
 | 
			
		||||
      "Number of register threads should be less than or equal to ",
 | 
			
		||||
      maxThreads,
 | 
			
		||||
@ -171,7 +171,8 @@ size_t CUDAAllocatorConfig::parsePinnedReserveSegmentSize(
 | 
			
		||||
    size_t i) {
 | 
			
		||||
  tokenizer.checkToken(++i, ":");
 | 
			
		||||
  size_t val2 = tokenizer.toSizeT(++i);
 | 
			
		||||
  TORCH_CHECK(val2 > 0, "Pinned reserve segment size has to be greater than 0");
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      val2 > 0, "Pinned reserve segment size has to be greater than 0");
 | 
			
		||||
  m_pinned_reserve_segment_size_mb = val2;
 | 
			
		||||
  return i;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@
 | 
			
		||||
#include <c10/core/AllocatorConfig.h>
 | 
			
		||||
#include <c10/cuda/CUDAException.h>
 | 
			
		||||
#include <c10/cuda/CUDAMacros.h>
 | 
			
		||||
#include <c10/util/Deprecated.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
 | 
			
		||||
@ -17,9 +18,14 @@ enum class Expandable_Segments_Handle_Type : int {
 | 
			
		||||
// Environment config parser
 | 
			
		||||
class C10_CUDA_API CUDAAllocatorConfig {
 | 
			
		||||
 public:
 | 
			
		||||
  C10_DEPRECATED_MESSAGE(
 | 
			
		||||
      "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.")
 | 
			
		||||
  static size_t max_split_size() {
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  C10_DEPRECATED_MESSAGE(
 | 
			
		||||
      "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.")
 | 
			
		||||
  static double garbage_collection_threshold() {
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
 | 
			
		||||
        garbage_collection_threshold();
 | 
			
		||||
@ -64,6 +70,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
 | 
			
		||||
    return instance().m_pinned_num_register_threads;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  C10_DEPRECATED_MESSAGE(
 | 
			
		||||
      "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.")
 | 
			
		||||
  static bool pinned_use_background_threads() {
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
 | 
			
		||||
        pinned_use_background_threads();
 | 
			
		||||
@ -80,11 +88,15 @@ class C10_CUDA_API CUDAAllocatorConfig {
 | 
			
		||||
    return 128;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  C10_DEPRECATED_MESSAGE(
 | 
			
		||||
      "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
 | 
			
		||||
  static size_t roundup_power2_divisions(size_t size) {
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
 | 
			
		||||
        roundup_power2_divisions(size);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  C10_DEPRECATED_MESSAGE(
 | 
			
		||||
      "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
 | 
			
		||||
  static std::vector<size_t> roundup_power2_divisions() {
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
 | 
			
		||||
        roundup_power2_divisions();
 | 
			
		||||
@ -95,6 +107,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
 | 
			
		||||
        max_non_split_rounding_size();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  C10_DEPRECATED_MESSAGE(
 | 
			
		||||
      "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.")
 | 
			
		||||
  static std::string last_allocator_settings() {
 | 
			
		||||
    return c10::CachingAllocator::getAllocatorSettings();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1260,6 +1260,9 @@ class DeviceCachingAllocator {
 | 
			
		||||
  // thread local compile context for each device
 | 
			
		||||
  static thread_local std::stack<std::string> compile_context;
 | 
			
		||||
 | 
			
		||||
  // thread local user metadata for annotating allocations
 | 
			
		||||
  static thread_local std::string user_metadata;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 | 
			
		||||
  explicit DeviceCachingAllocator(c10::DeviceIndex id)
 | 
			
		||||
@ -1267,7 +1270,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
        large_blocks(/*small=*/false),
 | 
			
		||||
        small_blocks(/*small=*/true) {
 | 
			
		||||
    stats.max_split_size =
 | 
			
		||||
        static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
 | 
			
		||||
        static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
 | 
			
		||||
    context_recorder_.store(nullptr);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1302,6 +1305,14 @@ class DeviceCachingAllocator {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void setUserMetadata(const std::string& metadata) {
 | 
			
		||||
    user_metadata = metadata;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string getUserMetadata() {
 | 
			
		||||
    return user_metadata;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool checkPoolLiveAllocations(
 | 
			
		||||
      MempoolId_t mempool_id,
 | 
			
		||||
      const std::unordered_set<void*>& expected_live_allocations) const {
 | 
			
		||||
@ -1394,7 +1405,8 @@ class DeviceCachingAllocator {
 | 
			
		||||
      // Do garbage collection if the flag is set.
 | 
			
		||||
      if (C10_UNLIKELY(
 | 
			
		||||
              set_fraction &&
 | 
			
		||||
              CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
 | 
			
		||||
              AcceleratorAllocatorConfig::garbage_collection_threshold() >
 | 
			
		||||
                  0.0)) {
 | 
			
		||||
        garbage_collect_cached_blocks(context);
 | 
			
		||||
      }
 | 
			
		||||
      // Attempt allocate
 | 
			
		||||
@ -1646,7 +1658,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
      stats.active_bytes[stat_type].increase(block->size);
 | 
			
		||||
      stats.requested_bytes[stat_type].increase(block->requested_size);
 | 
			
		||||
    });
 | 
			
		||||
    if (block->size >= CUDAAllocatorConfig::max_split_size())
 | 
			
		||||
    if (block->size >= AcceleratorAllocatorConfig::max_split_size())
 | 
			
		||||
      stats.oversize_allocations.increase(1);
 | 
			
		||||
 | 
			
		||||
    auto allocated_bytes_gauge =
 | 
			
		||||
@ -1915,7 +1927,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
        block->pool->owner_MempoolId(),
 | 
			
		||||
        context ? context : block->context_when_allocated);
 | 
			
		||||
 | 
			
		||||
    if (block->size >= CUDAAllocatorConfig::max_split_size())
 | 
			
		||||
    if (block->size >= AcceleratorAllocatorConfig::max_split_size())
 | 
			
		||||
      stats.oversize_allocations.decrease(1);
 | 
			
		||||
 | 
			
		||||
    // If the block has been used on more than one stream, handle accordingly.
 | 
			
		||||
@ -2488,7 +2500,8 @@ class DeviceCachingAllocator {
 | 
			
		||||
    if (size < kMinBlockSize) {
 | 
			
		||||
      return kMinBlockSize;
 | 
			
		||||
    } else {
 | 
			
		||||
      auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
 | 
			
		||||
      auto divisions =
 | 
			
		||||
          AcceleratorAllocatorConfig::roundup_power2_divisions(size);
 | 
			
		||||
      if (divisions > 1 && size > (kMinBlockSize * divisions)) {
 | 
			
		||||
        return roundup_power2_next_division(size, divisions);
 | 
			
		||||
      } else {
 | 
			
		||||
@ -2982,7 +2995,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
    if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
 | 
			
		||||
      return remaining >= kMinBlockSize;
 | 
			
		||||
    } else {
 | 
			
		||||
      return (size < CUDAAllocatorConfig::max_split_size()) &&
 | 
			
		||||
      return (size < AcceleratorAllocatorConfig::max_split_size()) &&
 | 
			
		||||
          (remaining > kSmallSize);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -3002,7 +3015,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
 | 
			
		||||
    if (C10_UNLIKELY(
 | 
			
		||||
            set_fraction &&
 | 
			
		||||
            CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
 | 
			
		||||
            AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
 | 
			
		||||
      // Track block reuse interval only when garbage collection is enabled.
 | 
			
		||||
      ++pool.get_free_blocks_call_count;
 | 
			
		||||
    }
 | 
			
		||||
@ -3044,13 +3057,13 @@ class DeviceCachingAllocator {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Do not return an oversized block for a large request
 | 
			
		||||
    if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
 | 
			
		||||
        ((*it)->size >= CUDAAllocatorConfig::max_split_size()))
 | 
			
		||||
    if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
 | 
			
		||||
        ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
 | 
			
		||||
      return false;
 | 
			
		||||
    // Allow oversized block size to be rounded up but within a limit
 | 
			
		||||
    if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
 | 
			
		||||
    if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
 | 
			
		||||
        ((*it)->size >=
 | 
			
		||||
         p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
 | 
			
		||||
         p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
 | 
			
		||||
      return false;
 | 
			
		||||
    p.block = *it;
 | 
			
		||||
    pool.blocks.erase(it);
 | 
			
		||||
@ -3073,7 +3086,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
    // therefore should be of less overheads.
 | 
			
		||||
 | 
			
		||||
    size_t gc_threshold = static_cast<size_t>(
 | 
			
		||||
        CUDAAllocatorConfig::garbage_collection_threshold() *
 | 
			
		||||
        AcceleratorAllocatorConfig::garbage_collection_threshold() *
 | 
			
		||||
        static_cast<double>(allowed_memory_maximum));
 | 
			
		||||
    // No need to trigger GC yet
 | 
			
		||||
    if (total_allocated_memory <= gc_threshold) {
 | 
			
		||||
@ -3221,7 +3234,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
      stats.segment[stat_type].increase(1);
 | 
			
		||||
      stats.reserved_bytes[stat_type].increase(size);
 | 
			
		||||
    });
 | 
			
		||||
    if (size >= CUDAAllocatorConfig::max_split_size())
 | 
			
		||||
    if (size >= AcceleratorAllocatorConfig::max_split_size())
 | 
			
		||||
      stats.oversize_segments.increase(1);
 | 
			
		||||
    auto reserved_bytes_gauge =
 | 
			
		||||
        STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
 | 
			
		||||
@ -3250,7 +3263,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
  bool release_available_cached_blocks(
 | 
			
		||||
      const AllocParams& p,
 | 
			
		||||
      const std::shared_ptr<GatheredContext>& context) {
 | 
			
		||||
    if (CUDAAllocatorConfig::max_split_size() ==
 | 
			
		||||
    if (AcceleratorAllocatorConfig::max_split_size() ==
 | 
			
		||||
        std::numeric_limits<size_t>::max())
 | 
			
		||||
      return false;
 | 
			
		||||
    BlockPool& pool = *p.pool;
 | 
			
		||||
@ -3258,8 +3271,8 @@ class DeviceCachingAllocator {
 | 
			
		||||
    // because of std::unique_ptr, block cannot be trivially copied
 | 
			
		||||
    // Use constructor for search key.
 | 
			
		||||
    Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
 | 
			
		||||
    key.size = (key.size < CUDAAllocatorConfig::max_split_size())
 | 
			
		||||
        ? CUDAAllocatorConfig::max_split_size()
 | 
			
		||||
    key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
 | 
			
		||||
        ? AcceleratorAllocatorConfig::max_split_size()
 | 
			
		||||
        : key.size;
 | 
			
		||||
    auto it = pool.blocks.lower_bound(&key);
 | 
			
		||||
    if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
 | 
			
		||||
@ -3272,7 +3285,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
      --it; // Back up one item.  Now on the largest block for the correct
 | 
			
		||||
            // stream
 | 
			
		||||
      while ((totalReleased < key.size) &&
 | 
			
		||||
             ((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
 | 
			
		||||
             ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
 | 
			
		||||
             ((*it)->stream == p.stream())) {
 | 
			
		||||
        auto cur = it;
 | 
			
		||||
        bool is_first = cur == pool.blocks.begin();
 | 
			
		||||
@ -3397,7 +3410,7 @@ class DeviceCachingAllocator {
 | 
			
		||||
        stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
 | 
			
		||||
            .current);
 | 
			
		||||
 | 
			
		||||
    if (block->size >= CUDAAllocatorConfig::max_split_size())
 | 
			
		||||
    if (block->size >= AcceleratorAllocatorConfig::max_split_size())
 | 
			
		||||
      stats.oversize_segments.decrease(1);
 | 
			
		||||
    pool->blocks.erase(block);
 | 
			
		||||
    delete block;
 | 
			
		||||
@ -3682,7 +3695,8 @@ class DeviceCachingAllocator {
 | 
			
		||||
        mempool_id,
 | 
			
		||||
        getApproximateTime(),
 | 
			
		||||
        record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr,
 | 
			
		||||
        compile_string);
 | 
			
		||||
        compile_string,
 | 
			
		||||
        user_metadata);
 | 
			
		||||
 | 
			
		||||
    // Callbacks should not include any Pytorch call
 | 
			
		||||
    for (const auto& cb : trace_trackers_) {
 | 
			
		||||
@ -3737,6 +3751,7 @@ static void uncached_delete(void* ptr) {
 | 
			
		||||
 | 
			
		||||
static void local_raw_delete(void* ptr);
 | 
			
		||||
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
 | 
			
		||||
thread_local std::string DeviceCachingAllocator::user_metadata;
 | 
			
		||||
#ifdef __cpp_lib_hardware_interference_size
 | 
			
		||||
using std::hardware_destructive_interference_size;
 | 
			
		||||
#else
 | 
			
		||||
@ -3934,6 +3949,18 @@ class NativeCachingAllocator : public CUDAAllocator {
 | 
			
		||||
    device_allocator[device]->popCompileContext();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void setUserMetadata(const std::string& metadata) override {
 | 
			
		||||
    c10::DeviceIndex device = 0;
 | 
			
		||||
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
 | 
			
		||||
    device_allocator[device]->setUserMetadata(metadata);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string getUserMetadata() override {
 | 
			
		||||
    c10::DeviceIndex device = 0;
 | 
			
		||||
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
 | 
			
		||||
    return device_allocator[device]->getUserMetadata();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isHistoryEnabled() override {
 | 
			
		||||
    c10::DeviceIndex device = 0;
 | 
			
		||||
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
 | 
			
		||||
@ -4034,8 +4061,8 @@ class NativeCachingAllocator : public CUDAAllocator {
 | 
			
		||||
 | 
			
		||||
    auto& md = result.config_metadata;
 | 
			
		||||
    md.garbage_collection_threshold =
 | 
			
		||||
        CUDAAllocatorConfig::garbage_collection_threshold();
 | 
			
		||||
    md.max_split_size = CUDAAllocatorConfig::max_split_size();
 | 
			
		||||
        AcceleratorAllocatorConfig::garbage_collection_threshold();
 | 
			
		||||
    md.max_split_size = AcceleratorAllocatorConfig::max_split_size();
 | 
			
		||||
    md.pinned_num_register_threads =
 | 
			
		||||
        CUDAAllocatorConfig::pinned_num_register_threads();
 | 
			
		||||
    md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
 | 
			
		||||
@ -4043,11 +4070,12 @@ class NativeCachingAllocator : public CUDAAllocator {
 | 
			
		||||
        CUDAAllocatorConfig::release_lock_on_cudamalloc();
 | 
			
		||||
    md.pinned_use_host_register =
 | 
			
		||||
        CUDAAllocatorConfig::pinned_use_cuda_host_register();
 | 
			
		||||
    md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
 | 
			
		||||
    md.last_allocator_settings =
 | 
			
		||||
        AcceleratorAllocatorConfig::last_allocator_settings();
 | 
			
		||||
    md.graph_capture_record_stream_reuse =
 | 
			
		||||
        CUDAAllocatorConfig::graph_capture_record_stream_reuse();
 | 
			
		||||
    md.roundup_power2_divisions =
 | 
			
		||||
        CUDAAllocatorConfig::roundup_power2_divisions();
 | 
			
		||||
        AcceleratorAllocatorConfig::roundup_power2_divisions();
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
@ -4425,11 +4453,12 @@ CUDAAllocator* allocator();
 | 
			
		||||
} // namespace CudaMallocAsync
 | 
			
		||||
 | 
			
		||||
struct BackendStaticInitializer {
 | 
			
		||||
  // Parses env for backend at load time, duplicating some logic from
 | 
			
		||||
  // CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
 | 
			
		||||
  // runtime). Defers verbose exceptions and error checks, including Cuda
 | 
			
		||||
  // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
 | 
			
		||||
  // works, maybe we should move all of CUDAAllocatorConfig here?
 | 
			
		||||
  // Parses the environment configuration for CUDA/ROCm allocator backend at
 | 
			
		||||
  // load time. This duplicates some logic from CUDAAllocatorConfig to ensure
 | 
			
		||||
  // lazy initialization without triggering global static constructors. The
 | 
			
		||||
  // function looks for the key "backend" and returns the appropriate allocator
 | 
			
		||||
  // instance based on its value. If no valid configuration is found, it falls
 | 
			
		||||
  // back to the default Native allocator.
 | 
			
		||||
  CUDAAllocator* parseEnvForBackend() {
 | 
			
		||||
    auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
@ -4438,34 +4467,35 @@ struct BackendStaticInitializer {
 | 
			
		||||
      val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    if (!val.has_value()) {
 | 
			
		||||
      val = c10::utils::get_env("PYTORCH_ALLOC_CONF");
 | 
			
		||||
    }
 | 
			
		||||
    if (val.has_value()) {
 | 
			
		||||
      const std::string& config = val.value();
 | 
			
		||||
 | 
			
		||||
      std::regex exp("[\\s,]+");
 | 
			
		||||
      std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
 | 
			
		||||
      std::sregex_token_iterator end;
 | 
			
		||||
      std::vector<std::string> options(it, end);
 | 
			
		||||
 | 
			
		||||
      for (auto option : options) {
 | 
			
		||||
        std::regex exp2("[:]+");
 | 
			
		||||
        std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
 | 
			
		||||
        std::sregex_token_iterator end2;
 | 
			
		||||
        std::vector<std::string> kv(it2, end2);
 | 
			
		||||
        if (kv.size() >= 2) {
 | 
			
		||||
          if (kv[0] == "backend") {
 | 
			
		||||
      c10::CachingAllocator::ConfigTokenizer tokenizer(val.value());
 | 
			
		||||
      for (size_t i = 0; i < tokenizer.size(); i++) {
 | 
			
		||||
        const auto& key = tokenizer[i];
 | 
			
		||||
        if (key == "backend") {
 | 
			
		||||
          tokenizer.checkToken(++i, ":");
 | 
			
		||||
          i++; // Move to the value after the colon
 | 
			
		||||
          if (tokenizer[i] == "cudaMallocAsync"
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
            // convenience for ROCm users to allow either CUDA or HIP env var
 | 
			
		||||
            if (kv[1] == "cudaMallocAsync" || kv[1] == "hipMallocAsync")
 | 
			
		||||
#else
 | 
			
		||||
            if (kv[1] == "cudaMallocAsync")
 | 
			
		||||
              // convenience for ROCm users to allow either CUDA or HIP env var
 | 
			
		||||
              || tokenizer[i] == "hipMallocAsync"
 | 
			
		||||
#endif
 | 
			
		||||
              return CudaMallocAsync::allocator();
 | 
			
		||||
            if (kv[1] == "native")
 | 
			
		||||
              return &Native::allocator;
 | 
			
		||||
          ) {
 | 
			
		||||
            return CudaMallocAsync::allocator();
 | 
			
		||||
          }
 | 
			
		||||
          break;
 | 
			
		||||
        } else {
 | 
			
		||||
          // Skip the key and its value
 | 
			
		||||
          i = tokenizer.skipKey(i);
 | 
			
		||||
        }
 | 
			
		||||
        if (i + 1 < tokenizer.size()) {
 | 
			
		||||
          tokenizer.checkToken(++i, ",");
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    // Default fallback allocator.
 | 
			
		||||
    return &Native::allocator;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -118,7 +118,8 @@ struct TraceEntry {
 | 
			
		||||
      MempoolId_t mempool,
 | 
			
		||||
      approx_time_t time,
 | 
			
		||||
      std::shared_ptr<GatheredContext> context = nullptr,
 | 
			
		||||
      std::string compile_context = "")
 | 
			
		||||
      std::string compile_context = "",
 | 
			
		||||
      std::string user_metadata = "")
 | 
			
		||||
      : action_(action),
 | 
			
		||||
        device_(device),
 | 
			
		||||
        addr_(addr),
 | 
			
		||||
@ -126,7 +127,8 @@ struct TraceEntry {
 | 
			
		||||
        stream_(stream),
 | 
			
		||||
        size_(size),
 | 
			
		||||
        mempool_(std::move(mempool)),
 | 
			
		||||
        compile_context_(std::move(compile_context)) {
 | 
			
		||||
        compile_context_(std::move(compile_context)),
 | 
			
		||||
        user_metadata_(std::move(user_metadata)) {
 | 
			
		||||
    time_.approx_t_ = time;
 | 
			
		||||
  }
 | 
			
		||||
  Action action_;
 | 
			
		||||
@ -138,6 +140,7 @@ struct TraceEntry {
 | 
			
		||||
  MempoolId_t mempool_;
 | 
			
		||||
  trace_time_ time_{};
 | 
			
		||||
  std::string compile_context_;
 | 
			
		||||
  std::string user_metadata_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Calls made by record_function will save annotations
 | 
			
		||||
@ -297,6 +300,10 @@ class CUDAAllocator : public DeviceAllocator {
 | 
			
		||||
      const std::vector<std::pair<std::string, std::string>>& /*md*/) {}
 | 
			
		||||
  virtual void pushCompileContext(std::string& md) {}
 | 
			
		||||
  virtual void popCompileContext() {}
 | 
			
		||||
  virtual void setUserMetadata(const std::string& metadata) {}
 | 
			
		||||
  virtual std::string getUserMetadata() {
 | 
			
		||||
    return "";
 | 
			
		||||
  }
 | 
			
		||||
  virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
 | 
			
		||||
 | 
			
		||||
  // Attached AllocatorTraceTracker callbacks will be called while the
 | 
			
		||||
@ -536,6 +543,14 @@ inline void enablePeerAccess(
 | 
			
		||||
  get()->enablePeerAccess(dev, dev_to_access);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void setUserMetadata(const std::string& metadata) {
 | 
			
		||||
  get()->setUserMetadata(metadata);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline std::string getUserMetadata() {
 | 
			
		||||
  return get()->getUserMetadata();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10::cuda::CUDACachingAllocator
 | 
			
		||||
 | 
			
		||||
namespace c10::cuda {
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,6 @@
 | 
			
		||||
#include <c10/cuda/CUDADeviceAssertionHost.h>
 | 
			
		||||
#include <c10/cuda/CUDAException.h>
 | 
			
		||||
#include <c10/cuda/CUDAFunctions.h>
 | 
			
		||||
#include <c10/util/Backtrace.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,6 @@
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <c10/util/UniqueVoidPtr.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#include <c10/cuda/CUDAMiscFunctions.h>
 | 
			
		||||
#include <c10/util/env.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
namespace c10::cuda {
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/CUDAException.h>
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <c10/util/CallOnce.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Logging.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
 | 
			
		||||
@ -328,5 +328,21 @@ struct pair {
 | 
			
		||||
  T2 second;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
 | 
			
		||||
  MACRO(float);                          \
 | 
			
		||||
  MACRO(half);                           \
 | 
			
		||||
  MACRO(bfloat);                         \
 | 
			
		||||
  MACRO(float2);                         \
 | 
			
		||||
  MACRO(long);                           \
 | 
			
		||||
  MACRO(char);                           \
 | 
			
		||||
  MACRO(uchar);                          \
 | 
			
		||||
  MACRO(short);                          \
 | 
			
		||||
  MACRO(int);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_FOR_FLOAT_TYPES(MACRO) \
 | 
			
		||||
  MACRO(float);                            \
 | 
			
		||||
  MACRO(half);                             \
 | 
			
		||||
  MACRO(bfloat);
 | 
			
		||||
 | 
			
		||||
} // namespace metal
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
@ -67,8 +67,8 @@ TEST(AllocatorConfigTest, allocator_config_test) {
 | 
			
		||||
  EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2);
 | 
			
		||||
  EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
 | 
			
		||||
  EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2);
 | 
			
		||||
  // EXPECT_EQ(
 | 
			
		||||
  //     AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
      AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
      AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1);
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
@ -101,8 +101,8 @@ TEST(AllocatorConfigTest, allocator_config_test) {
 | 
			
		||||
  EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1);
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
      AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0);
 | 
			
		||||
  // EXPECT_EQ(
 | 
			
		||||
  //     AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
      AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
      AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#include <c10/util/ApproximateClock.h>
 | 
			
		||||
#include <c10/util/ArrayRef.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -18,6 +18,7 @@
 | 
			
		||||
#include <c10/macros/Macros.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/SmallVector.h>
 | 
			
		||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
@ -40,200 +41,106 @@ namespace c10 {
 | 
			
		||||
///
 | 
			
		||||
/// This is intended to be trivially copyable, so it should be passed by
 | 
			
		||||
/// value.
 | 
			
		||||
///
 | 
			
		||||
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
 | 
			
		||||
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
 | 
			
		||||
/// the underlying constexpr calls, we rely on apparent-type dispatch for
 | 
			
		||||
/// inheritance. This should be fine because their memory format is the same,
 | 
			
		||||
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
 | 
			
		||||
/// However, you should prefer to use ArrayRef when possible, because its use
 | 
			
		||||
/// of TORCH_CHECK will lead to better user-facing error messages.
 | 
			
		||||
template <typename T>
 | 
			
		||||
class ArrayRef final {
 | 
			
		||||
class ArrayRef final : public HeaderOnlyArrayRef<T> {
 | 
			
		||||
 public:
 | 
			
		||||
  using iterator = const T*;
 | 
			
		||||
  using const_iterator = const T*;
 | 
			
		||||
  using size_type = size_t;
 | 
			
		||||
  using value_type = T;
 | 
			
		||||
 | 
			
		||||
  using reverse_iterator = std::reverse_iterator<iterator>;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  /// The start of the array, in an external buffer.
 | 
			
		||||
  const T* Data;
 | 
			
		||||
 | 
			
		||||
  /// The number of elements.
 | 
			
		||||
  size_type Length;
 | 
			
		||||
 | 
			
		||||
  void debugCheckNullptrInvariant() {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
 | 
			
		||||
        Data != nullptr || Length == 0,
 | 
			
		||||
        "created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  /// @name Constructors
 | 
			
		||||
  /// @name Constructors, all inherited from HeaderOnlyArrayRef except for
 | 
			
		||||
  /// SmallVector.
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /// Construct an empty ArrayRef.
 | 
			
		||||
  /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
 | 
			
		||||
  using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a single element.
 | 
			
		||||
  // TODO Make this explicit
 | 
			
		||||
  constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a pointer and length.
 | 
			
		||||
  constexpr ArrayRef(const T* data, size_t length)
 | 
			
		||||
      : Data(data), Length(length) {
 | 
			
		||||
    debugCheckNullptrInvariant();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a range.
 | 
			
		||||
  constexpr ArrayRef(const T* begin, const T* end)
 | 
			
		||||
      : Data(begin), Length(end - begin) {
 | 
			
		||||
    debugCheckNullptrInvariant();
 | 
			
		||||
  }
 | 
			
		||||
  /// Construct an ArrayRef from a std::vector.
 | 
			
		||||
  /// This constructor is identical to the one in HeaderOnlyArrayRef, but we
 | 
			
		||||
  /// include it to help with Class Template Argument Deduction (CTAD).
 | 
			
		||||
  /// Without it, CTAD can fail sometimes due to the indirect constructor
 | 
			
		||||
  /// inheritance. So we explicitly include this constructor.
 | 
			
		||||
  template <typename A>
 | 
			
		||||
  /* implicit */ ArrayRef(const std::vector<T, A>& Vec)
 | 
			
		||||
      : HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a SmallVector. This is templated in order to
 | 
			
		||||
  /// avoid instantiating SmallVectorTemplateCommon<T> whenever we
 | 
			
		||||
  /// copy-construct an ArrayRef.
 | 
			
		||||
  /// NOTE: this is the only constructor that is not inherited from
 | 
			
		||||
  /// HeaderOnlyArrayRef.
 | 
			
		||||
  template <typename U>
 | 
			
		||||
  /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
 | 
			
		||||
      : Data(Vec.data()), Length(Vec.size()) {
 | 
			
		||||
    debugCheckNullptrInvariant();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <
 | 
			
		||||
      typename Container,
 | 
			
		||||
      typename U = decltype(std::declval<Container>().data()),
 | 
			
		||||
      typename = std::enable_if_t<
 | 
			
		||||
          (std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
 | 
			
		||||
  /* implicit */ ArrayRef(const Container& container)
 | 
			
		||||
      : Data(container.data()), Length(container.size()) {
 | 
			
		||||
    debugCheckNullptrInvariant();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a std::vector.
 | 
			
		||||
  // The enable_if stuff here makes sure that this isn't used for
 | 
			
		||||
  // std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
 | 
			
		||||
  // bitfield.
 | 
			
		||||
  template <typename A>
 | 
			
		||||
  /* implicit */ ArrayRef(const std::vector<T, A>& Vec)
 | 
			
		||||
      : Data(Vec.data()), Length(Vec.size()) {
 | 
			
		||||
    static_assert(
 | 
			
		||||
        !std::is_same_v<T, bool>,
 | 
			
		||||
        "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a std::array
 | 
			
		||||
  template <size_t N>
 | 
			
		||||
  /* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
 | 
			
		||||
      : Data(Arr.data()), Length(N) {}
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a C array.
 | 
			
		||||
  template <size_t N>
 | 
			
		||||
  // NOLINTNEXTLINE(*c-arrays*)
 | 
			
		||||
  /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
 | 
			
		||||
 | 
			
		||||
  /// Construct an ArrayRef from a std::initializer_list.
 | 
			
		||||
  /* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
 | 
			
		||||
      : Data(
 | 
			
		||||
            std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
 | 
			
		||||
                                             : std::begin(Vec)),
 | 
			
		||||
        Length(Vec.size()) {}
 | 
			
		||||
      : HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
  /// @name Simple Operations
 | 
			
		||||
  /// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  constexpr iterator begin() const {
 | 
			
		||||
    return Data;
 | 
			
		||||
  }
 | 
			
		||||
  constexpr iterator end() const {
 | 
			
		||||
    return Data + Length;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // These are actually the same as iterator, since ArrayRef only
 | 
			
		||||
  // gives you const iterators.
 | 
			
		||||
  constexpr const_iterator cbegin() const {
 | 
			
		||||
    return Data;
 | 
			
		||||
  }
 | 
			
		||||
  constexpr const_iterator cend() const {
 | 
			
		||||
    return Data + Length;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  constexpr reverse_iterator rbegin() const {
 | 
			
		||||
    return reverse_iterator(end());
 | 
			
		||||
  }
 | 
			
		||||
  constexpr reverse_iterator rend() const {
 | 
			
		||||
    return reverse_iterator(begin());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Check if all elements in the array satisfy the given expression
 | 
			
		||||
  constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
 | 
			
		||||
    return std::all_of(cbegin(), cend(), pred);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// empty - Check if the array is empty.
 | 
			
		||||
  constexpr bool empty() const {
 | 
			
		||||
    return Length == 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  constexpr const T* data() const {
 | 
			
		||||
    return Data;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// size - Get the array size.
 | 
			
		||||
  constexpr size_t size() const {
 | 
			
		||||
    return Length;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// front - Get the first element.
 | 
			
		||||
  /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
 | 
			
		||||
  /// STD_TORCH_CHECK
 | 
			
		||||
  constexpr const T& front() const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        !empty(), "ArrayRef: attempted to access front() of empty list");
 | 
			
		||||
    return Data[0];
 | 
			
		||||
        !this->empty(), "ArrayRef: attempted to access front() of empty list");
 | 
			
		||||
    return this->Data[0];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// back - Get the last element.
 | 
			
		||||
  /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
 | 
			
		||||
  /// STD_TORCH_CHECK
 | 
			
		||||
  constexpr const T& back() const {
 | 
			
		||||
    TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
 | 
			
		||||
    return Data[Length - 1];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// equals - Check for element-wise equality.
 | 
			
		||||
  constexpr bool equals(ArrayRef RHS) const {
 | 
			
		||||
    return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        !this->empty(), "ArrayRef: attempted to access back() of empty list");
 | 
			
		||||
    return this->Data[this->Length - 1];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// slice(n, m) - Take M elements of the array starting at element N
 | 
			
		||||
  /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
 | 
			
		||||
  /// STD_TORCH_CHECK
 | 
			
		||||
  constexpr ArrayRef<T> slice(size_t N, size_t M) const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        N + M <= size(),
 | 
			
		||||
        N + M <= this->size(),
 | 
			
		||||
        "ArrayRef: invalid slice, N = ",
 | 
			
		||||
        N,
 | 
			
		||||
        "; M = ",
 | 
			
		||||
        M,
 | 
			
		||||
        "; size = ",
 | 
			
		||||
        size());
 | 
			
		||||
    return ArrayRef<T>(data() + N, M);
 | 
			
		||||
        this->size());
 | 
			
		||||
    return ArrayRef<T>(this->data() + N, M);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// slice(n) - Chop off the first N elements of the array.
 | 
			
		||||
  /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
 | 
			
		||||
  /// STD_TORCH_CHECK
 | 
			
		||||
  constexpr ArrayRef<T> slice(size_t N) const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
 | 
			
		||||
    return slice(N, size() - N);
 | 
			
		||||
        N <= this->size(),
 | 
			
		||||
        "ArrayRef: invalid slice, N = ",
 | 
			
		||||
        N,
 | 
			
		||||
        "; size = ",
 | 
			
		||||
        this->size());
 | 
			
		||||
    return slice(N, this->size() - N); // should this slice be this->slice?
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
  /// @name Operator Overloads
 | 
			
		||||
  /// @{
 | 
			
		||||
  constexpr const T& operator[](size_t Index) const {
 | 
			
		||||
    return Data[Index];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Vector compatibility
 | 
			
		||||
  /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
 | 
			
		||||
  /// STD_TORCH_CHECK
 | 
			
		||||
  constexpr const T& at(size_t Index) const {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        Index < Length,
 | 
			
		||||
        Index < this->Length,
 | 
			
		||||
        "ArrayRef: invalid index Index = ",
 | 
			
		||||
        Index,
 | 
			
		||||
        "; Length = ",
 | 
			
		||||
        Length);
 | 
			
		||||
    return Data[Index];
 | 
			
		||||
        this->Length);
 | 
			
		||||
    return this->Data[Index];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Disallow accidental assignment from a temporary.
 | 
			
		||||
@ -253,13 +160,6 @@ class ArrayRef final {
 | 
			
		||||
  std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
 | 
			
		||||
      std::initializer_list<U>) = delete;
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
  /// @name Expensive Operations
 | 
			
		||||
  /// @{
 | 
			
		||||
  std::vector<T> vec() const {
 | 
			
		||||
    return std::vector<T>(Data, Data + Length);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,5 @@
 | 
			
		||||
#include <c10/util/complex.h>
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
 | 
			
		||||
// Note [ Complex Square root in libc++]
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
// In libc++ complex square root is computed using polar form
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,6 @@
 | 
			
		||||
#include <unistd.h>
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <condition_variable>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
 | 
			
		||||
@ -74,7 +74,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    code.append("      " + OutType + "* op = &out[rangeIndex * block_size];")
 | 
			
		||||
    for i in range(0, uf):
 | 
			
		||||
    for i in range(uf):
 | 
			
		||||
        j = 8 * i
 | 
			
		||||
        code.append("      __m256 vop" + str(j) + " = _mm256_setzero_ps();")
 | 
			
		||||
 | 
			
		||||
@ -158,7 +158,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
 | 
			
		||||
        "&input[idx_pref_T0 * fused_block_size];"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for i in range(0, uf):
 | 
			
		||||
    for i in range(uf):
 | 
			
		||||
        j = 8 * i
 | 
			
		||||
        cachelinesize = 64
 | 
			
		||||
        byteoffset = sizeof[InType] * j
 | 
			
		||||
@ -170,7 +170,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
 | 
			
		||||
        code.append("      if (!normalize_by_lengths || length == 0) {")
 | 
			
		||||
    else:
 | 
			
		||||
        code.append("      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {")
 | 
			
		||||
    for i in range(0, uf):
 | 
			
		||||
    for i in range(uf):
 | 
			
		||||
        j = 8 * i
 | 
			
		||||
        code.append("        _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
 | 
			
		||||
    code.append("      } else {")
 | 
			
		||||
@ -181,7 +181,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
 | 
			
		||||
        code.append(
 | 
			
		||||
            "        __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);"
 | 
			
		||||
        )
 | 
			
		||||
    for i in range(0, uf):
 | 
			
		||||
    for i in range(uf):
 | 
			
		||||
        j = 8 * i
 | 
			
		||||
        code.append(
 | 
			
		||||
            "        _mm256_storeu_ps(&op["
 | 
			
		||||
 | 
			
		||||
@ -159,8 +159,6 @@ ignore = [
 | 
			
		||||
    "EXE001",
 | 
			
		||||
    "F405",
 | 
			
		||||
    "FURB122", # writelines
 | 
			
		||||
    # these ignores are from flake8-logging-format; please fix!
 | 
			
		||||
    "G101",
 | 
			
		||||
    # these ignores are from ruff NPY; please fix!
 | 
			
		||||
    "NPY002",
 | 
			
		||||
    # these ignores are from ruff PERF; please fix!
 | 
			
		||||
@ -204,14 +202,10 @@ select = [
 | 
			
		||||
    "NPY",
 | 
			
		||||
    "PERF",
 | 
			
		||||
    "PGH004",
 | 
			
		||||
    "PIE790",
 | 
			
		||||
    "PIE794",
 | 
			
		||||
    "PIE800",
 | 
			
		||||
    "PIE804",
 | 
			
		||||
    "PIE807",
 | 
			
		||||
    "PIE810",
 | 
			
		||||
    "PIE",
 | 
			
		||||
    "PLC0131", # type bivariance
 | 
			
		||||
    "PLC0132", # type param mismatch
 | 
			
		||||
    "PLC1802", # len({expression}) used as condition without comparison
 | 
			
		||||
    "PLC0205", # string as __slots__
 | 
			
		||||
    "PLC3002", # unnecessary-direct-lambda-call
 | 
			
		||||
    "PLE",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								pyrefly.toml
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								pyrefly.toml
									
									
									
									
									
								
							@ -5,6 +5,7 @@ python-version = "3.12"
 | 
			
		||||
project-includes = [
 | 
			
		||||
    "torch",
 | 
			
		||||
    "caffe2",
 | 
			
		||||
    "tools",
 | 
			
		||||
    "test/test_bundled_images.py",
 | 
			
		||||
    "test/test_bundled_inputs.py",
 | 
			
		||||
    "test/test_complex.py",
 | 
			
		||||
@ -22,12 +23,13 @@ project-includes = [
 | 
			
		||||
project-excludes = [
 | 
			
		||||
  # ==== below will be enabled directory by directory ====
 | 
			
		||||
  # ==== to test Pyrefly on a specific directory, simply comment it out ====
 | 
			
		||||
  "torch/_inductor/runtime",
 | 
			
		||||
  "torch/_inductor/codegen/triton.py",
 | 
			
		||||
  "torch/_inductor/runtime/triton_helpers.py",
 | 
			
		||||
  "torch/_inductor/runtime/triton_heuristics.py",
 | 
			
		||||
  "torch/_inductor/runtime/halide_helpers.py",
 | 
			
		||||
  "tools/linter/adapters/test_device_bias_linter.py",
 | 
			
		||||
  "tools/code_analyzer/gen_operators_yaml.py",
 | 
			
		||||
  # formatting issues, will turn on after adjusting where suppressions can be
 | 
			
		||||
  # in import statements
 | 
			
		||||
  "tools/flight_recorder/components/types.py",
 | 
			
		||||
  "torch/linalg/__init__.py",
 | 
			
		||||
  "torch/package/importer.py",
 | 
			
		||||
  "torch/package/_package_pickler.py",
 | 
			
		||||
@ -42,17 +44,6 @@ project-excludes = [
 | 
			
		||||
  "torch/distributed/elastic/metrics/__init__.py",
 | 
			
		||||
  "torch/_inductor/fx_passes/bucketing.py",
 | 
			
		||||
  # ====
 | 
			
		||||
  "benchmarks/instruction_counts/main.py",
 | 
			
		||||
  "benchmarks/instruction_counts/definitions/setup.py",
 | 
			
		||||
  "benchmarks/instruction_counts/applications/ci.py",
 | 
			
		||||
  "benchmarks/instruction_counts/core/api.py",
 | 
			
		||||
  "benchmarks/instruction_counts/core/expand.py",
 | 
			
		||||
  "benchmarks/instruction_counts/core/types.py",
 | 
			
		||||
  "benchmarks/instruction_counts/core/utils.py",
 | 
			
		||||
  "benchmarks/instruction_counts/definitions/standard.py",
 | 
			
		||||
  "benchmarks/instruction_counts/definitions/setup.py",
 | 
			
		||||
  "benchmarks/instruction_counts/execution/runner.py",
 | 
			
		||||
  "benchmarks/instruction_counts/execution/work.py",
 | 
			
		||||
  "torch/include/**",
 | 
			
		||||
  "torch/csrc/**",
 | 
			
		||||
  "torch/distributed/elastic/agent/server/api.py",
 | 
			
		||||
@ -139,3 +130,4 @@ errors.bad-param-name-override = false
 | 
			
		||||
errors.implicit-import = false
 | 
			
		||||
permissive-ignores = true
 | 
			
		||||
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
 | 
			
		||||
search-path = ["tools/experimental"]
 | 
			
		||||
 | 
			
		||||
@ -190,7 +190,7 @@ class TestActivationSparsifier(TestCase):
 | 
			
		||||
                if features is None:
 | 
			
		||||
                    assert torch.all(mask * input_data == output)
 | 
			
		||||
                else:
 | 
			
		||||
                    for feature_idx in range(0, len(features)):
 | 
			
		||||
                    for feature_idx in range(len(features)):
 | 
			
		||||
                        feature = torch.Tensor(
 | 
			
		||||
                            [features[feature_idx]], device=input_data.device
 | 
			
		||||
                        ).long()
 | 
			
		||||
@ -378,7 +378,7 @@ class TestActivationSparsifier(TestCase):
 | 
			
		||||
        # some dummy data
 | 
			
		||||
        data_list = []
 | 
			
		||||
        num_data_points = 5
 | 
			
		||||
        for _ in range(0, num_data_points):
 | 
			
		||||
        for _ in range(num_data_points):
 | 
			
		||||
            rand_data = torch.randn(16, 1, 28, 28)
 | 
			
		||||
            activation_sparsifier.model(rand_data)
 | 
			
		||||
            data_list.append(rand_data)
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,7 @@ class TestBaseDataScheduler(TestCase):
 | 
			
		||||
 | 
			
		||||
        # checking step count
 | 
			
		||||
        step_cnt = 5
 | 
			
		||||
        for _ in range(0, step_cnt):
 | 
			
		||||
        for _ in range(step_cnt):
 | 
			
		||||
            sparsifier.step()
 | 
			
		||||
            scheduler.step()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -123,7 +123,7 @@ class _BaseDataSparsiferTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
        step_count = 3
 | 
			
		||||
 | 
			
		||||
        for _ in range(0, step_count):
 | 
			
		||||
        for _ in range(step_count):
 | 
			
		||||
            sparsifier.step()
 | 
			
		||||
        for some_data in all_data:
 | 
			
		||||
            name, data, _ = self._get_name_data_config(some_data)
 | 
			
		||||
 | 
			
		||||
@ -472,8 +472,8 @@ class TestNearlyDiagonalSparsifier(TestCase):
 | 
			
		||||
        else:
 | 
			
		||||
            height, width = mask.shape
 | 
			
		||||
            dist_to_diagonal = nearliness // 2
 | 
			
		||||
            for row in range(0, height):
 | 
			
		||||
                for col in range(0, width):
 | 
			
		||||
            for row in range(height):
 | 
			
		||||
                for col in range(width):
 | 
			
		||||
                    if abs(row - col) <= dist_to_diagonal:
 | 
			
		||||
                        assert mask[row, col] == 1
 | 
			
		||||
                    else:
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
 | 
			
		||||
  ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										52
									
								
								test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
 | 
			
		||||
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
using torch::headeronly::HeaderOnlyArrayRef;
 | 
			
		||||
 | 
			
		||||
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
 | 
			
		||||
  HeaderOnlyArrayRef<float> arr;
 | 
			
		||||
  ASSERT_TRUE(arr.empty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
 | 
			
		||||
  float val = 5.0f;
 | 
			
		||||
  HeaderOnlyArrayRef<float> arr(val);
 | 
			
		||||
  ASSERT_FALSE(arr.empty());
 | 
			
		||||
  EXPECT_EQ(arr.size(), 1);
 | 
			
		||||
  EXPECT_EQ(arr[0], val);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
 | 
			
		||||
  std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
 | 
			
		||||
  HeaderOnlyArrayRef<int> arr(vec);
 | 
			
		||||
  ASSERT_FALSE(arr.empty());
 | 
			
		||||
  EXPECT_EQ(arr.size(), 7);
 | 
			
		||||
  for (size_t i = 0; i < arr.size(); i++) {
 | 
			
		||||
    EXPECT_EQ(arr[i], i + 1);
 | 
			
		||||
    EXPECT_EQ(arr.at(i), i + 1);
 | 
			
		||||
  }
 | 
			
		||||
  EXPECT_EQ(arr.front(), 1);
 | 
			
		||||
  EXPECT_EQ(arr.back(), 7);
 | 
			
		||||
  ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
 | 
			
		||||
  std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
 | 
			
		||||
  HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
 | 
			
		||||
  auto res_vec = arr.vec();
 | 
			
		||||
  for (size_t i = 0; i < vec.size(); i++) {
 | 
			
		||||
    EXPECT_EQ(vec[i], res_vec[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
 | 
			
		||||
  std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
 | 
			
		||||
  HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
 | 
			
		||||
  auto res_vec = arr.vec();
 | 
			
		||||
  for (size_t i = 0; i < res_vec.size(); i++) {
 | 
			
		||||
    EXPECT_EQ(vec[i + 3], res_vec[i]);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -311,10 +311,9 @@ void boxed_fill_infinity(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor my_pad(Tensor t) {
 | 
			
		||||
  std::vector<int64_t> padding = {1, 2, 2, 1};
 | 
			
		||||
  std::string mode = "constant";
 | 
			
		||||
  double value = 0.0;
 | 
			
		||||
  return pad(t, padding, mode, value);
 | 
			
		||||
  return pad(t, {1, 2, 2, 1}, mode, value);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void boxed_my_pad(
 | 
			
		||||
@ -342,6 +341,9 @@ void boxed_my_narrow(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor my_new_empty_dtype_variant(Tensor t) {
 | 
			
		||||
  // Still using a std::vector below even though people can just pass in an
 | 
			
		||||
  // initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
 | 
			
		||||
  // directly.
 | 
			
		||||
  std::vector<int64_t> sizes = {2, 5};
 | 
			
		||||
  auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
 | 
			
		||||
  return new_empty(t, sizes, dtype);
 | 
			
		||||
@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
 | 
			
		||||
  std::vector<int64_t> sizes = {2, 5};
 | 
			
		||||
  auto dtype = std::make_optional(at::ScalarType::Float);
 | 
			
		||||
  return new_zeros(t, sizes, dtype);
 | 
			
		||||
  return new_zeros(t, {2, 5}, dtype);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
 | 
			
		||||
@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor my_amax_vec(Tensor t) {
 | 
			
		||||
  std::vector<int64_t> v = {0,1};
 | 
			
		||||
  return amax(t, v, false);
 | 
			
		||||
  return amax(t, {0,1}, false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
 | 
			
		||||
 | 
			
		||||
@ -164,6 +164,9 @@ class TestIntTuple(TestCase):
 | 
			
		||||
            crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
 | 
			
		||||
        )  # 4 -> (1,0,0) -> 1*8 = 8
 | 
			
		||||
 | 
			
		||||
        # Test with zero-length shape and strides
 | 
			
		||||
        self.assertEqual(crd2idx(0, (), ()), 0)  # 0 -> () -> sum([]) = 0
 | 
			
		||||
 | 
			
		||||
    def test_idx2crd_basic(self):
 | 
			
		||||
        # Test basic int/int case
 | 
			
		||||
        self.assertEqual(idx2crd(2, 5, 1), 2)
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 | 
			
		||||
            dist.init_process_group(
 | 
			
		||||
                store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
 | 
			
		||||
            )
 | 
			
		||||
            group = list(range(0, self.world_size))
 | 
			
		||||
            group = list(range(self.world_size))
 | 
			
		||||
            group_id = dist.group.WORLD
 | 
			
		||||
            self._test_all_gather(
 | 
			
		||||
                group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16
 | 
			
		||||
@ -94,7 +94,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 | 
			
		||||
            dist.init_process_group(
 | 
			
		||||
                store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
 | 
			
		||||
            )
 | 
			
		||||
            group = list(range(0, self.world_size))
 | 
			
		||||
            group = list(range(self.world_size))
 | 
			
		||||
            group_id = dist.group.WORLD
 | 
			
		||||
            self._test_all_gather(
 | 
			
		||||
                group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16
 | 
			
		||||
@ -111,7 +111,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 | 
			
		||||
            dist.init_process_group(
 | 
			
		||||
                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
 | 
			
		||||
            )
 | 
			
		||||
            group = list(range(0, self.world_size))
 | 
			
		||||
            group = list(range(self.world_size))
 | 
			
		||||
            group_id = dist.new_group(range(self.world_size))
 | 
			
		||||
            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
 | 
			
		||||
            self._test_all_to_all(
 | 
			
		||||
@ -135,7 +135,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 | 
			
		||||
            dist.init_process_group(
 | 
			
		||||
                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
 | 
			
		||||
            )
 | 
			
		||||
            group = list(range(0, self.world_size))
 | 
			
		||||
            group = list(range(self.world_size))
 | 
			
		||||
            group_id = dist.new_group(range(self.world_size))
 | 
			
		||||
            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
 | 
			
		||||
            self._test_all_to_all(
 | 
			
		||||
@ -158,7 +158,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 | 
			
		||||
            dist.init_process_group(
 | 
			
		||||
                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
 | 
			
		||||
            )
 | 
			
		||||
            group = list(range(0, self.world_size))
 | 
			
		||||
            group = list(range(self.world_size))
 | 
			
		||||
            group_id = dist.new_group(range(self.world_size))
 | 
			
		||||
            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
 | 
			
		||||
            self._test_all_to_all_single(
 | 
			
		||||
@ -181,7 +181,7 @@ if BACKEND == "gloo" or BACKEND == "nccl":
 | 
			
		||||
            dist.init_process_group(
 | 
			
		||||
                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
 | 
			
		||||
            )
 | 
			
		||||
            group = list(range(0, self.world_size))
 | 
			
		||||
            group = list(range(self.world_size))
 | 
			
		||||
            group_id = dist.new_group(range(self.world_size))
 | 
			
		||||
            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
 | 
			
		||||
            self._test_all_to_all_single(
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ if TEST_WITH_DEV_DBG_ASAN:
 | 
			
		||||
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
 | 
			
		||||
    shards_metadata = []
 | 
			
		||||
    local_shards = []
 | 
			
		||||
    for idx in range(0, world_size * shards_per_rank):
 | 
			
		||||
    for idx in range(world_size * shards_per_rank):
 | 
			
		||||
        shard_rank = idx // shards_per_rank
 | 
			
		||||
        shard_md = ShardMetadata(
 | 
			
		||||
            shard_offsets=[idx * shard_size],
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,7 @@ if TEST_WITH_DEV_DBG_ASAN:
 | 
			
		||||
def create_sharded_tensor(rank, world_size, shards_per_rank):
 | 
			
		||||
    shards_metadata = []
 | 
			
		||||
    local_shards = []
 | 
			
		||||
    for idx in range(0, world_size * shards_per_rank):
 | 
			
		||||
    for idx in range(world_size * shards_per_rank):
 | 
			
		||||
        shard_rank = idx // shards_per_rank
 | 
			
		||||
        shard_md = ShardMetadata(
 | 
			
		||||
            shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
 | 
			
		||||
 | 
			
		||||
@ -633,7 +633,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
 | 
			
		||||
        worker_group = agent.get_worker_group()
 | 
			
		||||
 | 
			
		||||
        num_restarts = 3
 | 
			
		||||
        for _ in range(0, num_restarts):
 | 
			
		||||
        for _ in range(num_restarts):
 | 
			
		||||
            agent._restart_workers(worker_group)
 | 
			
		||||
            self.assertEqual(WorkerState.HEALTHY, worker_group.state)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -146,7 +146,7 @@ def echo_large(size: int) -> dict[int, str]:
 | 
			
		||||
    returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"})
 | 
			
		||||
    """
 | 
			
		||||
    out = {}
 | 
			
		||||
    for idx in range(0, size):
 | 
			
		||||
    for idx in range(size):
 | 
			
		||||
        out[idx] = f"test{idx}"
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -191,7 +191,7 @@ if not (IS_WINDOWS or IS_MACOS or IS_ARM64):
 | 
			
		||||
        """
 | 
			
		||||
        client = timer.FileTimerClient(file_path)
 | 
			
		||||
        sem.release()
 | 
			
		||||
        for _ in range(0, n):
 | 
			
		||||
        for _ in range(n):
 | 
			
		||||
            client.acquire("test_scope", 0)
 | 
			
		||||
            time.sleep(interval)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user