mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			8 Commits
		
	
	
		
			annotate_b
			...
			gh/aorenst
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d825116f62 | |||
| c2d1314604 | |||
| 0723299640 | |||
| 287570c1aa | |||
| 803557ac40 | |||
| 9fabe06fa8 | |||
| 897f09de0b | |||
| f4fcc05fbd | 
@ -143,7 +143,7 @@ def sample_vllm_test_library():
 | 
			
		||||
                "pytest -v -s compile/test_decorator.py",
 | 
			
		||||
            ],
 | 
			
		||||
        },
 | 
			
		||||
        "vllm_language_model_test_extended_generation_28_failure_test": {
 | 
			
		||||
        "vllm_languagde_model_test_extended_generation_28_failure_test": {
 | 
			
		||||
            "title": "Language Models Test (Extended Generation) 2.8 release failure",
 | 
			
		||||
            "id": "vllm_languagde_model_test_extended_generation_28_failure_test",
 | 
			
		||||
            "package_install": [
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ class VllmBuildParameters:
 | 
			
		||||
    # DOCKERFILE_PATH: path to Dockerfile used when use_local_dockerfile is True"
 | 
			
		||||
    use_local_dockerfile: bool = env_bool_field("USE_LOCAL_DOCKERFILE", True)
 | 
			
		||||
    dockerfile_path: Path = env_path_field(
 | 
			
		||||
        "DOCKERFILE_PATH", ".github/ci_configs/vllm/Dockerfile"
 | 
			
		||||
        "DOCKERFILE_PATH", ".github/ci_configs/vllm/Dockerfile.tmp_vllm"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # the cleaning script to remove torch dependencies from pip
 | 
			
		||||
 | 
			
		||||
@ -187,22 +187,19 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
 | 
			
		||||
            export USE_CUFILE=0
 | 
			
		||||
        else
 | 
			
		||||
            DEPS_LIST+=(
 | 
			
		||||
                "/usr/local/cuda/lib64/libnvToolsExt.so.1"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcublas.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcublasLt.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libcudart.so.12"
 | 
			
		||||
                "/usr/local/cuda/lib64/libnvrtc.so.12"
 | 
			
		||||
                "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
 | 
			
		||||
            DEPS_SONAME+=(
 | 
			
		||||
                "libnvToolsExt.so.1"
 | 
			
		||||
                "libcublas.so.12"
 | 
			
		||||
                "libcublasLt.so.12"
 | 
			
		||||
                "libcudart.so.12"
 | 
			
		||||
                "libnvrtc.so.12"
 | 
			
		||||
                "libcupti.so.12")
 | 
			
		||||
 | 
			
		||||
            if [[ $CUDA_VERSION != 12.9* ]]; then
 | 
			
		||||
                DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
 | 
			
		||||
                DEPS_SONAME+=("libnvToolsExt.so.1")
 | 
			
		||||
            fi
 | 
			
		||||
        fi
 | 
			
		||||
    else
 | 
			
		||||
        echo "Using nvidia libs from pypi."
 | 
			
		||||
 | 
			
		||||
@ -1615,7 +1615,6 @@ test_operator_benchmark() {
 | 
			
		||||
  TEST_REPORTS_DIR=$(pwd)/test/test-reports
 | 
			
		||||
  mkdir -p "$TEST_REPORTS_DIR"
 | 
			
		||||
  TEST_DIR=$(pwd)
 | 
			
		||||
  ARCH=$(uname -m)
 | 
			
		||||
 | 
			
		||||
  test_inductor_set_cpu_affinity
 | 
			
		||||
 | 
			
		||||
@ -1630,7 +1629,7 @@ test_operator_benchmark() {
 | 
			
		||||
  pip_install pandas
 | 
			
		||||
  python check_perf_csv.py \
 | 
			
		||||
      --actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \
 | 
			
		||||
      --expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv"
 | 
			
		||||
      --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
test_operator_microbenchmark() {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/ISSUE_TEMPLATE/ci-sev.md
									
									
									
									
										vendored
									
									
								
							@ -8,7 +8,6 @@ assignees: ''
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
> NOTE: Remember to label this issue with "`ci: sev`"
 | 
			
		||||
>       If you want autorevert to be disabled, keep the ci: disable-autorevert label
 | 
			
		||||
 | 
			
		||||
 <!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open -->
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/ISSUE_TEMPLATE/disable-autorevert.md
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
---
 | 
			
		||||
name: "D❌\U0001F519 ISABLE AUTOREVERT"
 | 
			
		||||
name: DISABLE AUTOREVERT
 | 
			
		||||
about: Disables autorevert when open
 | 
			
		||||
title: "[DISABLE AUTOREVERT]"
 | 
			
		||||
title: "❌\U0001F519 [DISABLE AUTOREVERT]"
 | 
			
		||||
labels: 'ci: disable-autorevert'
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ runs:
 | 
			
		||||
          cd .ci/lumen_cli
 | 
			
		||||
          python3 -m pip install -e .
 | 
			
		||||
        )
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=10)"
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=6)"
 | 
			
		||||
        export MAX_JOBS
 | 
			
		||||
 | 
			
		||||
        # Split the comma-separated list and build each target
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
 | 
			
		||||
87ff22e49ed0e92576c4935ccb8c143daac4a3cd
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
faffd5cf673615583da6517275e361cb3dbc77e6
 | 
			
		||||
966da7e46f65d6d49df3e31214470a4fe5cc8e66
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
e5192819208c4d68194844b7dfafbc00020d0dea
 | 
			
		||||
0ad9951c416d33c5da4f7a504fb162cbe62386f5
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
0fa6e3129e61143224663e1ec67980d12b7ec4eb
 | 
			
		||||
2a9138a26ee257fef05310ad3fecf7c55fe80d73
 | 
			
		||||
 | 
			
		||||
@ -1,41 +1,59 @@
 | 
			
		||||
# TODO(elainwy): remove this file after the torch nightly dockerfile is in sync in vllm repo
 | 
			
		||||
# The vLLM Dockerfile is used to construct vLLM image against torch nightly and torch main that can be directly used for testing
 | 
			
		||||
 | 
			
		||||
ARG CUDA_VERSION=12.8.1
 | 
			
		||||
ARG PYTHON_VERSION=3.12
 | 
			
		||||
 | 
			
		||||
# BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine,
 | 
			
		||||
# by default, it uses the torch-nightly-base stage from this docker image
 | 
			
		||||
ARG BUILD_BASE_IMAGE=torch-nightly-base
 | 
			
		||||
 | 
			
		||||
# FINAL_BASE_IMAGE: used to set up vllm-instaled environment and build flashinfer,
 | 
			
		||||
# by default, it uses devel-ubuntu22.04 official image.
 | 
			
		||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
 | 
			
		||||
 | 
			
		||||
# The logic is copied from https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile
 | 
			
		||||
ARG GET_PIP_URL="https://bootstrap.pypa.io/get-pip.py"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### TORCH NIGHTLY BASE IMAGE ####################
 | 
			
		||||
# A base image for building vLLM with devel ubuntu 22.04, this is mainly used to build vllm in vllm builtkite ci
 | 
			
		||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 as torch-nightly-base
 | 
			
		||||
 | 
			
		||||
ARG CUDA_VERSION
 | 
			
		||||
ARG PYTHON_VERSION
 | 
			
		||||
ARG GET_PIP_URL
 | 
			
		||||
 | 
			
		||||
# Install system dependencies and uv, then create Python virtual environment
 | 
			
		||||
# Install Python and other dependencies
 | 
			
		||||
RUN apt-get update -y \
 | 
			
		||||
    && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
 | 
			
		||||
    && curl -LsSf https://astral.sh/uv/install.sh | sh \
 | 
			
		||||
    && $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
 | 
			
		||||
    && rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
 | 
			
		||||
    && ln -s /opt/venv/bin/python3 /usr/bin/python3 \
 | 
			
		||||
    && ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
 | 
			
		||||
    && ln -s /opt/venv/bin/pip /usr/bin/pip \
 | 
			
		||||
    && apt-get install -y ccache software-properties-common git curl wget sudo vim \
 | 
			
		||||
    && add-apt-repository -y ppa:deadsnakes/ppa \
 | 
			
		||||
    && apt-get update -y \
 | 
			
		||||
    && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
 | 
			
		||||
    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
 | 
			
		||||
    && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
 | 
			
		||||
    && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
 | 
			
		||||
    && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
 | 
			
		||||
    && python3 --version && python3 -m pip --version
 | 
			
		||||
 | 
			
		||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
 | 
			
		||||
# as it was causing spam when compiling the CUTLASS kernels
 | 
			
		||||
RUN apt-get install -y gcc-10 g++-10
 | 
			
		||||
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
 | 
			
		||||
RUN <<EOF
 | 
			
		||||
gcc --version
 | 
			
		||||
EOF
 | 
			
		||||
# Ensure gcc >= 10 to avoid CUTLASS issues (bug 92519)
 | 
			
		||||
RUN current_gcc_version=$(gcc -dumpversion | cut -f1 -d.) && \
 | 
			
		||||
    if command -v apt-get >/dev/null; then \
 | 
			
		||||
        if [ "$current_gcc_version" -lt 10 ]; then \
 | 
			
		||||
            echo "GCC version is $current_gcc_version, installing gcc-10..."; \
 | 
			
		||||
            apt-get update \
 | 
			
		||||
            && apt-get install -y gcc-10 g++-10 \
 | 
			
		||||
            && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 \
 | 
			
		||||
            && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100; \
 | 
			
		||||
        else \
 | 
			
		||||
            echo "GCC version is $current_gcc_version, no need to install gcc-10."; \
 | 
			
		||||
        fi \
 | 
			
		||||
    fi \
 | 
			
		||||
    && gcc --version && g++ --version
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs
 | 
			
		||||
# install uv for faster pip installs
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    python3 -m pip install uv==0.8.4
 | 
			
		||||
 | 
			
		||||
@ -43,32 +61,36 @@ ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
 | 
			
		||||
ENV UV_LINK_MODE=copy
 | 
			
		||||
 | 
			
		||||
#################### TORCH NIGHTLY  BASE IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
# A base image for building vLLM with torch nightly or torch wheels
 | 
			
		||||
# prepare basic build environment
 | 
			
		||||
FROM ${BUILD_BASE_IMAGE} AS base
 | 
			
		||||
USER root
 | 
			
		||||
 | 
			
		||||
ARG CUDA_VERSION
 | 
			
		||||
ARG PYTHON_VERSION
 | 
			
		||||
 | 
			
		||||
# Only work with PyTorch manylinux builder
 | 
			
		||||
# TODO (huydhn): Only work with PyTorch manylinux builder
 | 
			
		||||
ENV PATH="/opt/python/cp312-cp312/bin:${PATH}"
 | 
			
		||||
 | 
			
		||||
# Install some system dependencies and double check python version
 | 
			
		||||
RUN if command -v apt-get >/dev/null; then \
 | 
			
		||||
        apt-get update -y \
 | 
			
		||||
        && apt-get install -y ccache software-properties-common git wget sudo vim; \
 | 
			
		||||
        && apt-get install -y ccache software-properties-common git curl wget sudo vim; \
 | 
			
		||||
    else \
 | 
			
		||||
        dnf install -y git wget sudo; \
 | 
			
		||||
        dnf install -y git curl wget sudo; \
 | 
			
		||||
    fi \
 | 
			
		||||
    && python3 --version && python3 -m pip --version
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs if not existed
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    python3 -m pip install uv==0.8.4
 | 
			
		||||
 | 
			
		||||
    if ! python3 -m uv --version >/dev/null 2>&1; then \
 | 
			
		||||
        python3 -m pip install uv==0.8.4; \
 | 
			
		||||
    fi
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
 | 
			
		||||
@ -76,15 +98,15 @@ ENV UV_LINK_MODE=copy
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
 | 
			
		||||
# Install build and runtime dependencies
 | 
			
		||||
# install build and runtime dependencies
 | 
			
		||||
COPY requirements/common.txt requirements/common.txt
 | 
			
		||||
COPY use_existing_torch.py use_existing_torch.py
 | 
			
		||||
COPY pyproject.toml pyproject.toml
 | 
			
		||||
 | 
			
		||||
# Install build and runtime dependencies without stable torch version
 | 
			
		||||
# install build and runtime dependencies without stable torch version
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
 | 
			
		||||
# Default mount file as placeholder, this just avoid the mount error
 | 
			
		||||
# default mount file as placeholder, this just avoid the mount error
 | 
			
		||||
# change to a different vllm folder if this does not exist anymore
 | 
			
		||||
ARG TORCH_WHEELS_PATH="./requirements"
 | 
			
		||||
ARG PINNED_TORCH_VERSION
 | 
			
		||||
@ -116,36 +138,56 @@ RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/common.txt
 | 
			
		||||
 | 
			
		||||
# Must put before installing xformers, so it can install the correct version of xfomrers.
 | 
			
		||||
ARG xformers_cuda_arch_list='7.5;8.0+PTX;9.0a'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${xformers_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
ARG max_jobs=16
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
 | 
			
		||||
    export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
 | 
			
		||||
    git clone https://github.com/facebookresearch/xformers.git
 | 
			
		||||
RUN echo ${TORCH_CUDA_ARCH_LIST}
 | 
			
		||||
RUN echo ${MAX_JOBS}
 | 
			
		||||
RUN pip freeze | grep -E 'ninja'
 | 
			
		||||
 | 
			
		||||
    pushd xformers
 | 
			
		||||
    git checkout v0.0.32.post2
 | 
			
		||||
    git submodule update --init --recursive
 | 
			
		||||
    python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose
 | 
			
		||||
    popd
 | 
			
		||||
# Build xformers with cuda and torch nightly/wheel
 | 
			
		||||
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
 | 
			
		||||
# sha for https://github.com/facebookresearch/xformers/tree/v0.0.32.post2
 | 
			
		||||
ARG XFORMERS_COMMIT=5d4b92a5e5a9c6c6d4878283f47d82e17995b468
 | 
			
		||||
ENV CCACHE_DIR=/root/.cache/ccache
 | 
			
		||||
 | 
			
		||||
    rm -rf xformers
 | 
			
		||||
BASH
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/ccache \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    echo 'git clone xformers...' \
 | 
			
		||||
    && git clone https://github.com/facebookresearch/xformers.git --recursive \
 | 
			
		||||
    && cd xformers \
 | 
			
		||||
    && git checkout ${XFORMERS_COMMIT} \
 | 
			
		||||
    && git submodule update --init --recursive \
 | 
			
		||||
    && echo 'finish git clone xformers...' \
 | 
			
		||||
    && rm -rf build \
 | 
			
		||||
    && python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
 | 
			
		||||
    && cd .. \
 | 
			
		||||
    && rm -rf xformers
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system xformers-dist/*.whl
 | 
			
		||||
    uv pip install --system xformers-dist/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
# Build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
 | 
			
		||||
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
 | 
			
		||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
 | 
			
		||||
 | 
			
		||||
RUN cat torch_build_versions.txt
 | 
			
		||||
RUN pip freeze | grep -E 'torch|xformers|torchvision|torchaudio'
 | 
			
		||||
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### WHEEL BUILD IMAGE ####################
 | 
			
		||||
# Image used to build vllm wheel
 | 
			
		||||
FROM base AS build
 | 
			
		||||
ARG TARGETPLATFORM
 | 
			
		||||
 | 
			
		||||
COPY . .
 | 
			
		||||
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
@ -155,17 +197,20 @@ ARG GIT_REPO_CHECK=0
 | 
			
		||||
RUN --mount=type=bind,source=.git,target=.git \
 | 
			
		||||
    if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
 | 
			
		||||
 | 
			
		||||
# Max jobs used by Ninja to build extensions
 | 
			
		||||
ARG max_jobs=16
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
ARG nvcc_threads=8
 | 
			
		||||
ARG nvcc_threads=4
 | 
			
		||||
ENV NVCC_THREADS=$nvcc_threads
 | 
			
		||||
ARG torch_cuda_arch_list='8.0 8.6 8.9 9.0'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
ARG USE_SCCACHE
 | 
			
		||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
 | 
			
		||||
ARG SCCACHE_REGION_NAME=us-west-2
 | 
			
		||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
 | 
			
		||||
 | 
			
		||||
# Use sccache to speed up compilation
 | 
			
		||||
# if USE_SCCACHE is set, use sccache to speed up compilation
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    --mount=type=bind,source=.git,target=.git \
 | 
			
		||||
    if [ "$USE_SCCACHE" = "1" ]; then \
 | 
			
		||||
@ -190,9 +235,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
        && sccache --show-stats; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
ARG torch_cuda_arch_list='8.0 8.6 8.9 9.0'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
ARG vllm_target_device="cuda"
 | 
			
		||||
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
 | 
			
		||||
ENV CCACHE_DIR=/root/.cache/ccache
 | 
			
		||||
@ -206,10 +248,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
 | 
			
		||||
        export VLLM_DOCKER_BUILD_CONTEXT=1 && \
 | 
			
		||||
        python3 setup.py bdist_wheel --dist-dir=vllm-dist --py-limited-api=cp38; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
RUN echo "[INFO] Listing current directory:" && \
 | 
			
		||||
    ls -al && \
 | 
			
		||||
    echo "[INFO] Showing torch_build_versions.txt content:" && \
 | 
			
		||||
    cat torch_build_versions.txt
 | 
			
		||||
 | 
			
		||||
#################### WHEEL BUILD IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################### VLLM INSTALLED IMAGE ####################
 | 
			
		||||
# Setup clean environment for vLLM for test and api server using ubuntu22.04 with AOT flashinfer
 | 
			
		||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
 | 
			
		||||
USER root
 | 
			
		||||
 | 
			
		||||
@ -217,7 +266,7 @@ ARG CUDA_VERSION
 | 
			
		||||
ARG PYTHON_VERSION
 | 
			
		||||
ARG GET_PIP_URL
 | 
			
		||||
 | 
			
		||||
# Only work with PyTorch manylinux builder
 | 
			
		||||
# TODO (huydhn): Only work with PyTorch manylinux builder
 | 
			
		||||
ENV PATH="/opt/python/cp312-cp312/bin:${PATH}"
 | 
			
		||||
 | 
			
		||||
# prepare for environment starts
 | 
			
		||||
@ -226,19 +275,20 @@ WORKDIR /workspace
 | 
			
		||||
# Install Python and other dependencies
 | 
			
		||||
RUN if command -v apt-get >/dev/null; then \
 | 
			
		||||
        apt-get update -y \
 | 
			
		||||
        && apt-get install -y ccache software-properties-common git sudo vim python3-pip; \
 | 
			
		||||
        && apt-get install -y ccache software-properties-common git curl wget sudo vim \
 | 
			
		||||
        && add-apt-repository -y ppa:deadsnakes/ppa \
 | 
			
		||||
        && apt-get update -y \
 | 
			
		||||
        && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
 | 
			
		||||
        && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
 | 
			
		||||
        && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
 | 
			
		||||
        && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
 | 
			
		||||
        && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION}; \
 | 
			
		||||
    else \
 | 
			
		||||
        dnf install -y git wget sudo; \
 | 
			
		||||
        dnf install -y git curl wget sudo; \
 | 
			
		||||
    fi \
 | 
			
		||||
    && curl -LsSf https://astral.sh/uv/install.sh | sh \
 | 
			
		||||
    && $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
 | 
			
		||||
    && rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
 | 
			
		||||
    && ln -s /opt/venv/bin/python3 /usr/bin/python3 \
 | 
			
		||||
    && ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
 | 
			
		||||
    && ln -s /opt/venv/bin/pip /usr/bin/pip \
 | 
			
		||||
    && python3 --version && python3 -m pip --version
 | 
			
		||||
 | 
			
		||||
# Get the torch versions, and whls used in previous stage
 | 
			
		||||
# Get the torch versions, and whls used in previous stagtes for consistency
 | 
			
		||||
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
 | 
			
		||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
 | 
			
		||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
 | 
			
		||||
@ -247,29 +297,33 @@ RUN echo "[INFO] Listing current directory before torch install step:" && \
 | 
			
		||||
    echo "[INFO] Showing torch_build_versions.txt content:" && \
 | 
			
		||||
    cat torch_build_versions.txt
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs if not existed
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    python3 -m pip install uv==0.8.4
 | 
			
		||||
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
 | 
			
		||||
ENV UV_LINK_MODE=copy
 | 
			
		||||
 | 
			
		||||
# Install build and runtime dependencies, this is needed for flashinfer install
 | 
			
		||||
COPY requirements/build.txt requirements/build.txt
 | 
			
		||||
COPY use_existing_torch.py use_existing_torch.py
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
RUN cat requirements/build.txt
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs if not existed
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    if ! python3 -m uv --version > /dev/null 2>&1; then \
 | 
			
		||||
        python3 -m pip install uv==0.8.4; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
 | 
			
		||||
ENV UV_LINK_MODE=copy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/build.txt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Default mount file as placeholder, this just avoid the mount error
 | 
			
		||||
ARG TORCH_WHEELS_PATH="./requirements"
 | 
			
		||||
# Install torch, torchaudio and torchvision. If TORCH_WHEELS_PATH is default
 | 
			
		||||
# to ./requirements, it will pull the nightly versions using pip. Otherwise,
 | 
			
		||||
# it will use the local wheels from TORCH_WHEELS_PATH
 | 
			
		||||
# Install torch, torchaudio and torchvision
 | 
			
		||||
# if TORCH_WHEELS_PATH is default "./requirements", it will pull the nightly versions using pip using torch_build_versions.txt
 | 
			
		||||
# otherwise, it will use the whls from TORCH_WHEELS_PATH from the host machine
 | 
			
		||||
RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    if [ -n "$TORCH_WHEELS_PATH" ] && [ "$TORCH_WHEELS_PATH" != "./requirements" ] && [ -d "/dist" ] && ls /dist/torch*.whl >/dev/null 2>&1; then \
 | 
			
		||||
@ -290,14 +344,18 @@ RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
# Install xformers wheel from previous stage
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system /wheels/xformers/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
# Build FlashInfer from source
 | 
			
		||||
# Build flashinfer from source.
 | 
			
		||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
 | 
			
		||||
# install package for build flashinfer
 | 
			
		||||
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
 | 
			
		||||
 | 
			
		||||
RUN pip freeze | grep -E 'setuptools|packaging|build'
 | 
			
		||||
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
# Build flashinfer for torch nightly from source around 10 mins
 | 
			
		||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
 | 
			
		||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
 | 
			
		||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    git clone --depth 1 --recursive --shallow-submodules \
 | 
			
		||||
        --branch ${FLASHINFER_GIT_REF} \
 | 
			
		||||
@ -309,7 +367,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    && cd .. \
 | 
			
		||||
    && rm -rf flashinfer
 | 
			
		||||
 | 
			
		||||
# Install FlashInfer
 | 
			
		||||
# install flashinfer python
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system wheels/flashinfer/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
@ -319,6 +377,49 @@ RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm
 | 
			
		||||
################### VLLM INSTALLED IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### UNITTEST IMAGE #############################
 | 
			
		||||
FROM vllm-base as test
 | 
			
		||||
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
 | 
			
		||||
ENV UV_LINK_MODE=copy
 | 
			
		||||
 | 
			
		||||
COPY tests/ tests/
 | 
			
		||||
COPY examples examples
 | 
			
		||||
COPY benchmarks benchmarks
 | 
			
		||||
COPY ./vllm/collect_env.py .
 | 
			
		||||
COPY requirements/common.txt requirements/common.txt
 | 
			
		||||
COPY use_existing_torch.py use_existing_torch.py
 | 
			
		||||
COPY pyproject.toml pyproject.toml
 | 
			
		||||
# Install build and runtime dependencies without stable torch version
 | 
			
		||||
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
 | 
			
		||||
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
 | 
			
		||||
# install packages
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/common.txt
 | 
			
		||||
# enable fast downloads from hf (for testing)
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system hf_transfer
 | 
			
		||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
 | 
			
		||||
 | 
			
		||||
# install development dependencies (for testing)
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -e tests/vllm_test_utils
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/nightly_torch_test.txt
 | 
			
		||||
 | 
			
		||||
# Logging to confirm the torch versions
 | 
			
		||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
 | 
			
		||||
 | 
			
		||||
# Logging to confirm all the packages are installed
 | 
			
		||||
RUN pip freeze
 | 
			
		||||
 | 
			
		||||
#################### UNITTEST IMAGE #############################
 | 
			
		||||
 | 
			
		||||
#################### EXPORT STAGE ####################
 | 
			
		||||
FROM scratch as export-wheels
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -15,8 +15,7 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/inductor-micro-benchmark
 | 
			
		||||
- ciflow/inductor-micro-benchmark-cpu-x86
 | 
			
		||||
- ciflow/inductor-perf-compare
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-x86-zen
 | 
			
		||||
- ciflow/inductor-periodic
 | 
			
		||||
- ciflow/inductor-rocm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/scripts/filter_test_configs.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/scripts/filter_test_configs.py
									
									
									
									
										vendored
									
									
								
							@ -512,8 +512,6 @@ def perform_misc_tasks(
 | 
			
		||||
        "keep-going",
 | 
			
		||||
        branch == MAIN_BRANCH
 | 
			
		||||
        or bool(tag and re.match(r"^trunk/[a-f0-9]{40}$", tag))
 | 
			
		||||
        # Pattern for tags created via manual run on HUD
 | 
			
		||||
        or bool(tag and re.match(r"^ciflow/[^/]+/[a-f0-9]{40}$", tag))
 | 
			
		||||
        or check_for_setting(labels, pr_body, "keep-going"),
 | 
			
		||||
    )
 | 
			
		||||
    set_output(
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -241,11 +241,7 @@ def generate_libtorch_matrix(
 | 
			
		||||
            arches += CUDA_ARCHES
 | 
			
		||||
            arches += ROCM_ARCHES
 | 
			
		||||
        elif os == "windows":
 | 
			
		||||
            # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
 | 
			
		||||
            # in 2.10
 | 
			
		||||
            windows_cuda_arches = CUDA_ARCHES.copy()
 | 
			
		||||
            windows_cuda_arches.remove("12.9")
 | 
			
		||||
            arches += windows_cuda_arches
 | 
			
		||||
            arches += CUDA_ARCHES
 | 
			
		||||
    if libtorch_variants is None:
 | 
			
		||||
        libtorch_variants = [
 | 
			
		||||
            "shared-with-deps",
 | 
			
		||||
@ -309,11 +305,7 @@ def generate_wheels_matrix(
 | 
			
		||||
        if os == "linux":
 | 
			
		||||
            arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES
 | 
			
		||||
        elif os == "windows":
 | 
			
		||||
            # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
 | 
			
		||||
            # in 2.10
 | 
			
		||||
            windows_cuda_arches = CUDA_ARCHES.copy()
 | 
			
		||||
            windows_cuda_arches.remove("12.9")
 | 
			
		||||
            arches += windows_cuda_arches + XPU_ARCHES
 | 
			
		||||
            arches += CUDA_ARCHES + XPU_ARCHES
 | 
			
		||||
        elif os == "linux-aarch64":
 | 
			
		||||
            # Separate new if as the CPU type is different and
 | 
			
		||||
            # uses different build/test scripts
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							@ -2042,6 +2042,10 @@ def validate_revert(
 | 
			
		||||
            f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Raises exception if matching rule is not found, but ignores all status checks
 | 
			
		||||
    find_matching_merge_rule(
 | 
			
		||||
        pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
 | 
			
		||||
    )
 | 
			
		||||
    commit_sha = get_pr_commit_sha(repo, pr)
 | 
			
		||||
    return (author_login, commit_sha)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							@ -37,7 +37,7 @@ on:
 | 
			
		||||
      runner:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: "linux.c7i.2xlarge"
 | 
			
		||||
        default: "linux.2xlarge"
 | 
			
		||||
        description: |
 | 
			
		||||
          Label of the runner this job should run on.
 | 
			
		||||
      test-matrix:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -27,8 +27,9 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: [ '3.12' ]
 | 
			
		||||
        # TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
 | 
			
		||||
        platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129', 'cu130' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129' ]
 | 
			
		||||
        include:
 | 
			
		||||
          - platform: manylinux_2_28_x86_64
 | 
			
		||||
            device: cu128
 | 
			
		||||
@ -38,10 +39,6 @@ jobs:
 | 
			
		||||
            device: cu129
 | 
			
		||||
            manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
 | 
			
		||||
            runner: linux.12xlarge.memory
 | 
			
		||||
          - platform: manylinux_2_28_x86_64
 | 
			
		||||
            device: cu130
 | 
			
		||||
            manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
 | 
			
		||||
            runner: linux.12xlarge.memory
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu128
 | 
			
		||||
            manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
 | 
			
		||||
@ -50,11 +47,6 @@ jobs:
 | 
			
		||||
            device: cu129
 | 
			
		||||
            manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
 | 
			
		||||
            runner: linux.arm64.r7g.12xlarge.memory
 | 
			
		||||
        exclude:
 | 
			
		||||
          # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
 | 
			
		||||
          # xformers is update to support 13.0
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu130
 | 
			
		||||
    name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
 | 
			
		||||
    runs-on: ${{ matrix.runner }}
 | 
			
		||||
    timeout-minutes: 480
 | 
			
		||||
@ -177,12 +169,7 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
 | 
			
		||||
        device: [ 'cu128', 'cu129', 'cu130' ]
 | 
			
		||||
        exclude:
 | 
			
		||||
          # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
 | 
			
		||||
          # xformers is update to support 13.0
 | 
			
		||||
          - platform: manylinux_2_28_aarch64
 | 
			
		||||
            device: cu130
 | 
			
		||||
        device: [ 'cu128', 'cu129' ]
 | 
			
		||||
    env:
 | 
			
		||||
      PLATFORM: ${{ matrix.platform }}
 | 
			
		||||
      BUILD_DEVICE: ${{ matrix.device }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -788,6 +788,256 @@ jobs:
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: debug
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v4.4.0
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-debug
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-debug-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - libtorch-cuda12_9-shared-with-deps-debug-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: debug
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v4.1.7
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-debug
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-debug-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: libtorch-cuda12_9-shared-with-deps-debug-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      LIBTORCH_CONFIG: debug
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: libtorch-cuda12_9-shared-with-deps-debug
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-debug-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										250
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -788,6 +788,256 @@ jobs:
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v4.4.0
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-release
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-release-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - libtorch-cuda12_9-shared-with-deps-release-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
 | 
			
		||||
    timeout-minutes: 360
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          git config --global core.longpaths true
 | 
			
		||||
          git config --global core.symlinks true
 | 
			
		||||
 | 
			
		||||
          # https://git-scm.com/docs/git-fsmonitor--daemon.  The daemon could lock
 | 
			
		||||
          # the directory on Windows and prevent GHA from checking out as reported
 | 
			
		||||
          # in https://github.com/actions/checkout/issues/1018
 | 
			
		||||
          git config --global core.fsmonitor false
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v4.1.7
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-cuda12_9-shared-with-deps-release
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  libtorch-cuda12_9-shared-with-deps-release-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: libtorch-cuda12_9-shared-with-deps-release-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: "12.9"
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      # This is a dummy value for libtorch to work correctly with our batch scripts
 | 
			
		||||
      # without this value pip does not get installed for some reason
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: libtorch-cuda12_9-shared-with-deps-release
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  libtorch-cuda13_0-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1666
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1666
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,132 +0,0 @@
 | 
			
		||||
name: inductor-perf-nightly-rocm-mi300
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm-mi300/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 15 0 * * *
 | 
			
		||||
  # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
 | 
			
		||||
  # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      training:
 | 
			
		||||
        description: Run training (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      inference:
 | 
			
		||||
        description: Run inference (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      default:
 | 
			
		||||
        description: Run inductor_default?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      dynamic:
 | 
			
		||||
        description: Run inductor_dynamic_shapes?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cppwrapper:
 | 
			
		||||
        description: Run inductor_cpp_wrapper?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      freezing_cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs with freezing for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      aotinductor:
 | 
			
		||||
        description: Run aot_inductor for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      maxautotune:
 | 
			
		||||
        description: Run inductor_max_autotune?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      benchmark_configs:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
      opt_out_experiments: lf
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-inductor-benchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: rocm-py3_10-inductor-benchmark-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-inductor-benchmark-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: rocm-py3_10-inductor-benchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs: linux-jammy-rocm-py3_10-inductor-benchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }}
 | 
			
		||||
      timeout-minutes: 720
 | 
			
		||||
      # Disable monitor in perf tests for more investigation
 | 
			
		||||
      disable-monitor: true
 | 
			
		||||
      monitor-log-interval: 10
 | 
			
		||||
      monitor-data-collect-interval: 2
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
@ -1,11 +1,11 @@
 | 
			
		||||
name: inductor-perf-nightly-rocm-mi355
 | 
			
		||||
name: inductor-perf-nightly-rocm
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm-mi355/*
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-rocm/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 15 0 * * *
 | 
			
		||||
    - cron: 0 7 * * 0,3
 | 
			
		||||
  # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
 | 
			
		||||
  # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
@ -59,7 +59,7 @@ on:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355
 | 
			
		||||
        default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
@ -88,27 +88,23 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							@ -118,9 +118,9 @@ jobs:
 | 
			
		||||
        CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
 | 
			
		||||
        echo "Running all other linters"
 | 
			
		||||
        if [ "$CHANGED_FILES" = '*' ]; then
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
        else
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
 | 
			
		||||
        fi
 | 
			
		||||
 | 
			
		||||
  quick-checks:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										39
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										39
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							@ -7,11 +7,9 @@ on:
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      test_mode:
 | 
			
		||||
        type: choice
 | 
			
		||||
        options:
 | 
			
		||||
          - 'short'
 | 
			
		||||
          - 'long'
 | 
			
		||||
          - 'all'
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: 'short'
 | 
			
		||||
        description: tag filter for operator benchmarks, options from long, short, all
 | 
			
		||||
  schedule:
 | 
			
		||||
    # Run at 07:00 UTC every Sunday
 | 
			
		||||
@ -30,25 +28,38 @@ permissions:
 | 
			
		||||
  contents: read
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  x86-opbenchmark-build:
 | 
			
		||||
  opbenchmark-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: x86-opbenchmark-build
 | 
			
		||||
    name: opbenchmark-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
          { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  x86-opbenchmark-test:
 | 
			
		||||
    name: x86-opbenchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: x86-opbenchmark-build
 | 
			
		||||
  opbenchmark-on-demand-build:
 | 
			
		||||
    if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    name: opbenchmark-on-demand-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  opbenchmark-test:
 | 
			
		||||
    name: opbenchmark-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: opbenchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11-build
 | 
			
		||||
      docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -180,13 +180,13 @@ jobs:
 | 
			
		||||
      disable-monitor: false
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  win-vs2022-cuda12_8-py3-build:
 | 
			
		||||
    name: win-vs2022-cuda12.8-py3
 | 
			
		||||
  win-vs2022-cuda12_6-py3-build:
 | 
			
		||||
    name: win-vs2022-cuda12.6-py3
 | 
			
		||||
    uses: ./.github/workflows/_win-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: win-vs2022-cuda12.8-py3
 | 
			
		||||
      cuda-version: "12.8"
 | 
			
		||||
      build-environment: win-vs2022-cuda12.6-py3
 | 
			
		||||
      cuda-version: "12.6"
 | 
			
		||||
      runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/vllm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/vllm.yml
									
									
									
									
										vendored
									
									
								
							@ -46,7 +46,7 @@ jobs:
 | 
			
		||||
      runner: linux.24xlarge.memory
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "vllm_basic_correctness_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config:  "vllm_basic_correctness_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_basic_models_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_entrypoints_test", shard: 1, num_shards: 1,runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_regression_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
@ -54,7 +54,7 @@ jobs:
 | 
			
		||||
          { config: "vllm_pytorch_compilation_unit_tests", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_lora_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_multi_model_test_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu"},
 | 
			
		||||
          { config: "vllm_language_model_test_extended_generation_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu"},
 | 
			
		||||
          { config: "vllm_languagde_model_test_extended_generation_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu"},
 | 
			
		||||
          { config: "vllm_distributed_test_2_gpu_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_lora_test", shard: 0, num_shards: 4, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "vllm_lora_test", shard: 1, num_shards: 4, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -395,4 +395,3 @@ android/pytorch_android_torchvision/.cxx
 | 
			
		||||
CLAUDE.local.md
 | 
			
		||||
/test_*.py
 | 
			
		||||
/debug_*.py
 | 
			
		||||
CLAUDE_CONTEXT/
 | 
			
		||||
 | 
			
		||||
@ -209,46 +209,6 @@ command = [
 | 
			
		||||
    '@{{PATHSFILE}}'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'PYREFLY'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    'torch/**/*.py',
 | 
			
		||||
    'torch/**/*.pyi',
 | 
			
		||||
    'torchgen/**/*.py',
 | 
			
		||||
    'torchgen/**/*.pyi',
 | 
			
		||||
    'functorch/**/*.py',
 | 
			
		||||
    'functorch/**/*.pyi',
 | 
			
		||||
]
 | 
			
		||||
exclude_patterns = []
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pyrefly_linter.py',
 | 
			
		||||
    '--config=pyrefly.toml',
 | 
			
		||||
]
 | 
			
		||||
init_command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pip_init.py',
 | 
			
		||||
    '--dry-run={{DRYRUN}}',
 | 
			
		||||
    'numpy==2.1.0 ; python_version >= "3.12"',
 | 
			
		||||
    'expecttest==0.3.0',
 | 
			
		||||
    'pyrefly==0.36.2',
 | 
			
		||||
    'sympy==1.13.3',
 | 
			
		||||
    'types-requests==2.27.25',
 | 
			
		||||
    'types-pyyaml==6.0.2',
 | 
			
		||||
    'types-tabulate==0.8.8',
 | 
			
		||||
    'types-protobuf==5.29.1.20250403',
 | 
			
		||||
    'types-setuptools==79.0.0.20250422',
 | 
			
		||||
    'types-jinja2==2.11.9',
 | 
			
		||||
    'types-colorama==0.4.6',
 | 
			
		||||
    'filelock==3.18.0',
 | 
			
		||||
    'junitparser==2.1.1',
 | 
			
		||||
    'rich==14.1.0',
 | 
			
		||||
    'optree==0.17.0',
 | 
			
		||||
    'types-openpyxl==3.1.5.20250919',
 | 
			
		||||
    'types-python-dateutil==2.9.0.20251008'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'CLANGTIDY'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
 | 
			
		||||
@ -256,7 +256,6 @@ endif()
 | 
			
		||||
IF(USE_FBGEMM_GENAI)
 | 
			
		||||
  set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
 | 
			
		||||
  set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
 | 
			
		||||
 | 
			
		||||
  if(USE_CUDA)
 | 
			
		||||
    # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
 | 
			
		||||
    # If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
 | 
			
		||||
@ -293,64 +292,58 @@ IF(USE_FBGEMM_GENAI)
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
    target_include_directories(fbgemm_genai PUBLIC
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${fbgemm_genai_mx8mx8bf16_grouped}
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
  else()
 | 
			
		||||
    if(USE_ROCM)
 | 
			
		||||
      # Only include the kernels we want to build to avoid increasing binary size.
 | 
			
		||||
      file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
 | 
			
		||||
        "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
 | 
			
		||||
        "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
 | 
			
		||||
      set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
 | 
			
		||||
 | 
			
		||||
    # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  elseif(USE_ROCM)
 | 
			
		||||
    # Only include the kernels we want to build to avoid increasing binary size.
 | 
			
		||||
    file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
 | 
			
		||||
      "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
 | 
			
		||||
    set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
 | 
			
		||||
      # Add additional HIPCC compiler flags for performance
 | 
			
		||||
      set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -amdgpu-coerce-illegal-types=1
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -enable-post-misched=0
 | 
			
		||||
        -mllvm
 | 
			
		||||
        -greedy-reverse-local-assignment=1
 | 
			
		||||
        -fhip-new-launch-api)
 | 
			
		||||
 | 
			
		||||
    # Add additional HIPCC compiler flags for performance
 | 
			
		||||
    set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -amdgpu-coerce-illegal-types=1
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -enable-post-misched=0
 | 
			
		||||
      -mllvm
 | 
			
		||||
      -greedy-reverse-local-assignment=1
 | 
			
		||||
      -fhip-new-launch-api)
 | 
			
		||||
      # Only compile for gfx942 for now.
 | 
			
		||||
      # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
      set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
 | 
			
		||||
      string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
 | 
			
		||||
      if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
 | 
			
		||||
        list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
 | 
			
		||||
      endif()
 | 
			
		||||
      set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
 | 
			
		||||
 | 
			
		||||
    # Only compile for gfx942 for now.
 | 
			
		||||
    # This is rather hacky, I could not figure out a clean solution :(
 | 
			
		||||
    set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
 | 
			
		||||
    string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
 | 
			
		||||
    if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
 | 
			
		||||
      list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
 | 
			
		||||
      hip_add_library(
 | 
			
		||||
        fbgemm_genai STATIC
 | 
			
		||||
        ${fbgemm_genai_native_rocm_hip}
 | 
			
		||||
        HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
 | 
			
		||||
      set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
 | 
			
		||||
      set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
      target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
 | 
			
		||||
 | 
			
		||||
      target_include_directories(fbgemm_genai PUBLIC
 | 
			
		||||
        # FBGEMM version of Composable Kernel is used due to some customizations
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/composable_kernel/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
        ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
        ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
        ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
      )
 | 
			
		||||
    endif()
 | 
			
		||||
    set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
 | 
			
		||||
 | 
			
		||||
    hip_add_library(
 | 
			
		||||
      fbgemm_genai STATIC
 | 
			
		||||
      ${fbgemm_genai_native_rocm_hip}
 | 
			
		||||
      HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
 | 
			
		||||
    set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
 | 
			
		||||
    set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
    target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
 | 
			
		||||
 | 
			
		||||
    target_include_directories(fbgemm_genai PRIVATE
 | 
			
		||||
      # FBGEMM version of Composable Kernel is used due to some customizations
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/composable_kernel/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/include
 | 
			
		||||
      ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
 | 
			
		||||
      ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
@ -699,6 +692,12 @@ if(USE_CUDA AND NOT USE_ROCM)
 | 
			
		||||
  list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
 | 
			
		||||
  list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
 | 
			
		||||
 | 
			
		||||
  # Add FBGEMM_GENAI include directories for torch_ops.h
 | 
			
		||||
  if(USE_FBGEMM_GENAI)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
 | 
			
		||||
    list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  if($ENV{ATEN_STATIC_CUDA})
 | 
			
		||||
    if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
 | 
			
		||||
      list(APPEND ATen_CUDA_DEPENDENCY_LIBS
 | 
			
		||||
 | 
			
		||||
@ -389,16 +389,37 @@ void fillVersion<DLManagedTensorVersioned>(
 | 
			
		||||
// constructed out of ATen tensor
 | 
			
		||||
template <class T>
 | 
			
		||||
T* toDLPackImpl(const Tensor& src) {
 | 
			
		||||
  auto view = src;
 | 
			
		||||
 | 
			
		||||
  // Detect whether there is need to normalize the strides
 | 
			
		||||
  // Background: gh-83069
 | 
			
		||||
  //
 | 
			
		||||
  // However, normalizing strides can come at a high-cost
 | 
			
		||||
  // to slow down toDLPack conversion 3x, so we
 | 
			
		||||
  // only normalize if needed.
 | 
			
		||||
  //
 | 
			
		||||
  // The following code detects whether the src follows
 | 
			
		||||
  // a continuous pattern. If the src follows such pattern (common-case)
 | 
			
		||||
  // then we do not need to normalize the strides.
 | 
			
		||||
  bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
 | 
			
		||||
  // less common case, try normalizing the strides
 | 
			
		||||
  if (need_normalize_strides) {
 | 
			
		||||
    // create a new tensor with possibly normalized strides
 | 
			
		||||
    // gh-83069
 | 
			
		||||
    auto shape = src.sizes();
 | 
			
		||||
    view = src.as_strided(shape, {1}, src.storage_offset());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
 | 
			
		||||
  atDLMTensor->handle = src;
 | 
			
		||||
  atDLMTensor->handle = view;
 | 
			
		||||
  atDLMTensor->tensor.manager_ctx = atDLMTensor;
 | 
			
		||||
  atDLMTensor->tensor.deleter = &deleter<T>;
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
 | 
			
		||||
  atDLMTensor->tensor.dl_tensor.byte_offset = 0;
 | 
			
		||||
  fillVersion(&atDLMTensor->tensor);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,16 +52,16 @@ struct DLPackTraits {};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct DLPackTraits<DLManagedTensor> {
 | 
			
		||||
  inline static constexpr const char* capsule = "dltensor";
 | 
			
		||||
  inline static constexpr const char* used = "used_dltensor";
 | 
			
		||||
  inline static const char* capsule = "dltensor";
 | 
			
		||||
  inline static const char* used = "used_dltensor";
 | 
			
		||||
  inline static auto toDLPack = at::toDLPack;
 | 
			
		||||
  inline static auto fromDLPack = at::fromDLPack;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct DLPackTraits<DLManagedTensorVersioned> {
 | 
			
		||||
  inline static constexpr const char* capsule = "dltensor_versioned";
 | 
			
		||||
  inline static constexpr const char* used = "used_dltensor_versioned";
 | 
			
		||||
  inline static const char* capsule = "dltensor_versioned";
 | 
			
		||||
  inline static const char* used = "used_dltensor_versioned";
 | 
			
		||||
  inline static auto toDLPack = at::toDLPackVersioned;
 | 
			
		||||
  inline static auto fromDLPack = at::fromDLPackVersioned;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -42,14 +42,8 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool torch_function_mode_enabled() {
 | 
			
		||||
  // Manually flatten because gcc is refusing to inline here.  Note
 | 
			
		||||
  // that we are still calling __tls_get_addr twice here with GCC,
 | 
			
		||||
  // presumably because of
 | 
			
		||||
  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81501 (which says
 | 
			
		||||
  // the fix ships in GCC 16), but forcing inlining still improves
 | 
			
		||||
  // performance.
 | 
			
		||||
  const auto& ptfs = pythonTorchFunctionState;
 | 
			
		||||
  return ptfs.disabled_state_ != TorchFunctionDisabledState::ALL_DISABLED && !ptfs.stack_.empty();
 | 
			
		||||
  return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
 | 
			
		||||
         PythonTorchFunctionTLS::stack_len() > 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This is needed to disambiguate the ternary torch function disabled states
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,6 @@ struct TORCH_API PythonTorchFunctionTLS {
 | 
			
		||||
  TorchFunctionDisabledState disabled_state_ =
 | 
			
		||||
      TorchFunctionDisabledState::ENABLED;
 | 
			
		||||
  std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
 | 
			
		||||
  friend TORCH_API bool torch_function_mode_enabled();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_API bool torch_function_mode_enabled();
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
#include <ATen/CachedTensorUtils.h>
 | 
			
		||||
#include <c10/core/GradMode.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
 | 
			
		||||
namespace at::autocast {
 | 
			
		||||
@ -37,29 +36,10 @@ namespace {
 | 
			
		||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
 | 
			
		||||
using val_type = std::tuple<weakref_type, Tensor>;
 | 
			
		||||
 | 
			
		||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
 | 
			
		||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
 | 
			
		||||
// are not incorrectly reused in gradient-enabled contexts.
 | 
			
		||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
 | 
			
		||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
 | 
			
		||||
  static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
 | 
			
		||||
  return cached_casts_grad_enabled;
 | 
			
		||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
 | 
			
		||||
  static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
 | 
			
		||||
  return cached_casts;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
 | 
			
		||||
  static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
 | 
			
		||||
  return cached_casts_grad_disabled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to get the appropriate cache based on current gradient mode.
 | 
			
		||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
 | 
			
		||||
// preventing incorrect cache hits when gradient mode changes.
 | 
			
		||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
 | 
			
		||||
  return at::GradMode::is_enabled() ?
 | 
			
		||||
    get_cached_casts_grad_enabled() :
 | 
			
		||||
    get_cached_casts_grad_disabled();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::mutex cached_casts_mutex;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -106,9 +86,7 @@ thread_local bool cache_enabled = true;
 | 
			
		||||
 | 
			
		||||
void clear_cache() {
 | 
			
		||||
  const std::lock_guard<std::mutex> lock(cached_casts_mutex);
 | 
			
		||||
  // Clear both caches to ensure consistent behavior regardless of current gradient mode
 | 
			
		||||
  get_cached_casts_grad_enabled().clear();
 | 
			
		||||
  get_cached_casts_grad_disabled().clear();
 | 
			
		||||
  get_cached_casts().clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int increment_nesting() {
 | 
			
		||||
@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
 | 
			
		||||
  if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
 | 
			
		||||
    // Heuristic:  Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
 | 
			
		||||
    // See cached_casts declaration above for detailed strategy.
 | 
			
		||||
    //
 | 
			
		||||
    // We maintain separate caches for gradient-enabled and gradient-disabled modes
 | 
			
		||||
    // (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
 | 
			
		||||
    // with torch.autocast(), while maintaining optimal performance for both training and inference.
 | 
			
		||||
    // This fixes issue #158232 without any performance regression.
 | 
			
		||||
    bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
 | 
			
		||||
                         arg.scalar_type() == at::kFloat && arg.requires_grad() &&
 | 
			
		||||
                         arg.is_leaf() && !arg.is_view() && cache_enabled &&
 | 
			
		||||
 | 
			
		||||
@ -624,14 +624,7 @@ struct TORCH_API IValue final {
 | 
			
		||||
  IValue(const c10::SymBool& i) {
 | 
			
		||||
    if (auto mi = i.maybe_as_bool()) {
 | 
			
		||||
      tag = Tag::Bool;
 | 
			
		||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
 | 
			
		||||
      payload.u.as_int = *mi;
 | 
			
		||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
 | 
			
		||||
      /* due to byteorder if value assigned as_int, as_bool actually is not set correctly */
 | 
			
		||||
      payload.u.as_bool = *mi;
 | 
			
		||||
#else
 | 
			
		||||
#error Unexpected or undefined __BYTE_ORDER__
 | 
			
		||||
#endif
 | 
			
		||||
    } else {
 | 
			
		||||
      tag = Tag::SymBool;
 | 
			
		||||
      payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
#include <c10/core/ScalarType.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/tunable/TunableOp.h>
 | 
			
		||||
#include <ATen/cuda/tunable/Tunable.h>
 | 
			
		||||
#include <ATen/cuda/CUDABlas.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
@ -151,7 +150,6 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
 | 
			
		||||
      BLASType = "unknown";
 | 
			
		||||
  }
 | 
			
		||||
  return BLASType;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Similar to Compute Type in GemmRocblas.h
 | 
			
		||||
@ -246,25 +244,33 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
 | 
			
		||||
 | 
			
		||||
  if (!config.enabled) {
 | 
			
		||||
    return true; // skip when disabled
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
 | 
			
		||||
  auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
 | 
			
		||||
  // comparison done as 1D tensor
 | 
			
		||||
  at::Tensor ref = at::from_blob(c,       {size}, options);
 | 
			
		||||
  at::Tensor oth = at::from_blob(other_c, {size}, options);
 | 
			
		||||
  at::Tensor ref_float = ref.to(at::kFloat);
 | 
			
		||||
  at::Tensor oth_float = oth.to(at::kFloat);
 | 
			
		||||
 | 
			
		||||
  const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
 | 
			
		||||
  if (ok) {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
 | 
			
		||||
  } else {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
 | 
			
		||||
  std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
 | 
			
		||||
  std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
 | 
			
		||||
  double last_succeed_atol = 1;
 | 
			
		||||
  double last_succeed_rtol = 1;
 | 
			
		||||
  for (auto& atol : atols) {
 | 
			
		||||
    for (auto& rtol : rtols) {
 | 
			
		||||
      if (at::allclose(ref_float, oth_float, rtol, atol)) {
 | 
			
		||||
        last_succeed_atol = atol;
 | 
			
		||||
        last_succeed_rtol = rtol;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return ok;
 | 
			
		||||
  if (last_succeed_atol == 1) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -349,10 +355,8 @@ struct GemmParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<T>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -445,10 +449,8 @@ struct GemmAndBiasParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<T>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -544,10 +546,8 @@ struct GemmStridedBatchedParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
@ -663,9 +663,7 @@ struct ScaledGemmParams : OpParams {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
 | 
			
		||||
    auto* ctx = getTuningContext();
 | 
			
		||||
    auto cfg = ctx->GetNumericalCheckConfig();
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
 | 
			
		||||
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char transa{};
 | 
			
		||||
 | 
			
		||||
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
 | 
			
		||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
 | 
			
		||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
 | 
			
		||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
 | 
			
		||||
@ -173,9 +173,10 @@ All python APIs exist in the `torch.cuda.tunable` module.
 | 
			
		||||
| get_max_tuning_iterations() -> int | |
 | 
			
		||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
 | 
			
		||||
| get_filename() -> str | |
 | 
			
		||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
 | 
			
		||||
| get_results() -> Tuple[str, str, str, float] | |
 | 
			
		||||
| get_validators() -> Tuple[str, str] | |
 | 
			
		||||
| write_file_on_exit(val: bool) -> None | Default is True. |
 | 
			
		||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
 | 
			
		||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
 | 
			
		||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
 | 
			
		||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
 | 
			
		||||
 | 
			
		||||
@ -107,30 +107,14 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
 | 
			
		||||
  bool is_new = false;
 | 
			
		||||
  ResultEntry inserted = ResultEntry::Null();
 | 
			
		||||
  std::scoped_lock l{lock_};
 | 
			
		||||
 | 
			
		||||
  // ---- mutate maps under results lock ----
 | 
			
		||||
  {
 | 
			
		||||
    std::scoped_lock l{lock_};
 | 
			
		||||
    auto& km = results_[op_signature];  // creates if missing
 | 
			
		||||
    is_new = (km.find(params_signature) == km.end());
 | 
			
		||||
    AddImpl(op_signature, params_signature, std::move(best), km);
 | 
			
		||||
    if (is_new) {
 | 
			
		||||
      inserted = km.at(params_signature);  // snapshot for I/O after unlocking
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
   if (!is_new) return;  // only write once per unique (op, params)
 | 
			
		||||
 | 
			
		||||
   TuningContext* ctx = getTuningContext();
 | 
			
		||||
  if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
 | 
			
		||||
    InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
 | 
			
		||||
 | 
			
		||||
    if (is_new && realtime_out_ && realtime_out_->good()) {
 | 
			
		||||
      AppendResultLine(op_signature, params_signature, inserted);
 | 
			
		||||
    }
 | 
			
		||||
  auto it = results_.find(op_signature);
 | 
			
		||||
  if (it == results_.end()) {
 | 
			
		||||
    it = results_.insert({op_signature, {}}).first;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  AddImpl(op_signature, params_signature, std::move(best), it->second);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
 | 
			
		||||
@ -166,77 +150,6 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
  if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (realtime_out_ && realtime_filename_ != filename) {
 | 
			
		||||
    realtime_out_->flush();
 | 
			
		||||
    realtime_out_->close();
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    validators_written_ = false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool file_exists = false;
 | 
			
		||||
  bool file_empty = true;
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    std::ifstream check_file(filename);
 | 
			
		||||
    if (check_file.good()) {
 | 
			
		||||
      file_exists = true;
 | 
			
		||||
      file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
 | 
			
		||||
 | 
			
		||||
  if (!realtime_out_->good()) {
 | 
			
		||||
    TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if(!file_exists || file_empty) {
 | 
			
		||||
    for(const auto& [key, val] : validators) {
 | 
			
		||||
      (*realtime_out_) << "Validator," << key << "," << val << std::endl;
 | 
			
		||||
      realtime_out_->flush();
 | 
			
		||||
    }
 | 
			
		||||
    validators_written_ = true;
 | 
			
		||||
 | 
			
		||||
    TUNABLE_LOG2("Wrote validators to realtime output file");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  realtime_filename_ = filename;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
  if(!realtime_out_ || !realtime_out_->good()) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  (*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
 | 
			
		||||
  realtime_out_->flush(); //ensure immediate write to disk
 | 
			
		||||
 | 
			
		||||
  TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::CloseRealtimeAppend() {
 | 
			
		||||
  std::scoped_lock fl{realtime_file_mutex_};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  if(realtime_out_) {
 | 
			
		||||
    realtime_out_->flush();
 | 
			
		||||
    realtime_out_->close();
 | 
			
		||||
    realtime_out_.reset();
 | 
			
		||||
    TUNABLE_LOG2("Closed realtime output file");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
 | 
			
		||||
  std::scoped_lock l{lock_};
 | 
			
		||||
 | 
			
		||||
@ -483,6 +396,7 @@ TuningContext::TuningContext() :
 | 
			
		||||
    tuning_enable_{true},
 | 
			
		||||
    record_untuned_enable_{false},
 | 
			
		||||
    manager_initialized_{false},
 | 
			
		||||
    write_file_on_exit_{true},
 | 
			
		||||
    numerics_check_enable_{false},
 | 
			
		||||
    max_tuning_duration_ms_{30},
 | 
			
		||||
    max_tuning_iterations_{100},
 | 
			
		||||
@ -503,8 +417,20 @@ TuningContext::~TuningContext() {
 | 
			
		||||
    // but doesn't do any computation itself.
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  TUNABLE_LOG1("Closing File");
 | 
			
		||||
  GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
 | 
			
		||||
  auto filename = GetFilename();
 | 
			
		||||
  if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
 | 
			
		||||
    if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
 | 
			
		||||
      if (results_count_from_input_file_ > 0) {
 | 
			
		||||
        TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        TUNABLE_LOG1("writing file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
      if (!WriteFile(filename)) {
 | 
			
		||||
        TUNABLE_LOG1("failed to write file ", filename);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (untuned_file_.good()) {
 | 
			
		||||
    untuned_file_.close();
 | 
			
		||||
@ -585,54 +511,20 @@ std::ofstream& TuningContext::GetUntunedFile(){
 | 
			
		||||
  return untuned_file_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::WriteFileOnExit(bool value) {
 | 
			
		||||
  write_file_on_exit_ = value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::EnableNumericsCheck(bool value) {
 | 
			
		||||
  numerics_check_enable_ = value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
 | 
			
		||||
  const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
 | 
			
		||||
 | 
			
		||||
  if (!env_opt.has_value()) {
 | 
			
		||||
    return numerics_cfg_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const std::string& env = env_opt.value();
 | 
			
		||||
 | 
			
		||||
  if (env == "0") {
 | 
			
		||||
    return NumericalCheckConfig(false, 1e-5, 1e-5);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t underscore = env.find('_');
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      underscore != std::string::npos,
 | 
			
		||||
      "Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
 | 
			
		||||
      "Expected 'atol_rtol', got: ",
 | 
			
		||||
      env);
 | 
			
		||||
 | 
			
		||||
  double atol = 0.0;
 | 
			
		||||
  double rtol = 0.0;
 | 
			
		||||
 | 
			
		||||
  try {
 | 
			
		||||
    atol = std::stod(env.substr(0, underscore));
 | 
			
		||||
    rtol = std::stod(env.substr(underscore + 1));
 | 
			
		||||
  } catch (const std::exception& e) {
 | 
			
		||||
    TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
 | 
			
		||||
  return NumericalCheckConfig(true, atol, rtol);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
 | 
			
		||||
  TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
 | 
			
		||||
  numerics_cfg_ = {enabled, atol, rtol};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::IsNumericsCheckEnabled() const {
 | 
			
		||||
  const auto cfg = GetNumericalCheckConfig();
 | 
			
		||||
  return cfg.enabled || numerics_check_enable_;
 | 
			
		||||
  const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
 | 
			
		||||
  if (env == "1") {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  return numerics_check_enable_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
 | 
			
		||||
@ -742,6 +634,11 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
 | 
			
		||||
    auto filename = GetFilename();
 | 
			
		||||
    if (!filename.empty() && !IsRecordUntunedEnabled()) {
 | 
			
		||||
      ReadFile(filename);
 | 
			
		||||
      // attempt immediately to open file for writing to catch errors early
 | 
			
		||||
      std::ofstream file(filename, std::ios::out | std::ios::app);
 | 
			
		||||
      if (!file.good()) {
 | 
			
		||||
        TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
  return manager_;
 | 
			
		||||
@ -847,6 +744,27 @@ bool TuningContext::ReadFile(const std::string& filename_) {
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool TuningContext::WriteFile(const std::string& filename_) {
 | 
			
		||||
  std::string filename = filename_.empty() ? GetFilename() : filename_;
 | 
			
		||||
  std::ofstream file(filename, std::ios::out | std::ios::trunc);
 | 
			
		||||
  if (!file.good()) {
 | 
			
		||||
    TUNABLE_LOG1("error opening tuning results file for writing ", filename);
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  auto validators = GetTuningResultsValidator().GetAllValidators();
 | 
			
		||||
  for (const auto& [key, val] : validators) {
 | 
			
		||||
    file << "Validator," << key << "," << val << std::endl;
 | 
			
		||||
  }
 | 
			
		||||
  auto results = GetTuningResultsManager().Dump();
 | 
			
		||||
  for (const auto& [op_sig, kernelmap] : results) {
 | 
			
		||||
    for (const auto& [param_sig, result] : kernelmap) {
 | 
			
		||||
      file << op_sig << "," << param_sig << "," << result << std::endl;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  file.close();
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct MaybeDelete {
 | 
			
		||||
 | 
			
		||||
@ -103,24 +103,10 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
 | 
			
		||||
 | 
			
		||||
    void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
 | 
			
		||||
      const std::string& params_signature, const std::string& blas_signature);
 | 
			
		||||
 | 
			
		||||
    void InitRealtimeAppend(
 | 
			
		||||
        const std::string& filename,
 | 
			
		||||
        const std::unordered_map<std::string, std::string>& validators);
 | 
			
		||||
 | 
			
		||||
    void AppendResultLine(const std::string& op_sig,
 | 
			
		||||
                         const std::string& param_sig,
 | 
			
		||||
                         const ResultEntry& result);
 | 
			
		||||
 | 
			
		||||
    void CloseRealtimeAppend();  // For clean shutdown
 | 
			
		||||
  private:
 | 
			
		||||
    std::mutex lock_;
 | 
			
		||||
    std::mutex realtime_file_mutex_;
 | 
			
		||||
    std::unique_ptr<std::ofstream> realtime_out_;
 | 
			
		||||
    std::string realtime_filename_;
 | 
			
		||||
    ResultsMap results_;
 | 
			
		||||
    UntunedMap untuned_results_;
 | 
			
		||||
    bool validators_written_ = false;
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -148,16 +134,6 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
 | 
			
		||||
    GetValidateFuncs validators_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct NumericalCheckConfig {
 | 
			
		||||
  bool   enabled{false};
 | 
			
		||||
  double atol{1e-5};
 | 
			
		||||
  double rtol{1e-5};
 | 
			
		||||
 | 
			
		||||
  NumericalCheckConfig() = default;
 | 
			
		||||
  NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
  public:
 | 
			
		||||
    TuningContext();
 | 
			
		||||
@ -179,8 +155,6 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
 | 
			
		||||
    void EnableNumericsCheck(bool value);
 | 
			
		||||
    bool IsNumericsCheckEnabled() const;
 | 
			
		||||
    void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
 | 
			
		||||
    NumericalCheckConfig GetNumericalCheckConfig() const;
 | 
			
		||||
 | 
			
		||||
    void SetMaxTuningDurationMs(int max_duration_ms);
 | 
			
		||||
    int GetMaxTuningDurationMs() const;
 | 
			
		||||
@ -211,7 +185,10 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
 | 
			
		||||
    std::string GetFilename() const;
 | 
			
		||||
 | 
			
		||||
    void WriteFileOnExit(bool value);
 | 
			
		||||
 | 
			
		||||
    bool ReadFile(const std::string& filename={});
 | 
			
		||||
    bool WriteFile(const std::string& filename={});
 | 
			
		||||
 | 
			
		||||
    template<class... Types>
 | 
			
		||||
    void Log(int level, Types... args) {
 | 
			
		||||
@ -230,6 +207,7 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    bool tuning_enable_;
 | 
			
		||||
    bool record_untuned_enable_;
 | 
			
		||||
    bool manager_initialized_;
 | 
			
		||||
    bool write_file_on_exit_;
 | 
			
		||||
    bool numerics_check_enable_;
 | 
			
		||||
    int max_tuning_duration_ms_;
 | 
			
		||||
    int max_tuning_iterations_;
 | 
			
		||||
@ -244,8 +222,6 @@ class TORCH_CUDA_CPP_API TuningContext {
 | 
			
		||||
    std::ofstream untuned_file_;
 | 
			
		||||
    size_t results_count_from_input_file_;
 | 
			
		||||
    bool is_shutting_down_;
 | 
			
		||||
 | 
			
		||||
    NumericalCheckConfig numerics_cfg_{};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
 | 
			
		||||
 | 
			
		||||
@ -267,10 +267,27 @@ class TunableOp {
 | 
			
		||||
      for (size_t i = 0; i < op_names_.size(); i++) {
 | 
			
		||||
        auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
 | 
			
		||||
 | 
			
		||||
        auto status = candidate->Call(reusable_params[0]);
 | 
			
		||||
        if (status != OK) {
 | 
			
		||||
          TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
          continue;
 | 
			
		||||
        if (do_numerics_check) {
 | 
			
		||||
          ParamsT* numerical_params = params->DeepCopy(false);
 | 
			
		||||
          auto status = candidate->Call(numerical_params);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            numerical_params->Delete();
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          status = reference_params->NumericalCheck(numerical_params);
 | 
			
		||||
          numerical_params->Delete();
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
          auto status = candidate->Call(reusable_params[0]);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // collect a small profile
 | 
			
		||||
@ -293,22 +310,6 @@ class TunableOp {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (do_numerics_check) {
 | 
			
		||||
          ParamsT* numerical_params = params->DeepCopy(false);
 | 
			
		||||
          auto status = candidate->Call(numerical_params);
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            numerical_params->Delete();
 | 
			
		||||
            TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          status = reference_params->NumericalCheck(numerical_params);
 | 
			
		||||
          numerical_params->Delete();
 | 
			
		||||
          if (status != OK) {
 | 
			
		||||
            TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // for warmup does user set max duration, max iters, or both?
 | 
			
		||||
        // warmup is skipped by default, i.e. warmup_iter = 0
 | 
			
		||||
        // warmup will be set to the non-zero value of max_warmup_duration
 | 
			
		||||
 | 
			
		||||
@ -213,22 +213,40 @@ static cudnn_grid_sample_backward_batch_rule(
 | 
			
		||||
  return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// uses functional formulation for one_hot under vmap to be compatible with
 | 
			
		||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
 | 
			
		||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
 | 
			
		||||
// but requires explicit positive num_classes under vmap to avoid
 | 
			
		||||
// data-dependent output shapes.
 | 
			
		||||
// TODO: replace with targetable functionalization
 | 
			
		||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
    TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
 | 
			
		||||
    auto shape = self.sym_sizes().vec();
 | 
			
		||||
 | 
			
		||||
    // empty tensor could be converted to one hot representation,
 | 
			
		||||
    // but shape inference is not possible.
 | 
			
		||||
    if (self.sym_numel() == 0) {
 | 
			
		||||
        if (num_classes <= 0) {
 | 
			
		||||
            TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
 | 
			
		||||
        } else {
 | 
			
		||||
            shape.emplace_back(num_classes);
 | 
			
		||||
            return at::empty_symint(shape, self.options());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // disallow implicit inference under vmap; this would be data-dependent
 | 
			
		||||
    // and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
 | 
			
		||||
    TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
 | 
			
		||||
        "provide an explicit positive num_classes argument.");
 | 
			
		||||
 | 
			
		||||
    const auto options = self.options();
 | 
			
		||||
    at::Tensor index = at::arange(num_classes, options);
 | 
			
		||||
    return at::eq(self.unsqueeze(-1), index).to(at::kLong);
 | 
			
		||||
    // Disabling all of the following checks. This is OK because scatter has checks too.
 | 
			
		||||
    // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
 | 
			
		||||
    // // non-empty tensor
 | 
			
		||||
    // if (self.device().type() != at::kCUDA) {
 | 
			
		||||
    //   //for cuda, rely on device assert thrown by scatter
 | 
			
		||||
    //   TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
 | 
			
		||||
    // }
 | 
			
		||||
    // if (self.device().type() != at::kCUDA) {
 | 
			
		||||
    //   //rely on device asserts from scatter to avoid sync here
 | 
			
		||||
    //   TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
 | 
			
		||||
    // }
 | 
			
		||||
 | 
			
		||||
    shape.emplace_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros_symint(shape, self.options());
 | 
			
		||||
    return ret.scatter(-1, self.unsqueeze(-1), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename A, A a, typename C>
 | 
			
		||||
 | 
			
		||||
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto shape = self.sym_sizes().vec();
 | 
			
		||||
    auto shape = self.sizes().vec();
 | 
			
		||||
 | 
			
		||||
    // empty tensor could be converted to one hot representation,
 | 
			
		||||
    // but shape inference is not possible.
 | 
			
		||||
    if (self.sym_numel() == 0) {
 | 
			
		||||
    if (self.numel() == 0) {
 | 
			
		||||
        if (num_classes <= 0) {
 | 
			
		||||
            TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
 | 
			
		||||
        } else {
 | 
			
		||||
            shape.emplace_back(num_classes);
 | 
			
		||||
            return at::empty_symint(shape, self.options());
 | 
			
		||||
            shape.push_back(num_classes);
 | 
			
		||||
            return at::empty(shape, self.options());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    shape.emplace_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros_symint(shape, self.options());
 | 
			
		||||
    shape.push_back(num_classes);
 | 
			
		||||
    Tensor ret = at::zeros(shape, self.options());
 | 
			
		||||
    ret.scatter_(-1, self.unsqueeze(-1), 1);
 | 
			
		||||
    return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1906,9 +1906,11 @@ Tensor& index_fill_(
 | 
			
		||||
        "This also applies to advanced indexing e.g. tensor[mask] = scalar");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      self.is_complex() || !source.isComplex(),
 | 
			
		||||
      "index_fill_(): Converting complex Scalar to non-complex type is not supported");
 | 
			
		||||
  if (!self.is_complex() && source.isComplex()) {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        false,
 | 
			
		||||
        "index_fill_(): Converting complex Scalar to non-complex type is not supported");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Handle the case when `self` is 0-dim
 | 
			
		||||
  Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
 | 
			
		||||
  } else if (dtype == ScalarType::Half) {
 | 
			
		||||
    [&]() {
 | 
			
		||||
      using scalar_t =
 | 
			
		||||
          c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
 | 
			
		||||
      const auto exp = exp_scalar.to<scalar_t>();
 | 
			
		||||
      using Vec = Vectorized<scalar_t>;
 | 
			
		||||
      cpu_kernel_vec(iter,
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -856,13 +856,9 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
      out_calc_t output_offset_calculator,
 | 
			
		||||
      loader_t loader,
 | 
			
		||||
      storer_t storer) {
 | 
			
		||||
    constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
 | 
			
		||||
    constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
 | 
			
		||||
    constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
 | 
			
		||||
    if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
 | 
			
		||||
      using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
 | 
			
		||||
      using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
 | 
			
		||||
      using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
 | 
			
		||||
    if (ret_t == rt_binary_specializations[arg_index][0] &&
 | 
			
		||||
        arg0_t == rt_binary_specializations[arg_index][1] &&
 | 
			
		||||
        arg1_t == rt_binary_specializations[arg_index][2])
 | 
			
		||||
      launch_vectorized_templated_kernel<
 | 
			
		||||
          func_t,
 | 
			
		||||
          array_t,
 | 
			
		||||
@ -870,9 +866,12 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          out_calc_t,
 | 
			
		||||
          loader_t,
 | 
			
		||||
          storer_t,
 | 
			
		||||
          cret_t,
 | 
			
		||||
          carg0_t,
 | 
			
		||||
          carg1_t>(
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][0]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][1]>::t),
 | 
			
		||||
          decltype(c10::impl::ScalarTypeToCPPType<
 | 
			
		||||
                   rt_binary_specializations[arg_index][2]>::t)>(
 | 
			
		||||
          numel,
 | 
			
		||||
          f,
 | 
			
		||||
          data,
 | 
			
		||||
@ -880,7 +879,6 @@ struct type_specialized_kernel_launcher {
 | 
			
		||||
          output_offset_calculator,
 | 
			
		||||
          loader,
 | 
			
		||||
          storer);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,41 +38,12 @@ __device__ inline int min(int a, int b) {
 | 
			
		||||
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename index_t>
 | 
			
		||||
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
 | 
			
		||||
  const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
 | 
			
		||||
  return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
 | 
			
		||||
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
 | 
			
		||||
  return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename index_t>
 | 
			
		||||
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
 | 
			
		||||
  return std::min((size + pad) / stride + 1, pooled_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline bool can_use_int32_nhwc(
 | 
			
		||||
    int64_t nbatch, int64_t channels,
 | 
			
		||||
    int64_t height, int64_t width,
 | 
			
		||||
    int64_t pooled_height, int64_t pooled_width,
 | 
			
		||||
    int64_t in_stride_n, int64_t in_stride_c,
 | 
			
		||||
    int64_t in_stride_h, int64_t in_stride_w)
 | 
			
		||||
{
 | 
			
		||||
  constexpr int64_t int_max = std::numeric_limits<int>::max();
 | 
			
		||||
 | 
			
		||||
  int64_t max_intra_batch =
 | 
			
		||||
      (height ? (height - 1) * in_stride_h : 0) +
 | 
			
		||||
      (width ? (width - 1) * in_stride_w : 0) +
 | 
			
		||||
      (channels? (channels - 1) * in_stride_c : 0);
 | 
			
		||||
 | 
			
		||||
  int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;
 | 
			
		||||
 | 
			
		||||
  if (max_input_offset > int_max) return false;
 | 
			
		||||
 | 
			
		||||
  int64_t out_batch_stride = pooled_height * pooled_width * channels;
 | 
			
		||||
  if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;
 | 
			
		||||
 | 
			
		||||
  if (height * width > int_max) return false;
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
 | 
			
		||||
  return min((size + pad) / stride + 1, pooled_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// kernels borrowed from Caffe
 | 
			
		||||
@ -114,25 +85,21 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename index_t>
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
 | 
			
		||||
__global__ void max_pool_forward_nhwc(
 | 
			
		||||
    const scalar_t* bottom_data,
 | 
			
		||||
    const int nbatch,
 | 
			
		||||
    const index_t channels, const index_t height, const index_t width,
 | 
			
		||||
    const index_t pooled_height, const index_t pooled_width,
 | 
			
		||||
    const int kernel_h, const int kernel_w, const int stride_h,
 | 
			
		||||
    const int stride_w, const int pad_h, const int pad_w,
 | 
			
		||||
    const int dilation_h, const int dilation_w,
 | 
			
		||||
    const index_t in_stride_n, const index_t in_stride_c,
 | 
			
		||||
    const index_t in_stride_h, const index_t in_stride_w,
 | 
			
		||||
    const int kernel_stride_C, const int kernel_size_C,
 | 
			
		||||
    scalar_t* top_data, int64_t* top_mask) {
 | 
			
		||||
 | 
			
		||||
  extern __shared__ unsigned char smem_raw[];
 | 
			
		||||
  index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
 | 
			
		||||
  scalar_t *out_cached = reinterpret_cast<scalar_t*>(
 | 
			
		||||
      out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);
 | 
			
		||||
__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
 | 
			
		||||
                                   const int64_t channels, const int64_t height,
 | 
			
		||||
                                   const int64_t width, const int pooled_height, const int pooled_width,
 | 
			
		||||
                                   const int kernel_h, const int kernel_w, const int stride_h,
 | 
			
		||||
                                   const int stride_w, const int pad_h, const int pad_w,
 | 
			
		||||
                                   const int dilation_h, const int dilation_w,
 | 
			
		||||
                                   const int in_stride_n, const int in_stride_c,
 | 
			
		||||
                                   const int in_stride_h, const int in_stride_w,
 | 
			
		||||
                                   const int kernel_stride_C, const int kernel_size_C,
 | 
			
		||||
                                   scalar_t* top_data, int64_t* top_mask) {
 | 
			
		||||
  extern __shared__ int smem[];
 | 
			
		||||
  int *out_mask_cached = smem;
 | 
			
		||||
  scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
 | 
			
		||||
 | 
			
		||||
  // flattening cta for pre-computation & smem initialization;
 | 
			
		||||
  int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
 | 
			
		||||
@ -151,26 +118,26 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
  int channel_id = blockIdx.x / nbatch;
 | 
			
		||||
  int channel_offset = threadIdx.x + channel_id * blockDim.x;
 | 
			
		||||
 | 
			
		||||
  top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
 | 
			
		||||
  top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
 | 
			
		||||
  bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;
 | 
			
		||||
  top_data = top_data + batch_id * pooled_height * pooled_width * channels;
 | 
			
		||||
  top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
 | 
			
		||||
  bottom_data = bottom_data + batch_id * in_stride_n;
 | 
			
		||||
 | 
			
		||||
  out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
 | 
			
		||||
  out_mask_cached  += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
 | 
			
		||||
  out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
 | 
			
		||||
  out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
 | 
			
		||||
 | 
			
		||||
  int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
 | 
			
		||||
  int oW = (static_cast<int>(pooled_width)  + gridDim.y - 1) / gridDim.y;
 | 
			
		||||
  int oH = (pooled_height + gridDim.z-1) / gridDim.z;
 | 
			
		||||
  int oW = (pooled_width + gridDim.y-1) / gridDim.y;
 | 
			
		||||
  int ostartH = threadIdx.z + blockIdx.z*oH;
 | 
			
		||||
  int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
 | 
			
		||||
  int oendH = ::min(ostartH+oH, pooled_height);
 | 
			
		||||
  int ostartW = threadIdx.y + blockIdx.y*oW;
 | 
			
		||||
  int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));
 | 
			
		||||
  int oendW = ::min(ostartW+oW, pooled_width);
 | 
			
		||||
 | 
			
		||||
  for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
 | 
			
		||||
    index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
 | 
			
		||||
    index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
 | 
			
		||||
    int hstart = oh * stride_h - pad_h;
 | 
			
		||||
    int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
 | 
			
		||||
    for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
 | 
			
		||||
      index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
 | 
			
		||||
      index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
 | 
			
		||||
      int wstart = ow * stride_w - pad_w;
 | 
			
		||||
      int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
 | 
			
		||||
      while(hstart < 0)
 | 
			
		||||
        hstart += dilation_h;
 | 
			
		||||
      while(wstart < 0)
 | 
			
		||||
@ -218,12 +185,12 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
      // Else do it Non-Prefetch...
 | 
			
		||||
      else
 | 
			
		||||
#endif
 | 
			
		||||
      for (index_t ih = hstart; ih < hend; ih += dilation_h) {
 | 
			
		||||
        for (index_t iw = wstart; iw < wend; iw += dilation_w) {
 | 
			
		||||
      for (int ih = hstart; ih < hend; ih += dilation_h) {
 | 
			
		||||
        for (int iw = wstart; iw < wend; iw += dilation_w) {
 | 
			
		||||
          int cached_index = threadIdx.x;
 | 
			
		||||
          const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
 | 
			
		||||
          for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
 | 
			
		||||
            scalar_t val = ptr_input[c * in_stride_c];
 | 
			
		||||
          for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
 | 
			
		||||
            scalar_t val = ptr_input[c*in_stride_c];
 | 
			
		||||
            if ((val > out_cached[cached_index]) || at::_isnan(val)) {
 | 
			
		||||
              out_cached[cached_index] = val;
 | 
			
		||||
              out_mask_cached[cached_index] = ih * width + iw;
 | 
			
		||||
@ -233,15 +200,15 @@ __global__ void max_pool_forward_nhwc(
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
 | 
			
		||||
      int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
 | 
			
		||||
      scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
 | 
			
		||||
      int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
 | 
			
		||||
 | 
			
		||||
      int cached_index = threadIdx.x;
 | 
			
		||||
      for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
 | 
			
		||||
      for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
 | 
			
		||||
        ptr_output_data[c] = out_cached[cached_index];
 | 
			
		||||
        ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
 | 
			
		||||
        ptr_output_mask[c] = out_mask_cached[cached_index];
 | 
			
		||||
        out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
 | 
			
		||||
        out_mask_cached[cached_index] = index_t(0);
 | 
			
		||||
        out_mask_cached[cached_index] = 0;
 | 
			
		||||
        cached_index += blockDim.x;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -495,11 +462,6 @@ const Tensor& indices) {
 | 
			
		||||
              maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
 | 
			
		||||
          const dim3 block(block_x, block_y, block_z);
 | 
			
		||||
 | 
			
		||||
          bool use_int32 = can_use_int32_nhwc(
 | 
			
		||||
              nbatch, nInputPlane, inputHeight, inputWidth,
 | 
			
		||||
              outputHeight, outputWidth,
 | 
			
		||||
              in_stride_n, in_stride_c, in_stride_h, in_stride_w);
 | 
			
		||||
 | 
			
		||||
          int kernel_stride_C = ceil_div(
 | 
			
		||||
              safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
 | 
			
		||||
          int kernel_size_C = ceil_div(
 | 
			
		||||
@ -514,41 +476,18 @@ const Tensor& indices) {
 | 
			
		||||
              ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
 | 
			
		||||
          const dim3 grid(grid_x, grid_y, grid_z);
 | 
			
		||||
 | 
			
		||||
          size_t shmem_size;
 | 
			
		||||
          size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;
 | 
			
		||||
          size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
 | 
			
		||||
          AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
 | 
			
		||||
 | 
			
		||||
          if (use_int32) {
 | 
			
		||||
            shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
 | 
			
		||||
            TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
 | 
			
		||||
                        "shared memory too small");
 | 
			
		||||
            max_pool_forward_nhwc<scalar_t, int32_t>
 | 
			
		||||
              <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
                input_data, static_cast<int>(nbatch),
 | 
			
		||||
                static_cast<int32_t>(nInputPlane),
 | 
			
		||||
                static_cast<int32_t>(inputHeight),
 | 
			
		||||
                static_cast<int32_t>(inputWidth),
 | 
			
		||||
                static_cast<int32_t>(outputHeight),
 | 
			
		||||
                static_cast<int32_t>(outputWidth),
 | 
			
		||||
                kH, kW, dH, dW, padH, padW, dilationH, dilationW,
 | 
			
		||||
                static_cast<int32_t>(in_stride_n),
 | 
			
		||||
                static_cast<int32_t>(in_stride_c),
 | 
			
		||||
                static_cast<int32_t>(in_stride_h),
 | 
			
		||||
                static_cast<int32_t>(in_stride_w),
 | 
			
		||||
                kernel_stride_C, kernel_size_C,
 | 
			
		||||
                output_data, indices_data);
 | 
			
		||||
          } else {
 | 
			
		||||
            shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
 | 
			
		||||
            TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
 | 
			
		||||
                        "shared memory too small");
 | 
			
		||||
            max_pool_forward_nhwc<scalar_t, int64_t>
 | 
			
		||||
              <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
                input_data, static_cast<int>(nbatch),
 | 
			
		||||
                nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
 | 
			
		||||
                kH, kW, dH, dW, padH, padW, dilationH, dilationW,
 | 
			
		||||
                in_stride_n, in_stride_c, in_stride_h, in_stride_w,
 | 
			
		||||
                kernel_stride_C, kernel_size_C,
 | 
			
		||||
                output_data, indices_data);
 | 
			
		||||
          }
 | 
			
		||||
          max_pool_forward_nhwc<scalar_t>
 | 
			
		||||
          <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
 | 
			
		||||
              input_data, nbatch,
 | 
			
		||||
                  nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
 | 
			
		||||
                  kH, kW, dH, dW, padH, padW, dilationH, dilationW,
 | 
			
		||||
                  in_stride_n, in_stride_c,
 | 
			
		||||
                  in_stride_h, in_stride_w,
 | 
			
		||||
                  kernel_stride_C, kernel_size_C,
 | 
			
		||||
                  output_data, indices_data);
 | 
			
		||||
          C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -655,14 +655,8 @@ struct ReduceOp {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
 | 
			
		||||
    // matching Triton, etc.
 | 
			
		||||
    // todo for AMD
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
    #endif
 | 
			
		||||
      #pragma unroll
 | 
			
		||||
      for (int i = 0; i < output_vec_size; i++) {
 | 
			
		||||
        arg_t other = ops.warp_shfl_down(value[i], offset);
 | 
			
		||||
 | 
			
		||||
@ -77,8 +77,8 @@ struct nansum_functor_complex {
 | 
			
		||||
#if AT_USE_JITERATOR()
 | 
			
		||||
  void operator()(TensorIterator& iter) {
 | 
			
		||||
    std::string func = jiterator_stringify(
 | 
			
		||||
        arg_t combine(arg_t a, arg_t b) {
 | 
			
		||||
          return a + (std::isnan(b) ? arg_t{0.} : b);
 | 
			
		||||
        arg_t combine(arg_t a, scalar_t b) {
 | 
			
		||||
          return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
 | 
			
		||||
        }
 | 
			
		||||
    );
 | 
			
		||||
    jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
 | 
			
		||||
 | 
			
		||||
@ -464,7 +464,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    int32_t trailingSize;
 | 
			
		||||
    int nDimsLocal = nDims;
 | 
			
		||||
    TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
 | 
			
		||||
    if (isInOutAligned) {
 | 
			
		||||
      // in this case we can and should flatten the tensors after the cat dim
 | 
			
		||||
@ -478,7 +477,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
      // and divide all strides except last by elems_per_vec (last stride is 1 always)
 | 
			
		||||
      // for input, we will fix up the sizes and strides in the kernel directly
 | 
			
		||||
      kernelOutputParam = outputParam;
 | 
			
		||||
      nDimsLocal = dimension + 1;
 | 
			
		||||
      nDims = dimension + 1;
 | 
			
		||||
      constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
 | 
			
		||||
      auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
 | 
			
		||||
      kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
 | 
			
		||||
@ -495,7 +494,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
      case 0:
 | 
			
		||||
        break;
 | 
			
		||||
      case 1:
 | 
			
		||||
        cat_dim = nDimsLocal - cat_dim;
 | 
			
		||||
        cat_dim = nDims - cat_dim;
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        cat_dim--;
 | 
			
		||||
@ -526,7 +525,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
 | 
			
		||||
              data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
 | 
			
		||||
    }\
 | 
			
		||||
    C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
    switch (nDimsLocal) {
 | 
			
		||||
    switch (nDims) {
 | 
			
		||||
      case 1:
 | 
			
		||||
        HANDLE_CASE(1);
 | 
			
		||||
        break;
 | 
			
		||||
 | 
			
		||||
@ -21,15 +21,9 @@ namespace {
 | 
			
		||||
struct offset_t {
 | 
			
		||||
  int stride;
 | 
			
		||||
  int begin;
 | 
			
		||||
  __device__ int operator[](int i) const {
 | 
			
		||||
  __device__ int operator[](int i) {
 | 
			
		||||
    return stride * (begin + i);
 | 
			
		||||
  }
 | 
			
		||||
#if CCCL_VERSION >= 3001000
 | 
			
		||||
  __device__ offset_t& operator+=(int i) {
 | 
			
		||||
    begin += i;
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
// Segmented sort by full sort algorithm:.
 | 
			
		||||
// Say we are sorting a (2, 3) tensor. We have in flattened form:
 | 
			
		||||
 | 
			
		||||
@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
// Helper function to compute output pixel range that can contribute to input pixel
 | 
			
		||||
template <typename accscalar_t>
 | 
			
		||||
__device__ __forceinline__ void compute_output_range(
 | 
			
		||||
    int input_pos,
 | 
			
		||||
    accscalar_t scale,
 | 
			
		||||
    int output_size,
 | 
			
		||||
    bool align_corners,
 | 
			
		||||
    int& min_output,
 | 
			
		||||
    int& max_output) {
 | 
			
		||||
  accscalar_t lo, hi;
 | 
			
		||||
  if (align_corners) {
 | 
			
		||||
      lo = static_cast<accscalar_t>(input_pos - 1) / scale;
 | 
			
		||||
      hi = static_cast<accscalar_t>(input_pos + 1) / scale;
 | 
			
		||||
  } else {
 | 
			
		||||
      lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
      hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
 | 
			
		||||
  }
 | 
			
		||||
  min_output = max(0, static_cast<int>(ceil(lo)));
 | 
			
		||||
  max_output = min(output_size - 1, static_cast<int>(floor(hi)));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(1024)
 | 
			
		||||
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
    const bool align_corners,
 | 
			
		||||
    scalar_t* __restrict__ idata,
 | 
			
		||||
    const scalar_t* __restrict__ odata) {
 | 
			
		||||
  // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    // Decode input pixel coordinates
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
    const int w1 = index_temp % width1;
 | 
			
		||||
    index_temp /= width1;
 | 
			
		||||
    const int h1 = index_temp % height1;
 | 
			
		||||
    const size_t nc_idx = index_temp / height1;
 | 
			
		||||
 | 
			
		||||
    accscalar_t grad_sum = 0;
 | 
			
		||||
 | 
			
		||||
    // Find range of output pixels that could interpolate from this input pixel
 | 
			
		||||
    int h2_min, h2_max, w2_min, w2_max;
 | 
			
		||||
    compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
 | 
			
		||||
    compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
 | 
			
		||||
 | 
			
		||||
    // Iterate over potential output pixels
 | 
			
		||||
    for (int h2 = h2_min; h2 <= h2_max; h2++) {
 | 
			
		||||
      for (int w2 = w2_min; w2 <= w2_max; w2++) {
 | 
			
		||||
        // Compute source coordinates for this output pixel
 | 
			
		||||
        const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rheight, h2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int h1_base = (int)h1r;
 | 
			
		||||
        const int h1p = (h1_base < height1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t h1lambda = h1r - h1_base;
 | 
			
		||||
        const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
 | 
			
		||||
 | 
			
		||||
        const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
 | 
			
		||||
            rwidth, w2, align_corners, /*cubic=*/false);
 | 
			
		||||
        const int w1_base = (int)w1r;
 | 
			
		||||
        const int w1p = (w1_base < width1 - 1) ? 1 : 0;
 | 
			
		||||
        const accscalar_t w1lambda = w1r - w1_base;
 | 
			
		||||
        const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
 | 
			
		||||
 | 
			
		||||
        // Check if our input pixel participates in this interpolation and accumulate all weights
 | 
			
		||||
        // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
 | 
			
		||||
        // to the same pixel, so we need to accumulate weights from all matching positions
 | 
			
		||||
        accscalar_t weight = 0;
 | 
			
		||||
 | 
			
		||||
        // Check all four interpolation positions and accumulate weights
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base) {
 | 
			
		||||
          weight += h0lambda * w0lambda;  // top-left
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h0lambda * w1lambda;  // top-right (may be same as top-left if w1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base) {
 | 
			
		||||
          weight += h1lambda * w0lambda;  // bottom-left (may be same as top-left if h1p=0)
 | 
			
		||||
        }
 | 
			
		||||
        if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
 | 
			
		||||
          weight += h1lambda * w1lambda;  // bottom-right (may collapse to other positions)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (weight > 0) {
 | 
			
		||||
          const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
 | 
			
		||||
          grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write accumulated gradient (no atomics needed)
 | 
			
		||||
    idata[index] = static_cast<scalar_t>(grad_sum);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  const size_t o_numel = nc * width2 * height2;
 | 
			
		||||
  const size_t i_numel = nc * width1 * height1;
 | 
			
		||||
  for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
 | 
			
		||||
       index += blockDim.x * gridDim.x) {
 | 
			
		||||
    size_t index_temp = index;
 | 
			
		||||
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
 | 
			
		||||
        static_cast<scalar_t>(h1lambda * w1lambda * d2val),
 | 
			
		||||
        true);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename accscalar_t>
 | 
			
		||||
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
  // threads are not covering the whole input tensor.
 | 
			
		||||
  grad_input.zero_();
 | 
			
		||||
 | 
			
		||||
  const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
  const int num_threads = std::min(
 | 
			
		||||
      at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  constexpr bool use_input = true;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr bool use_input = false;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half, at::ScalarType::BFloat16,
 | 
			
		||||
      grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
 | 
			
		||||
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * output_height * output_width;
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
 | 
			
		||||
              input_height,
 | 
			
		||||
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
 | 
			
		||||
      const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
 | 
			
		||||
          input_width, output_width, align_corners, scales_w);
 | 
			
		||||
 | 
			
		||||
      const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
 | 
			
		||||
 | 
			
		||||
      upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
 | 
			
		||||
          <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
 | 
			
		||||
             num_threads,
 | 
			
		||||
 | 
			
		||||
@ -662,7 +662,7 @@ void svd_cusolver(const Tensor& A,
 | 
			
		||||
  const auto n = A.size(-1);
 | 
			
		||||
  const auto k = std::min(m, n);
 | 
			
		||||
 | 
			
		||||
  static constexpr const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
 | 
			
		||||
  static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
 | 
			
		||||
 | 
			
		||||
  // The default heuristic is to use gesvdj driver
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
 | 
			
		||||
@ -466,11 +466,7 @@ struct ReduceJitOp {
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    #ifdef USE_ROCM
 | 
			
		||||
    for (int offset = 1; offset < dim_x; offset <<= 1) {
 | 
			
		||||
    #else
 | 
			
		||||
    for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
 | 
			
		||||
    #endif
 | 
			
		||||
      #pragma unroll
 | 
			
		||||
      for (int i = 0; i < output_vec_size; i++) {
 | 
			
		||||
        arg_t other = reducer::warp_shfl_down(value[i], offset);
 | 
			
		||||
 | 
			
		||||
@ -3,9 +3,6 @@
 | 
			
		||||
#include <ATen/core/Tensor.h>
 | 
			
		||||
#include <ATen/native/DispatchStub.h>
 | 
			
		||||
#include <c10/util/accumulate.h>
 | 
			
		||||
#include <c10/core/SymBool.h>
 | 
			
		||||
#include <c10/util/StringUtil.h>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
@ -22,30 +19,28 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
 | 
			
		||||
      "Expected normalized_shape to be at least 1-dimensional, i.e., ",
 | 
			
		||||
      "containing at least one element, but got normalized_shape = ",
 | 
			
		||||
      normalized_shape);
 | 
			
		||||
  if (weight.defined()) {
 | 
			
		||||
    TORCH_SYM_CHECK(
 | 
			
		||||
        sym_equals(weight.sym_sizes(), normalized_shape),
 | 
			
		||||
        "Expected weight to be of same shape as normalized_shape, but got ",
 | 
			
		||||
        "weight of shape ",
 | 
			
		||||
        weight.sym_sizes(),
 | 
			
		||||
        " and normalized_shape = ",
 | 
			
		||||
        normalized_shape);
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !weight.defined() || weight.sym_sizes().equals(normalized_shape),
 | 
			
		||||
      "Expected weight to be of same shape as normalized_shape, but got ",
 | 
			
		||||
      "weight of shape ",
 | 
			
		||||
      weight.sym_sizes(),
 | 
			
		||||
      " and normalized_shape = ",
 | 
			
		||||
      normalized_shape);
 | 
			
		||||
 | 
			
		||||
  const auto input_ndim = input.dim();
 | 
			
		||||
  const auto input_shape = input.sym_sizes();
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
      input_ndim >= normalized_ndim,
 | 
			
		||||
      "Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
 | 
			
		||||
 | 
			
		||||
  auto expect_input_shape_msg = c10::str(
 | 
			
		||||
      "Given normalized_shape=", normalized_shape,
 | 
			
		||||
      ", expected input with shape [*", c10::Join(", ", normalized_shape),
 | 
			
		||||
      "], but got input of size", input_shape);
 | 
			
		||||
 | 
			
		||||
  TORCH_SYM_CHECK(
 | 
			
		||||
      sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
 | 
			
		||||
      expect_input_shape_msg);
 | 
			
		||||
  if (input_ndim < normalized_ndim ||
 | 
			
		||||
      !input_shape.slice(input_ndim - normalized_ndim)
 | 
			
		||||
           .equals(normalized_shape)) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "Given normalized_shape=" << normalized_shape
 | 
			
		||||
       << ", expected input with shape [*";
 | 
			
		||||
    for (auto size : normalized_shape) {
 | 
			
		||||
      ss << ", " << size;
 | 
			
		||||
    }
 | 
			
		||||
    ss << "], but got input of size" << input_shape;
 | 
			
		||||
    TORCH_CHECK(false, ss.str());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
 | 
			
		||||
 | 
			
		||||
@ -160,12 +160,8 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
 | 
			
		||||
#if defined(__x86_64__) || defined(_M_X64)
 | 
			
		||||
    return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
 | 
			
		||||
        cpuinfo_has_x86_amx_fp16();
 | 
			
		||||
#else
 | 
			
		||||
    return false;   //TF32 not supported on power system
 | 
			
		||||
#endif
 | 
			
		||||
  return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
 | 
			
		||||
      cpuinfo_has_x86_amx_fp16();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
 | 
			
		||||
 | 
			
		||||
@ -74,12 +74,8 @@ static bool use_mkldnn_bf32_linear() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool use_mkldnn_tf32_linear() {
 | 
			
		||||
#if defined(__x86_64__) || defined(_M_X64)
 | 
			
		||||
    return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
 | 
			
		||||
  return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
 | 
			
		||||
      cpuinfo_has_x86_amx_fp16();
 | 
			
		||||
#else
 | 
			
		||||
  return false;  // TF32 not supported on power system
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor mkldnn_linear(
 | 
			
		||||
 | 
			
		||||
@ -114,13 +114,8 @@ static bool use_mkldnn_bf32_matmul() {
 | 
			
		||||
  return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static bool use_mkldnn_tf32_matmul() {
 | 
			
		||||
#if defined(__x86_64__) || defined(_M_X64)
 | 
			
		||||
    return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
 | 
			
		||||
#else
 | 
			
		||||
    return false;  // TF32 not supported on power system
 | 
			
		||||
#endif
 | 
			
		||||
  return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// returns an ideep::tensor
 | 
			
		||||
 | 
			
		||||
@ -99,9 +99,6 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape);
 | 
			
		||||
MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
 | 
			
		||||
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
 | 
			
		||||
 | 
			
		||||
// Determines whether a tensor is too large to use MPSGraph
 | 
			
		||||
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI = true);
 | 
			
		||||
 | 
			
		||||
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
 | 
			
		||||
  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -439,22 +439,6 @@ static void check_mps_shape(MPSShape* shape) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI) {
 | 
			
		||||
  static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
 | 
			
		||||
  if ((!tensor.is_contiguous() || tensor.storage_offset()) && useMPSStridedAPI && is_macOS_15_0_or_newer) {
 | 
			
		||||
    auto storage_numel = tensor.storage().nbytes() / tensor.element_size() - tensor.storage_offset();
 | 
			
		||||
    if (storage_numel > std::numeric_limits<int32_t>::max()) {
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  for (auto size : tensor.sizes()) {
 | 
			
		||||
    if (size > std::numeric_limits<int32_t>::max()) {
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
 | 
			
		||||
  id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
 | 
			
		||||
 | 
			
		||||
@ -712,7 +696,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) {
 | 
			
		||||
  } else if (scalar.isBoolean()) {
 | 
			
		||||
    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool));
 | 
			
		||||
  } else if (scalar.isComplex()) {
 | 
			
		||||
    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexFloat));
 | 
			
		||||
    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble));
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(scalar.isIntegral(false));
 | 
			
		||||
    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong));
 | 
			
		||||
 | 
			
		||||
@ -249,7 +249,7 @@ kernel void embedding_bag(
 | 
			
		||||
 | 
			
		||||
template <EmbeddingBagMode M, typename T>
 | 
			
		||||
struct MaybeDivBagSize {
 | 
			
		||||
  inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> /*bag_size*/) {
 | 
			
		||||
  inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
 | 
			
		||||
    return val;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -1,18 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <c10/metal/common.h>
 | 
			
		||||
 | 
			
		||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
 | 
			
		||||
struct CatLargeSharedParams {
 | 
			
		||||
  int32_t ndim;
 | 
			
		||||
  int32_t cat_dim;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> output_strides;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> output_sizes;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
 | 
			
		||||
struct CatLargeInputParams {
 | 
			
		||||
  idx_type_t cat_dim_offset;
 | 
			
		||||
  idx_type_t input_element_offset;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> input_strides;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> input_sizes;
 | 
			
		||||
};
 | 
			
		||||
@ -1,82 +0,0 @@
 | 
			
		||||
#include <ATen/native/mps/kernels/Shape.h>
 | 
			
		||||
#include <c10/metal/utils.h>
 | 
			
		||||
#include <metal_array>
 | 
			
		||||
#include <metal_stdlib>
 | 
			
		||||
 | 
			
		||||
using namespace metal;
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
 | 
			
		||||
template <typename T_in, typename T_out>
 | 
			
		||||
kernel void cat_large(
 | 
			
		||||
    constant T_in* input [[buffer(0)]],
 | 
			
		||||
    device T_out* output [[buffer(1)]],
 | 
			
		||||
    constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
 | 
			
		||||
    constant CatLargeInputParams<>& input_params [[buffer(3)]],
 | 
			
		||||
    uint tid [[thread_position_in_grid]]) {
 | 
			
		||||
  auto ndim = shared_params.ndim;
 | 
			
		||||
  auto cat_dim = shared_params.cat_dim;
 | 
			
		||||
  constant auto& output_strides = shared_params.output_strides;
 | 
			
		||||
  constant auto& output_sizes = shared_params.output_sizes;
 | 
			
		||||
 | 
			
		||||
  auto cat_dim_offset = input_params.cat_dim_offset;
 | 
			
		||||
  auto input_element_offset = input_params.input_element_offset;
 | 
			
		||||
  constant auto& input_strides = input_params.input_strides;
 | 
			
		||||
  constant auto& input_sizes = input_params.input_sizes;
 | 
			
		||||
 | 
			
		||||
  auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
 | 
			
		||||
  int64_t input_offset = 0;
 | 
			
		||||
  int64_t output_offset = 0;
 | 
			
		||||
 | 
			
		||||
  for (auto dim = ndim - 1; dim >= 0; dim--) {
 | 
			
		||||
    auto dim_size = input_sizes[dim];
 | 
			
		||||
    auto input_dim_idx = input_element_idx % dim_size;
 | 
			
		||||
    auto output_dim_idx =
 | 
			
		||||
        input_dim_idx + ((dim == cat_dim) ? cat_dim_offset : 0);
 | 
			
		||||
 | 
			
		||||
    input_offset += input_strides[dim] * input_dim_idx;
 | 
			
		||||
    output_offset += output_strides[dim] * output_dim_idx;
 | 
			
		||||
 | 
			
		||||
    input_element_idx = input_element_idx / dim_size;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  output[output_offset] = static_cast<T_out>(input[input_offset]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out)                           \
 | 
			
		||||
  template [[host_name("cat_large_" #T_in "_" #T_out)]]              \
 | 
			
		||||
  kernel void cat_large<T_in, T_out>(                                \
 | 
			
		||||
      constant T_in * input [[buffer(0)]],                           \
 | 
			
		||||
      device T_out * output [[buffer(1)]],                           \
 | 
			
		||||
      constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
 | 
			
		||||
      constant CatLargeInputParams<> & input_params [[buffer(3)]],   \
 | 
			
		||||
      uint tid [[thread_position_in_grid]]);
 | 
			
		||||
 | 
			
		||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(float, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(half, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(bfloat, T_out);              \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(int, T_out);                 \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(uint, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(long, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(ulong, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(short, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(ushort, T_out);              \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(char, T_out);                \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(uchar, T_out);               \
 | 
			
		||||
  REGISTER_CAT_LARGE_OP(bool, T_out);
 | 
			
		||||
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
 | 
			
		||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
 | 
			
		||||
 | 
			
		||||
REGISTER_CAT_LARGE_OP(float2, float2);
 | 
			
		||||
REGISTER_CAT_LARGE_OP(half2, half2);
 | 
			
		||||
@ -512,7 +512,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
  // (1.0f + erf(x*SQRT1_2)) * 0.5f;
 | 
			
		||||
  // (1.0f + erf(x*SQRT1_2)) * 0.5f * x;
 | 
			
		||||
  auto dataType = [inputTensor dataType];
 | 
			
		||||
  const float SQRT1_2 = 0.707106781186547524400844362104849039f;
 | 
			
		||||
  MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType];
 | 
			
		||||
 | 
			
		||||
@ -54,10 +54,6 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
 | 
			
		||||
  using namespace mps;
 | 
			
		||||
  using CachedGraph = MPSBinaryCachedGraph;
 | 
			
		||||
 | 
			
		||||
  if (self.numel() == 0 & other.numel() == 0) {
 | 
			
		||||
    return zeros({}, self.options());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  dot_check(self, other);
 | 
			
		||||
 | 
			
		||||
  auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
 | 
			
		||||
 | 
			
		||||
@ -907,8 +907,6 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te
 | 
			
		||||
  TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
 | 
			
		||||
              "index_fill_(): Expected dtype int32 or int64 for index");
 | 
			
		||||
  TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor");
 | 
			
		||||
  TORCH_CHECK(self.is_complex() || !source.is_complex(),
 | 
			
		||||
              "index_fill_(): Converting complex Scalar to non-complex type is not supported");
 | 
			
		||||
  // MPS.scatter crashes if used with complex dtypes
 | 
			
		||||
  TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,13 +2,9 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/MemoryOverlap.h>
 | 
			
		||||
#include <ATen/WrapDimUtils.h>
 | 
			
		||||
#include <ATen/mps/MPSProfiler.h>
 | 
			
		||||
#include <ATen/native/TensorShape.h>
 | 
			
		||||
#include <ATen/native/TypeProperties.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
#include <ATen/native/mps/kernels/Shape.h>
 | 
			
		||||
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -20,13 +16,6 @@
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
 | 
			
		||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/native/mps/Shape_metallib.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace mps {
 | 
			
		||||
 | 
			
		||||
// Produces a shape with the `dim` dimension set to 0.
 | 
			
		||||
@ -68,70 +57,6 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
 | 
			
		||||
                ")");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This implementation of cat is used only if one of the inputs or the output is
 | 
			
		||||
// too large to use MPSGraph.
 | 
			
		||||
// NOTE: `output` is expected to already have the correct size.
 | 
			
		||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
 | 
			
		||||
  CatLargeSharedParams shared_params;
 | 
			
		||||
 | 
			
		||||
  shared_params.ndim = output.dim();
 | 
			
		||||
  shared_params.cat_dim = dimension;
 | 
			
		||||
 | 
			
		||||
  for (const auto dim : c10::irange(output.dim())) {
 | 
			
		||||
    shared_params.output_strides[dim] = output.stride(dim);
 | 
			
		||||
    shared_params.output_sizes[dim] = output.size(dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t cat_dim_offset = 0;
 | 
			
		||||
  size_t input_idx = 0;
 | 
			
		||||
  MPSStream* stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  // Launch a separate kernels for each input. This will produce some overhead,
 | 
			
		||||
  // but that should be relatively minimal since at least one of the inputs is
 | 
			
		||||
  // very large. In order to launch only one kernel to process all inputs, we
 | 
			
		||||
  // would have to copy all the input tensor data into a packed buffer, which
 | 
			
		||||
  // would not be ideal.
 | 
			
		||||
  for (const Tensor& input : inputs) {
 | 
			
		||||
    if (input.numel() == 0) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Metal can only launch up to MAX_INT threads at one time. If the input has
 | 
			
		||||
    // more than that number of elements, launch multiple kernels with different
 | 
			
		||||
    // offsets into the data.
 | 
			
		||||
    const int64_t max_num_threads = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
 | 
			
		||||
 | 
			
		||||
    for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
 | 
			
		||||
      auto num_threads = std::min(max_num_threads, numel_remaining);
 | 
			
		||||
      CatLargeInputParams input_params;
 | 
			
		||||
 | 
			
		||||
      input_params.cat_dim_offset = cat_dim_offset;
 | 
			
		||||
      input_params.input_element_offset = input.numel() - numel_remaining;
 | 
			
		||||
 | 
			
		||||
      for (const auto dim : c10::irange(input.dim())) {
 | 
			
		||||
        input_params.input_strides[dim] = input.stride(dim);
 | 
			
		||||
        input_params.input_sizes[dim] = input.size(dim);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
        @autoreleasepool {
 | 
			
		||||
          id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
 | 
			
		||||
          auto pipeline_state = lib.getPipelineStateForFunc(
 | 
			
		||||
              fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
 | 
			
		||||
          getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
 | 
			
		||||
          [computeEncoder setComputePipelineState:pipeline_state];
 | 
			
		||||
          mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
 | 
			
		||||
          mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
 | 
			
		||||
          getMPSProfiler().endProfileKernel(pipeline_state);
 | 
			
		||||
        }
 | 
			
		||||
      });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cat_dim_offset += input.size(dimension);
 | 
			
		||||
    input_idx++;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
} // namespace mps
 | 
			
		||||
 | 
			
		||||
// topk
 | 
			
		||||
@ -306,11 +231,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
 | 
			
		||||
  // Compute size of the result in the cat dimension
 | 
			
		||||
  int64_t cat_dim_size = 0;
 | 
			
		||||
  idx = 0;
 | 
			
		||||
  bool has_large_tensor = false;
 | 
			
		||||
  for (const Tensor& tensor : materialized_inputs) {
 | 
			
		||||
    if (isTooLargeForMPSGraph(tensor)) {
 | 
			
		||||
      has_large_tensor |= true;
 | 
			
		||||
    }
 | 
			
		||||
    if (!should_skip(tensor)) {
 | 
			
		||||
      // TODO: Factor out `check_shape_except_dim`
 | 
			
		||||
      check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
 | 
			
		||||
@ -328,12 +249,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  has_large_tensor |= isTooLargeForMPSGraph(out);
 | 
			
		||||
 | 
			
		||||
  if (has_large_tensor) {
 | 
			
		||||
    return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct CachedGraph : public MPSCachedGraph {
 | 
			
		||||
    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
 | 
			
		||||
    std::vector<MPSGraphTensor*> inputTensors_;
 | 
			
		||||
 | 
			
		||||
@ -4545,7 +4545,6 @@
 | 
			
		||||
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: _cdist_forward
 | 
			
		||||
    MTIA: _cdist_forward_mtia
 | 
			
		||||
    MPS: _cdist_forward_mps
 | 
			
		||||
  autogen: _cdist_forward.out
 | 
			
		||||
  tags: core
 | 
			
		||||
@ -7183,12 +7182,6 @@
 | 
			
		||||
    CUDA: _scaled_grouped_mm_cuda
 | 
			
		||||
  tags: needs_exact_strides
 | 
			
		||||
 | 
			
		||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CUDA: _scaled_grouped_mm_cuda_v2
 | 
			
		||||
  tags: needs_exact_strides
 | 
			
		||||
 | 
			
		||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
@ -7384,7 +7377,7 @@
 | 
			
		||||
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
 | 
			
		||||
  variants: method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU, SparseCUDA, SparseMPS: sparse_mask
 | 
			
		||||
    SparseCPU, SparseCUDA: sparse_mask
 | 
			
		||||
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed
 | 
			
		||||
  autogen: sparse_mask.out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -178,30 +178,24 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
 | 
			
		||||
          0 & \text{ else }
 | 
			
		||||
        \end{cases}
 | 
			
		||||
  */
 | 
			
		||||
  bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
 | 
			
		||||
  at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
 | 
			
		||||
  at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
 | 
			
		||||
  at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
 | 
			
		||||
  at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
 | 
			
		||||
  auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
 | 
			
		||||
 | 
			
		||||
  auto zero_point_rounded = _get_rounded_zero_point(zero_point_, quant_min, quant_max);
 | 
			
		||||
  TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(X_.sizes() == dY_.sizes(), "`X` and `dY` are not the same size");
 | 
			
		||||
  TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      quant_min <= 0 && quant_max >= 0,
 | 
			
		||||
      "Expecting `quant_min` <= 0 and `quant_max` >= 0");
 | 
			
		||||
  TORCH_CHECK(scale_.dim() == 1, "scale should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(zero_point_.dim() == 1, "zero point should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      scale_.numel() == zero_point_.numel(),
 | 
			
		||||
      scale.numel() == zero_point.numel(),
 | 
			
		||||
      "scale and zero-point need to have the same dimensions");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      scale_.numel() == X_.size(axis),
 | 
			
		||||
      scale.numel() == X.size(axis),
 | 
			
		||||
      "dimensions of scale and zero-point are not consistent with input tensor")
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
@ -210,42 +204,42 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
 | 
			
		||||
      "`zero_point` must be between `quant_min` and `quant_max`.");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      axis >= 0 && axis < X_.dim(),
 | 
			
		||||
      axis >= 0 && axis < X.dim(),
 | 
			
		||||
      "`axis` must be between 0 and number of dimensions of input");
 | 
			
		||||
 | 
			
		||||
  if (X_.numel() <= 0) {
 | 
			
		||||
  if (X.numel() <= 0) {
 | 
			
		||||
    return std::make_tuple(X, scale, zero_point);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto numDimensions = X_.ndimension();
 | 
			
		||||
  auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto numDimensions = X.ndimension();
 | 
			
		||||
 | 
			
		||||
  // Create an axis mask for vectorizing and reshaping the scale and zero point tensors
 | 
			
		||||
  // into the same shapes as X along the channel axis.
 | 
			
		||||
  c10::DimVector axis_mask(numDimensions);
 | 
			
		||||
  for (const auto i : c10::irange(numDimensions)) {
 | 
			
		||||
    axis_mask[i] = (i == axis) ? X_.size(axis) : 1;
 | 
			
		||||
    axis_mask[i] = (i == axis) ? X.size(axis) : 1;
 | 
			
		||||
  }
 | 
			
		||||
  auto X_shape = X_.sizes();
 | 
			
		||||
  auto scale_vectorized = scale_.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
 | 
			
		||||
  auto X_shape = X.sizes();
 | 
			
		||||
  auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
 | 
			
		||||
  auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
 | 
			
		||||
 | 
			
		||||
  auto iter = TensorIteratorConfig()
 | 
			
		||||
    .add_output(dX)
 | 
			
		||||
    .add_output(dScale_vec)
 | 
			
		||||
    .add_output(dZeroPoint_vec)
 | 
			
		||||
    .add_input(X_)
 | 
			
		||||
    .add_input(dY_)
 | 
			
		||||
    .add_input(X)
 | 
			
		||||
    .add_input(dY)
 | 
			
		||||
    .add_input(scale_vectorized)
 | 
			
		||||
    .add_input(zero_point_vectorized)
 | 
			
		||||
    .build();
 | 
			
		||||
 | 
			
		||||
  fake_quant_grad_learnable_channel_stub(
 | 
			
		||||
    X_.device().type(), iter, quant_min, quant_max, grad_factor);
 | 
			
		||||
    X.device().type(), iter, quant_min, quant_max, grad_factor);
 | 
			
		||||
 | 
			
		||||
  auto numElements = X_.ndimension() - 1;
 | 
			
		||||
  auto numElements = X.ndimension() - 1;
 | 
			
		||||
 | 
			
		||||
  // Create a collection of axes that include all but the channel axis for
 | 
			
		||||
  // reduction when summing over the dScale and dZeroPoint tensors.
 | 
			
		||||
 | 
			
		||||
@ -184,23 +184,15 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
 | 
			
		||||
          0 & \text{ else }
 | 
			
		||||
        \end{cases}
 | 
			
		||||
  */
 | 
			
		||||
 | 
			
		||||
  bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
 | 
			
		||||
 | 
			
		||||
  at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
 | 
			
		||||
  at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
 | 
			
		||||
  at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
 | 
			
		||||
  at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
 | 
			
		||||
 | 
			
		||||
  float scale_val = scale_[0].item<float>();
 | 
			
		||||
  float scale_val = scale[0].item<float>();
 | 
			
		||||
  float inv_scale_val = 1.0f / scale_val;
 | 
			
		||||
  int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false);
 | 
			
		||||
  int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size");
 | 
			
		||||
  TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
 | 
			
		||||
  TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      quant_min <= 0 && quant_max >= 0,
 | 
			
		||||
      "`quant_min` should be less than or \
 | 
			
		||||
@ -208,28 +200,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      zero_point_val >= quant_min && zero_point_val <= quant_max,
 | 
			
		||||
      "`zero_point` must be between `quant_min` and `quant_max`.");
 | 
			
		||||
  if (X_.numel() <= 0) {
 | 
			
		||||
  if (X.numel() <= 0) {
 | 
			
		||||
    return std::make_tuple(X, scale, zero_point);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
  auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
 | 
			
		||||
 | 
			
		||||
  auto iter = TensorIteratorConfig()
 | 
			
		||||
    .add_output(dX)
 | 
			
		||||
    .add_output(dScale_vec)
 | 
			
		||||
    .add_output(dZeroPoint_vec)
 | 
			
		||||
    .add_input(X_)
 | 
			
		||||
    .add_input(dY_)
 | 
			
		||||
    .add_input(X)
 | 
			
		||||
    .add_input(dY)
 | 
			
		||||
    .build();
 | 
			
		||||
 | 
			
		||||
  fake_quant_grad_learnable_tensor_stub(
 | 
			
		||||
    X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
 | 
			
		||||
    X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
 | 
			
		||||
 | 
			
		||||
  // The total sums over the scale and zero point gradient vectors are what will be returned in the end.
 | 
			
		||||
  auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device());
 | 
			
		||||
  auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device());
 | 
			
		||||
  auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
 | 
			
		||||
  auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
 | 
			
		||||
 | 
			
		||||
  return std::make_tuple(dX, dScale, dZeroPoint);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,6 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/native/SparseTensorUtils.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
#include <ATen/native/sparse/SparseStubs.h>
 | 
			
		||||
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -15,8 +13,6 @@
 | 
			
		||||
#include <ATen/ops/mul_native.h>
 | 
			
		||||
#include <ATen/ops/empty_native.h>
 | 
			
		||||
#include <ATen/ops/zeros_native.h>
 | 
			
		||||
#include <ATen/ops/ones_like.h>
 | 
			
		||||
#include <ATen/ops/argsort.h>
 | 
			
		||||
#include <ATen/ops/result_type.h>
 | 
			
		||||
#include <ATen/ops/copy_sparse_to_sparse.h>
 | 
			
		||||
#include <ATen/ops/mul.h>
 | 
			
		||||
@ -440,137 +436,4 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
using OptTensor = std::optional<Tensor>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static void sparse_mask_apply_out_mps_kernel(
 | 
			
		||||
    Tensor& result,
 | 
			
		||||
    const Tensor& src_in,
 | 
			
		||||
    const Tensor& mask_in,
 | 
			
		||||
    bool accumulate_matches,
 | 
			
		||||
    bool require_same_sizes,
 | 
			
		||||
    bool coalesce_mask) {
 | 
			
		||||
  TORCH_CHECK(src_in.is_sparse() && mask_in.is_sparse(),
 | 
			
		||||
              "sparse_mask: expected both inputs to be sparse COO");
 | 
			
		||||
  TORCH_CHECK(src_in.is_mps() && mask_in.is_mps(),
 | 
			
		||||
              "sparse_mask: expected tensors to be on MPS device");
 | 
			
		||||
  TORCH_CHECK(src_in.sparse_dim() == mask_in.sparse_dim(),
 | 
			
		||||
              "sparse_mask: sparse_dim mismatch: ", src_in.sparse_dim(), " vs ", mask_in.sparse_dim());
 | 
			
		||||
  if (require_same_sizes) {
 | 
			
		||||
    TORCH_CHECK(src_in.sizes().equals(mask_in.sizes()),
 | 
			
		||||
                "sparse_mask: sizes must match exactly (no broadcasting)");
 | 
			
		||||
  }
 | 
			
		||||
  auto src  = src_in.coalesce();
 | 
			
		||||
  auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
 | 
			
		||||
 | 
			
		||||
  const int64_t src_nnz = src._nnz();
 | 
			
		||||
  const int64_t mask_nnz = mask._nnz();
 | 
			
		||||
  const int64_t sd = src.sparse_dim();
 | 
			
		||||
  result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
 | 
			
		||||
 | 
			
		||||
  auto commonDtype = at::result_type(src, mask);
 | 
			
		||||
  TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
 | 
			
		||||
              "Can't convert result type ", commonDtype, " to output ", result.scalar_type());
 | 
			
		||||
 | 
			
		||||
  if (mask_nnz == 0) {
 | 
			
		||||
    alias_into_sparse(
 | 
			
		||||
        result,
 | 
			
		||||
        mask._indices().narrow(1, 0, 0),
 | 
			
		||||
        at::empty({0}, result.options().dtype(result.scalar_type())));
 | 
			
		||||
    result._coalesced_(mask.is_coalesced());
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(sd > 0 || (src_nnz <= 1 && mask_nnz <= 1),
 | 
			
		||||
              "sparse_mask: invalid sparse_dim or nnz");
 | 
			
		||||
 | 
			
		||||
  if (sd == 0) {
 | 
			
		||||
    auto out_indices = mask._indices().narrow(1, 0, 1);
 | 
			
		||||
    auto out_values = src_nnz
 | 
			
		||||
      ? src._values().narrow(0, 0, 1).to(commonDtype)
 | 
			
		||||
      : at::zeros({1}, at::device(result.device()).dtype(commonDtype));
 | 
			
		||||
    alias_into_sparse(result, out_indices, out_values);
 | 
			
		||||
    result._coalesced_(mask.is_coalesced());
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (src_nnz == 0) {
 | 
			
		||||
    auto out_indices = mask._indices().contiguous();
 | 
			
		||||
    auto src_values  = src._values().to(commonDtype);
 | 
			
		||||
    auto out_val_sizes = src_values.sizes().vec();
 | 
			
		||||
    out_val_sizes[0] = mask_nnz;
 | 
			
		||||
    auto out_values = at::zeros(out_val_sizes, src_values.options());
 | 
			
		||||
    alias_into_sparse(result, out_indices, out_values);
 | 
			
		||||
    result._coalesced_(mask.is_coalesced());
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto mask_indices = mask._indices().contiguous();
 | 
			
		||||
  auto src_indices = src._indices().contiguous();
 | 
			
		||||
  auto src_values = src._values().to(commonDtype).contiguous();
 | 
			
		||||
 | 
			
		||||
  auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
 | 
			
		||||
  auto src_keys  = flatten_indices(src_indices,  src.sizes().slice(0, sd)).contiguous();
 | 
			
		||||
 | 
			
		||||
  const bool A_is_src = (src_nnz <= mask_nnz);
 | 
			
		||||
  const int64_t lenA = A_is_src ? src_nnz  : mask_nnz;
 | 
			
		||||
  const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
 | 
			
		||||
  auto A_keys = A_is_src ? src_keys  : mask_keys;
 | 
			
		||||
  auto B_keys = A_is_src ? mask_keys : src_keys;
 | 
			
		||||
 | 
			
		||||
  const auto device = result.device();
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
 | 
			
		||||
  auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
 | 
			
		||||
  auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
 | 
			
		||||
      auto enc = stream->commandEncoder();
 | 
			
		||||
      [enc setComputePipelineState:pso];
 | 
			
		||||
      mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
 | 
			
		||||
                  static_cast<uint32_t>(lenB), A_is_src);
 | 
			
		||||
      mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
 | 
			
		||||
 | 
			
		||||
  auto out_val_sizes = src_values.sizes().vec();
 | 
			
		||||
  out_val_sizes[0] = mask_nnz;
 | 
			
		||||
  auto out_values = at::zeros(out_val_sizes, src_values.options());
 | 
			
		||||
 | 
			
		||||
  if (M > 0) {
 | 
			
		||||
    auto src_match = outA_idx.narrow(0, 0, M);
 | 
			
		||||
    auto mask_match = outB_idx.narrow(0, 0, M);
 | 
			
		||||
 | 
			
		||||
    auto src_rows = src_values.index_select(0, src_match);
 | 
			
		||||
    if (accumulate_matches) {
 | 
			
		||||
      out_values.index_add_(0, mask_match, src_rows);
 | 
			
		||||
    } else {
 | 
			
		||||
      out_values.index_copy_(0, mask_match, src_rows);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  alias_into_sparse(result, mask_indices, out_values);
 | 
			
		||||
  result._coalesced_(mask.is_coalesced());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void sparse_mask_intersection_out_mps_kernel(
 | 
			
		||||
    Tensor& result,
 | 
			
		||||
    const Tensor& lhs,
 | 
			
		||||
    const Tensor& rhs,
 | 
			
		||||
    const OptTensor& = std::nullopt) {
 | 
			
		||||
  sparse_mask_apply_out_mps_kernel(
 | 
			
		||||
      result,
 | 
			
		||||
      /*src_in=*/lhs,
 | 
			
		||||
      /*mask_in=*/rhs,
 | 
			
		||||
      /*accumulate_matches=*/false,
 | 
			
		||||
      /*require_same_sizes=*/false,
 | 
			
		||||
      /*coalesce_mask=*/false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
@ -3,9 +3,6 @@
 | 
			
		||||
using namespace metal;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
template <typename T> struct MulAccum { using type = float; };
 | 
			
		||||
template <> struct MulAccum<float2> { using type = float2; };
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void dense_sparse_mul_kernel(
 | 
			
		||||
    device const T* dense         [[buffer(0)]],
 | 
			
		||||
@ -32,9 +29,8 @@ kernel void dense_sparse_mul_kernel(
 | 
			
		||||
  ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
 | 
			
		||||
  ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
 | 
			
		||||
 | 
			
		||||
  using accum_t = typename MulAccum<T>::type;
 | 
			
		||||
  const accum_t a = static_cast<accum_t>(values[val_idx]);
 | 
			
		||||
  const accum_t b = static_cast<accum_t>(dense[dense_idx]);
 | 
			
		||||
  const auto a = static_cast<float>(values[val_idx]);
 | 
			
		||||
  const auto b = static_cast<float>(dense[dense_idx]);
 | 
			
		||||
  out_values[val_idx] = static_cast<T>(a * b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -134,8 +130,6 @@ kernel void fused_gather_mul_kernel(
 | 
			
		||||
INSTANTIATE_DENSE_SPARSE_MUL(float);
 | 
			
		||||
INSTANTIATE_DENSE_SPARSE_MUL(half);
 | 
			
		||||
INSTANTIATE_DENSE_SPARSE_MUL(bfloat);
 | 
			
		||||
INSTANTIATE_DENSE_SPARSE_MUL(long);
 | 
			
		||||
INSTANTIATE_DENSE_SPARSE_MUL(float2);
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_FUSED_GATHER_MUL(DTYPE)                                  \
 | 
			
		||||
  template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void      \
 | 
			
		||||
 | 
			
		||||
@ -80,19 +80,3 @@ TEST(XpuGeneratorTest, testMultithreadingGetSetCurrentSeed) {
 | 
			
		||||
  t2.join();
 | 
			
		||||
  EXPECT_EQ(gen1.current_seed(), initial_seed+3);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(XpuGeneratorTest, testRNGForking) {
 | 
			
		||||
  // See Note [Acquire lock when using random generators]
 | 
			
		||||
  if (!at::xpu::is_available()) return;
 | 
			
		||||
  auto default_gen = at::xpu::detail::getDefaultXPUGenerator();
 | 
			
		||||
  auto current_gen = at::xpu::detail::createXPUGenerator();
 | 
			
		||||
  {
 | 
			
		||||
    std::lock_guard<std::mutex> lock(default_gen.mutex());
 | 
			
		||||
    current_gen = default_gen.clone(); // capture the current state of default generator
 | 
			
		||||
  }
 | 
			
		||||
  auto target_value = at::randn({1000}, at::kXPU);
 | 
			
		||||
  // Dramatically alter the internal state of the main generator
 | 
			
		||||
  auto x = at::randn({100000}, at::kXPU);
 | 
			
		||||
  auto forked_value = at::randn({1000}, current_gen, at::kXPU);
 | 
			
		||||
  ASSERT_EQ(target_value.sum().item<double>(), forked_value.sum().item<double>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,45 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
 | 
			
		||||
struct PhiloxXpuState {
 | 
			
		||||
  PhiloxXpuState() = default;
 | 
			
		||||
  PhiloxXpuState(uint64_t seed, uint64_t offset) {
 | 
			
		||||
    seed_.val = seed;
 | 
			
		||||
    offset_.val = offset;
 | 
			
		||||
  }
 | 
			
		||||
  // for graph capture
 | 
			
		||||
  PhiloxXpuState(
 | 
			
		||||
      int64_t* seed,
 | 
			
		||||
      int64_t* offset_extragraph,
 | 
			
		||||
      uint32_t offset_intragraph) {
 | 
			
		||||
    seed_.ptr = seed;
 | 
			
		||||
    offset_.ptr = offset_extragraph;
 | 
			
		||||
    offset_intragraph_ = offset_intragraph;
 | 
			
		||||
    captured_ = true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  union Payload {
 | 
			
		||||
    uint64_t val;
 | 
			
		||||
    int64_t* ptr;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  Payload seed_{};
 | 
			
		||||
  Payload offset_{};
 | 
			
		||||
  uint32_t offset_intragraph_ = 0;
 | 
			
		||||
  bool captured_ = false;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace xpu::philox {
 | 
			
		||||
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxXpuState arg) {
 | 
			
		||||
  if (arg.captured_) {
 | 
			
		||||
    return std::make_tuple(
 | 
			
		||||
        static_cast<uint64_t>(*arg.seed_.ptr),
 | 
			
		||||
        static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
 | 
			
		||||
  } else {
 | 
			
		||||
    return std::make_tuple(arg.seed_.val, arg.offset_.val);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace xpu::philox
 | 
			
		||||
} // namespace at
 | 
			
		||||
@ -1,14 +1,9 @@
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/Tensor.h>
 | 
			
		||||
#include <ATen/Utils.h>
 | 
			
		||||
#include <ATen/xpu/XPUGeneratorImpl.h>
 | 
			
		||||
#include <ATen/xpu/XPUGraphsUtils.h>
 | 
			
		||||
#include <c10/core/StreamGuard.h>
 | 
			
		||||
#include <c10/util/CallOnce.h>
 | 
			
		||||
#include <c10/xpu/XPUFunctions.h>
 | 
			
		||||
 | 
			
		||||
constexpr uint64_t PHILOX_ROUND_SIZE = 4;
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace xpu::detail {
 | 
			
		||||
namespace {
 | 
			
		||||
@ -63,82 +58,29 @@ Generator createXPUGenerator(DeviceIndex device) {
 | 
			
		||||
 | 
			
		||||
} // namespace xpu::detail
 | 
			
		||||
 | 
			
		||||
// Creates a clone of this XPU Generator State.
 | 
			
		||||
c10::intrusive_ptr<XPUGeneratorState> XPUGeneratorState::clone() {
 | 
			
		||||
  return make_intrusive<XPUGeneratorState>(
 | 
			
		||||
      seed_, philox_offset_per_thread_, offset_intragraph_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Function to increase the internal offset based on the specified increment.
 | 
			
		||||
void XPUGeneratorState::increase(uint64_t increment) {
 | 
			
		||||
  increment = ((increment + PHILOX_ROUND_SIZE - 1) / PHILOX_ROUND_SIZE) *
 | 
			
		||||
      PHILOX_ROUND_SIZE;
 | 
			
		||||
  if (at::xpu::currentStreamCaptureStatus() !=
 | 
			
		||||
      at::xpu::CaptureStatus::Executing) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        capturing_,
 | 
			
		||||
        "Attempt to increase offset for a XPU generator not in capture mode.");
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
 | 
			
		||||
        "Increment causes overflow in the offset value.");
 | 
			
		||||
    offset_intragraph_ += increment;
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        !capturing_,
 | 
			
		||||
        "Offset increment outside graph capture encountered unexpectedly.");
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
        philox_offset_per_thread_ % 4 == 0,
 | 
			
		||||
        "RNG offset must be a multiple of 4.");
 | 
			
		||||
    philox_offset_per_thread_ += increment;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index)
 | 
			
		||||
    : GeneratorImpl{
 | 
			
		||||
          Device(DeviceType::XPU, device_index),
 | 
			
		||||
          DispatchKeySet(c10::DispatchKey::XPU)} {
 | 
			
		||||
  at::xpu::assertNotCapturing("Cannot construct a new XPUGeneratorImpl");
 | 
			
		||||
  state_ = make_intrusive<XPUGeneratorState>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XPUGeneratorImpl::XPUGeneratorImpl(
 | 
			
		||||
    DeviceIndex device_index,
 | 
			
		||||
    intrusive_ptr<XPUGeneratorState> state)
 | 
			
		||||
    : GeneratorImpl{Device(DeviceType::XPU, device_index), DispatchKeySet(c10::DispatchKey::XPU)},
 | 
			
		||||
      state_(std::move(state)) {}
 | 
			
		||||
          DispatchKeySet(c10::DispatchKey::XPU)} {}
 | 
			
		||||
 | 
			
		||||
void XPUGeneratorImpl::set_current_seed(uint64_t seed) {
 | 
			
		||||
  if (C10_LIKELY(
 | 
			
		||||
          at::xpu::currentStreamCaptureStatus() ==
 | 
			
		||||
          at::xpu::CaptureStatus::Executing)) {
 | 
			
		||||
    state_->seed_ = seed;
 | 
			
		||||
    state_->philox_offset_per_thread_ = 0;
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        state_->seed_ == seed,
 | 
			
		||||
        "XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
 | 
			
		||||
  }
 | 
			
		||||
  seed_ = seed;
 | 
			
		||||
  set_philox_offset_per_thread(0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void XPUGeneratorImpl::set_offset(uint64_t offset) {
 | 
			
		||||
  at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::set_offset");
 | 
			
		||||
  set_philox_offset_per_thread(offset);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
uint64_t XPUGeneratorImpl::get_offset() const {
 | 
			
		||||
  at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::get_offset");
 | 
			
		||||
  return state_->philox_offset_per_thread_;
 | 
			
		||||
  return philox_offset_per_thread_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
uint64_t XPUGeneratorImpl::current_seed() const {
 | 
			
		||||
  at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::current_seed");
 | 
			
		||||
  return state_->seed_;
 | 
			
		||||
  return seed_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
uint64_t XPUGeneratorImpl::seed() {
 | 
			
		||||
  at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::seed");
 | 
			
		||||
  auto random = c10::detail::getNonDeterministicRandom(true);
 | 
			
		||||
  this->set_current_seed(random);
 | 
			
		||||
  return random;
 | 
			
		||||
@ -168,65 +110,39 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
 | 
			
		||||
  at::xpu::assertNotCapturing(
 | 
			
		||||
      "Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
 | 
			
		||||
  static const size_t seed_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t offset_size = sizeof(uint64_t);
 | 
			
		||||
  static const size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  at::detail::check_rng_state(new_state);
 | 
			
		||||
 | 
			
		||||
  bool no_philox_seed = false;
 | 
			
		||||
  auto new_state_size = new_state.numel();
 | 
			
		||||
  if (new_state_size == total_size - offset_size) {
 | 
			
		||||
    no_philox_seed = true;
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
 | 
			
		||||
 | 
			
		||||
  uint64_t input_seed = 0;
 | 
			
		||||
  uint64_t input_seed;
 | 
			
		||||
  auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
 | 
			
		||||
  memcpy(&input_seed, new_rng_state, seed_size);
 | 
			
		||||
  this->set_current_seed(input_seed);
 | 
			
		||||
  uint64_t philox_offset = 0;
 | 
			
		||||
  if (!no_philox_seed) {
 | 
			
		||||
    memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
 | 
			
		||||
  }
 | 
			
		||||
  uint64_t philox_offset;
 | 
			
		||||
  memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
 | 
			
		||||
  this->set_philox_offset_per_thread(philox_offset);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
 | 
			
		||||
  TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
 | 
			
		||||
  state_->philox_offset_per_thread_ = offset;
 | 
			
		||||
  philox_offset_per_thread_ = offset;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
uint64_t XPUGeneratorImpl::philox_offset_per_thread() const {
 | 
			
		||||
  return state_->philox_offset_per_thread_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PhiloxXpuState XPUGeneratorImpl::philox_xpu_state(uint64_t increment) {
 | 
			
		||||
  if (at::xpu::currentStreamCaptureStatus() !=
 | 
			
		||||
      at::xpu::CaptureStatus::Executing) {
 | 
			
		||||
    uint32_t offset = state_->offset_intragraph_;
 | 
			
		||||
    state_->increase(increment);
 | 
			
		||||
    return PhiloxXpuState(
 | 
			
		||||
        state_->seed_extragraph_.data_ptr<int64_t>(),
 | 
			
		||||
        state_->offset_extragraph_.data_ptr<int64_t>(),
 | 
			
		||||
        offset);
 | 
			
		||||
  } else {
 | 
			
		||||
    uint64_t offset = state_->philox_offset_per_thread_;
 | 
			
		||||
    state_->increase(increment);
 | 
			
		||||
    return PhiloxXpuState(state_->seed_, offset);
 | 
			
		||||
  }
 | 
			
		||||
  return philox_offset_per_thread_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs(
 | 
			
		||||
    uint64_t increment) {
 | 
			
		||||
  at::xpu::assertNotCapturing(
 | 
			
		||||
      "Refactor this op to use XPUGeneratorImpl::philox_xpu_state. Cannot call XPUGeneratorImpl::philox_engine_inputs");
 | 
			
		||||
  uint64_t offset = state_->philox_offset_per_thread_;
 | 
			
		||||
  state_->increase(increment);
 | 
			
		||||
  return std::make_pair(state_->seed_, offset);
 | 
			
		||||
  increment = ((increment + 3) / 4) * 4;
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
 | 
			
		||||
  uint64_t offset = this->philox_offset_per_thread_;
 | 
			
		||||
  this->philox_offset_per_thread_ += increment;
 | 
			
		||||
  return std::make_pair(this->seed_, offset);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DeviceType XPUGeneratorImpl::device_type() {
 | 
			
		||||
@ -238,8 +154,9 @@ std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const {
 | 
			
		||||
  at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::clone_impl");
 | 
			
		||||
  auto gen = new XPUGeneratorImpl(this->device().index(), state_->clone());
 | 
			
		||||
  auto gen = new XPUGeneratorImpl(this->device().index());
 | 
			
		||||
  gen->set_current_seed(this->seed_);
 | 
			
		||||
  gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
 | 
			
		||||
  return gen;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,43 +1,12 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/core/Generator.h>
 | 
			
		||||
#include <ATen/core/TensorBase.h>
 | 
			
		||||
#include <ATen/xpu/PhiloxXpuState.h>
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
 | 
			
		||||
namespace xpu {
 | 
			
		||||
struct XPUGraph;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct XPUGeneratorState : public c10::intrusive_ptr_target {
 | 
			
		||||
  uint64_t seed_;
 | 
			
		||||
  uint64_t philox_offset_per_thread_;
 | 
			
		||||
  uint32_t offset_intragraph_;
 | 
			
		||||
  bool capturing_{};
 | 
			
		||||
  at::TensorBase seed_extragraph_{};
 | 
			
		||||
  at::TensorBase offset_extragraph_{};
 | 
			
		||||
 | 
			
		||||
  XPUGeneratorState(
 | 
			
		||||
      uint64_t seed = default_rng_seed_val,
 | 
			
		||||
      uint64_t philox_offset_per_thread = 0,
 | 
			
		||||
      uint32_t offset_intragraph = 0)
 | 
			
		||||
      : seed_(seed),
 | 
			
		||||
        philox_offset_per_thread_(philox_offset_per_thread),
 | 
			
		||||
        offset_intragraph_(offset_intragraph) {}
 | 
			
		||||
 | 
			
		||||
  void increase(uint64_t increment);
 | 
			
		||||
 | 
			
		||||
  c10::intrusive_ptr<XPUGeneratorState> clone();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
 | 
			
		||||
  // Constructors
 | 
			
		||||
  XPUGeneratorImpl(DeviceIndex device_index = -1);
 | 
			
		||||
  XPUGeneratorImpl(
 | 
			
		||||
      DeviceIndex device_index,
 | 
			
		||||
      c10::intrusive_ptr<XPUGeneratorState> state_);
 | 
			
		||||
  ~XPUGeneratorImpl() override = default;
 | 
			
		||||
 | 
			
		||||
  // XPUGeneratorImpl methods
 | 
			
		||||
@ -49,18 +18,15 @@ struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
 | 
			
		||||
  uint64_t seed() override;
 | 
			
		||||
  void set_state(const c10::TensorImpl& new_state) override;
 | 
			
		||||
  c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
 | 
			
		||||
 | 
			
		||||
  void set_philox_offset_per_thread(uint64_t offset);
 | 
			
		||||
  uint64_t philox_offset_per_thread() const;
 | 
			
		||||
 | 
			
		||||
  PhiloxXpuState philox_xpu_state(uint64_t increment);
 | 
			
		||||
  // will remove once all ops are refactored to use philox_xpu_state.
 | 
			
		||||
  std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
 | 
			
		||||
  static c10::DeviceType device_type();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  XPUGeneratorImpl* clone_impl() const override;
 | 
			
		||||
  c10::intrusive_ptr<XPUGeneratorState> state_;
 | 
			
		||||
  uint64_t seed_ = default_rng_seed_val;
 | 
			
		||||
  uint64_t philox_offset_per_thread_ = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace xpu::detail {
 | 
			
		||||
 | 
			
		||||
@ -1,22 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/xpu/XPUGraphsC10Utils.h>
 | 
			
		||||
 | 
			
		||||
namespace at::xpu {
 | 
			
		||||
 | 
			
		||||
inline CaptureStatus currentStreamCaptureStatus() {
 | 
			
		||||
  return c10::xpu::currentStreamCaptureStatusMayInitCtx();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void assertNotCapturing(const std::string& attempt) {
 | 
			
		||||
  auto status = currentStreamCaptureStatus();
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      status == CaptureStatus::Executing,
 | 
			
		||||
      attempt,
 | 
			
		||||
      " during XPU graph capture. If you need this call to be captured, "
 | 
			
		||||
      "please file an issue. "
 | 
			
		||||
      "Current xpuStreamCaptureStatus: ",
 | 
			
		||||
      status);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::xpu
 | 
			
		||||
@ -15,8 +15,6 @@ flaky_models = {
 | 
			
		||||
    "moondream",  # discovered in https://github.com/pytorch/pytorch/pull/159291
 | 
			
		||||
    # discovered in https://github.com/pytorch/pytorch/issues/161419. Its not flaky but really hard to repro, so skipping it
 | 
			
		||||
    "mobilenetv3_large_100",
 | 
			
		||||
    # https://github.com/pytorch/pytorch/issues/163670
 | 
			
		||||
    "vision_maskrcnn",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,10 +61,6 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
 | 
			
		||||
                "swsl_resnext101_32x16d",
 | 
			
		||||
                "torchrec_dlrm",
 | 
			
		||||
                "vgg16",
 | 
			
		||||
                "BERT_pytorch",
 | 
			
		||||
                "coat_lite_mini",
 | 
			
		||||
                "mobilenet_v3_large",
 | 
			
		||||
                "vision_maskrcnn",
 | 
			
		||||
                # LLM
 | 
			
		||||
                "meta-llama/Llama-3.2-1B",
 | 
			
		||||
                "google/gemma-2-2b",
 | 
			
		||||
 | 
			
		||||
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -290,7 +290,7 @@ vgg16,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vision_maskrcnn,pass,21
 | 
			
		||||
vision_maskrcnn,pass,20
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -206,7 +206,7 @@ vgg16,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vision_maskrcnn,pass,40
 | 
			
		||||
vision_maskrcnn,pass,39
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -130,7 +130,7 @@ maml_omniglot,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,0
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -130,7 +130,7 @@ maml_omniglot,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,0
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -254,7 +254,7 @@ vgg16,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vision_maskrcnn,fail_accuracy,30
 | 
			
		||||
vision_maskrcnn,fail_accuracy,29
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,7
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -290,7 +290,7 @@ vgg16,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vision_maskrcnn,pass,21
 | 
			
		||||
vision_maskrcnn,pass,20
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -202,7 +202,7 @@ vgg16,pass,6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vision_maskrcnn,pass,40
 | 
			
		||||
vision_maskrcnn,pass,39
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -114,7 +114,7 @@ maml_omniglot,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,0
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -114,7 +114,7 @@ maml_omniglot,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,0
 | 
			
		||||
microbench_unbacked_tolist_sum,pass,1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_base_distilled_patch16_224,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
deit_tiny_patch16_224.fb_in1k,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
dm_nfnet_f0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
visformer_small,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch14_dinov2.lvd142m,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vit_base_patch16_siglip_256,pass,0
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -242,7 +242,7 @@ stable_diffusion_unet,pass_due_to_skip,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
torch_multimodal_clip,pass,0
 | 
			
		||||
torch_multimodal_clip,pass,3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -254,7 +254,7 @@ vgg16,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
vision_maskrcnn,fail_accuracy,30
 | 
			
		||||
vision_maskrcnn,fail_accuracy,29
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user