mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 00:54:52 +08:00
Compare commits
1 Commits
main-enabl
...
msaroufim-
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b1d8047c9 |
@ -143,7 +143,7 @@ def sample_vllm_test_library():
|
||||
"pytest -v -s compile/test_decorator.py",
|
||||
],
|
||||
},
|
||||
"vllm_language_model_test_extended_generation_28_failure_test": {
|
||||
"vllm_languagde_model_test_extended_generation_28_failure_test": {
|
||||
"title": "Language Models Test (Extended Generation) 2.8 release failure",
|
||||
"id": "vllm_languagde_model_test_extended_generation_28_failure_test",
|
||||
"package_install": [
|
||||
|
||||
@ -63,7 +63,7 @@ class VllmBuildParameters:
|
||||
# DOCKERFILE_PATH: path to Dockerfile used when use_local_dockerfile is True"
|
||||
use_local_dockerfile: bool = env_bool_field("USE_LOCAL_DOCKERFILE", True)
|
||||
dockerfile_path: Path = env_path_field(
|
||||
"DOCKERFILE_PATH", ".github/ci_configs/vllm/Dockerfile"
|
||||
"DOCKERFILE_PATH", ".github/ci_configs/vllm/Dockerfile.tmp_vllm"
|
||||
)
|
||||
|
||||
# the cleaning script to remove torch dependencies from pip
|
||||
|
||||
@ -187,22 +187,19 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
|
||||
export USE_CUFILE=0
|
||||
else
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/cuda/lib64/libnvToolsExt.so.1"
|
||||
"/usr/local/cuda/lib64/libcublas.so.12"
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12"
|
||||
"/usr/local/cuda/lib64/libcudart.so.12"
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12"
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
|
||||
DEPS_SONAME+=(
|
||||
"libnvToolsExt.so.1"
|
||||
"libcublas.so.12"
|
||||
"libcublasLt.so.12"
|
||||
"libcudart.so.12"
|
||||
"libnvrtc.so.12"
|
||||
"libcupti.so.12")
|
||||
|
||||
if [[ $CUDA_VERSION != 12.9* ]]; then
|
||||
DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
|
||||
DEPS_SONAME+=("libnvToolsExt.so.1")
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Using nvidia libs from pypi."
|
||||
|
||||
1
.github/ISSUE_TEMPLATE/ci-sev.md
vendored
1
.github/ISSUE_TEMPLATE/ci-sev.md
vendored
@ -8,7 +8,6 @@ assignees: ''
|
||||
---
|
||||
|
||||
> NOTE: Remember to label this issue with "`ci: sev`"
|
||||
> If you want autorevert to be disabled, keep the ci: disable-autorevert label
|
||||
|
||||
<!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open -->
|
||||
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/disable-autorevert.md
vendored
4
.github/ISSUE_TEMPLATE/disable-autorevert.md
vendored
@ -1,7 +1,7 @@
|
||||
---
|
||||
name: "D❌\U0001F519 ISABLE AUTOREVERT"
|
||||
name: DISABLE AUTOREVERT
|
||||
about: Disables autorevert when open
|
||||
title: "[DISABLE AUTOREVERT]"
|
||||
title: "❌\U0001F519 [DISABLE AUTOREVERT]"
|
||||
labels: 'ci: disable-autorevert'
|
||||
assignees: ''
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ runs:
|
||||
cd .ci/lumen_cli
|
||||
python3 -m pip install -e .
|
||||
)
|
||||
MAX_JOBS="$(nproc --ignore=10)"
|
||||
MAX_JOBS="$(nproc --ignore=6)"
|
||||
export MAX_JOBS
|
||||
|
||||
# Split the comma-separated list and build each target
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
|
||||
87ff22e49ed0e92576c4935ccb8c143daac4a3cd
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||
966da7e46f65d6d49df3e31214470a4fe5cc8e66
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
e5192819208c4d68194844b7dfafbc00020d0dea
|
||||
0ad9951c416d33c5da4f7a504fb162cbe62386f5
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
0fa6e3129e61143224663e1ec67980d12b7ec4eb
|
||||
2a9138a26ee257fef05310ad3fecf7c55fe80d73
|
||||
|
||||
@ -1,41 +1,59 @@
|
||||
# TODO(elainwy): remove this file after the torch nightly dockerfile is in sync in vllm repo
|
||||
# The vLLM Dockerfile is used to construct vLLM image against torch nightly and torch main that can be directly used for testing
|
||||
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine,
|
||||
# by default, it uses the torch-nightly-base stage from this docker image
|
||||
ARG BUILD_BASE_IMAGE=torch-nightly-base
|
||||
|
||||
# FINAL_BASE_IMAGE: used to set up vllm-instaled environment and build flashinfer,
|
||||
# by default, it uses devel-ubuntu22.04 official image.
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
|
||||
# The logic is copied from https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile
|
||||
ARG GET_PIP_URL="https://bootstrap.pypa.io/get-pip.py"
|
||||
|
||||
|
||||
#################### TORCH NIGHTLY BASE IMAGE ####################
|
||||
# A base image for building vLLM with devel ubuntu 22.04, this is mainly used to build vllm in vllm builtkite ci
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 as torch-nightly-base
|
||||
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install system dependencies and uv, then create Python virtual environment
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
|
||||
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
|
||||
&& ln -s /opt/venv/bin/python3 /usr/bin/python3 \
|
||||
&& ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
|
||||
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
# as it was causing spam when compiling the CUTLASS kernels
|
||||
RUN apt-get install -y gcc-10 g++-10
|
||||
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
|
||||
RUN <<EOF
|
||||
gcc --version
|
||||
EOF
|
||||
# Ensure gcc >= 10 to avoid CUTLASS issues (bug 92519)
|
||||
RUN current_gcc_version=$(gcc -dumpversion | cut -f1 -d.) && \
|
||||
if command -v apt-get >/dev/null; then \
|
||||
if [ "$current_gcc_version" -lt 10 ]; then \
|
||||
echo "GCC version is $current_gcc_version, installing gcc-10..."; \
|
||||
apt-get update \
|
||||
&& apt-get install -y gcc-10 g++-10 \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 \
|
||||
&& update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100; \
|
||||
else \
|
||||
echo "GCC version is $current_gcc_version, no need to install gcc-10."; \
|
||||
fi \
|
||||
fi \
|
||||
&& gcc --version && g++ --version
|
||||
|
||||
# Install uv for faster pip installs
|
||||
# install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv==0.8.4
|
||||
|
||||
@ -43,32 +61,36 @@ ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
#################### TORCH NIGHTLY BASE IMAGE ####################
|
||||
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# A base image for building vLLM with torch nightly or torch wheels
|
||||
# prepare basic build environment
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
USER root
|
||||
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
# Only work with PyTorch manylinux builder
|
||||
# TODO (huydhn): Only work with PyTorch manylinux builder
|
||||
ENV PATH="/opt/python/cp312-cp312/bin:${PATH}"
|
||||
|
||||
# Install some system dependencies and double check python version
|
||||
RUN if command -v apt-get >/dev/null; then \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git wget sudo vim; \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim; \
|
||||
else \
|
||||
dnf install -y git wget sudo; \
|
||||
dnf install -y git curl wget sudo; \
|
||||
fi \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Install uv for faster pip installs if not existed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv==0.8.4
|
||||
|
||||
if ! python3 -m uv --version >/dev/null 2>&1; then \
|
||||
python3 -m pip install uv==0.8.4; \
|
||||
fi
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
@ -76,15 +98,15 @@ ENV UV_LINK_MODE=copy
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install build and runtime dependencies
|
||||
# install build and runtime dependencies
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
COPY pyproject.toml pyproject.toml
|
||||
|
||||
# Install build and runtime dependencies without stable torch version
|
||||
# install build and runtime dependencies without stable torch version
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
# Default mount file as placeholder, this just avoid the mount error
|
||||
# default mount file as placeholder, this just avoid the mount error
|
||||
# change to a different vllm folder if this does not exist anymore
|
||||
ARG TORCH_WHEELS_PATH="./requirements"
|
||||
ARG PINNED_TORCH_VERSION
|
||||
@ -116,36 +138,56 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/common.txt
|
||||
|
||||
# Must put before installing xformers, so it can install the correct version of xfomrers.
|
||||
ARG xformers_cuda_arch_list='7.5;8.0+PTX;9.0a'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${xformers_cuda_arch_list}
|
||||
|
||||
ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
|
||||
git clone https://github.com/facebookresearch/xformers.git
|
||||
RUN echo ${TORCH_CUDA_ARCH_LIST}
|
||||
RUN echo ${MAX_JOBS}
|
||||
RUN pip freeze | grep -E 'ninja'
|
||||
|
||||
pushd xformers
|
||||
git checkout v0.0.32.post2
|
||||
git submodule update --init --recursive
|
||||
python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose
|
||||
popd
|
||||
# Build xformers with cuda and torch nightly/wheel
|
||||
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
|
||||
# sha for https://github.com/facebookresearch/xformers/tree/v0.0.32.post2
|
||||
ARG XFORMERS_COMMIT=5d4b92a5e5a9c6c6d4878283f47d82e17995b468
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
|
||||
rm -rf xformers
|
||||
BASH
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
echo 'git clone xformers...' \
|
||||
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
|
||||
&& cd xformers \
|
||||
&& git checkout ${XFORMERS_COMMIT} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo 'finish git clone xformers...' \
|
||||
&& rm -rf build \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
|
||||
&& cd .. \
|
||||
&& rm -rf xformers
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system xformers-dist/*.whl
|
||||
uv pip install --system xformers-dist/*.whl --verbose
|
||||
|
||||
# Build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
|
||||
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
|
||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
|
||||
|
||||
RUN cat torch_build_versions.txt
|
||||
RUN pip freeze | grep -E 'torch|xformers|torchvision|torchaudio'
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
# Image used to build vllm wheel
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
@ -155,17 +197,20 @@ ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
# Max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
ARG nvcc_threads=8
|
||||
ARG nvcc_threads=4
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
ARG torch_cuda_arch_list='8.0 8.6 8.9 9.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Use sccache to speed up compilation
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
@ -190,9 +235,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
&& sccache --show-stats; \
|
||||
fi
|
||||
|
||||
ARG torch_cuda_arch_list='8.0 8.6 8.9 9.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
ARG vllm_target_device="cuda"
|
||||
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
@ -206,10 +248,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
export VLLM_DOCKER_BUILD_CONTEXT=1 && \
|
||||
python3 setup.py bdist_wheel --dist-dir=vllm-dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
RUN echo "[INFO] Listing current directory:" && \
|
||||
ls -al && \
|
||||
echo "[INFO] Showing torch_build_versions.txt content:" && \
|
||||
cat torch_build_versions.txt
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
|
||||
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
# Setup clean environment for vLLM for test and api server using ubuntu22.04 with AOT flashinfer
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
USER root
|
||||
|
||||
@ -217,7 +266,7 @@ ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Only work with PyTorch manylinux builder
|
||||
# TODO (huydhn): Only work with PyTorch manylinux builder
|
||||
ENV PATH="/opt/python/cp312-cp312/bin:${PATH}"
|
||||
|
||||
# prepare for environment starts
|
||||
@ -226,19 +275,20 @@ WORKDIR /workspace
|
||||
# Install Python and other dependencies
|
||||
RUN if command -v apt-get >/dev/null; then \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git sudo vim python3-pip; \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION}; \
|
||||
else \
|
||||
dnf install -y git wget sudo; \
|
||||
dnf install -y git curl wget sudo; \
|
||||
fi \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
|
||||
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
|
||||
&& ln -s /opt/venv/bin/python3 /usr/bin/python3 \
|
||||
&& ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
|
||||
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Get the torch versions, and whls used in previous stage
|
||||
# Get the torch versions, and whls used in previous stagtes for consistency
|
||||
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
|
||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
|
||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
|
||||
@ -247,29 +297,33 @@ RUN echo "[INFO] Listing current directory before torch install step:" && \
|
||||
echo "[INFO] Showing torch_build_versions.txt content:" && \
|
||||
cat torch_build_versions.txt
|
||||
|
||||
# Install uv for faster pip installs if not existed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv==0.8.4
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Install build and runtime dependencies, this is needed for flashinfer install
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
RUN python3 use_existing_torch.py
|
||||
RUN cat requirements/build.txt
|
||||
|
||||
# Install uv for faster pip installs if not existed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if ! python3 -m uv --version > /dev/null 2>&1; then \
|
||||
python3 -m pip install uv==0.8.4; \
|
||||
fi
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt
|
||||
|
||||
|
||||
# Default mount file as placeholder, this just avoid the mount error
|
||||
ARG TORCH_WHEELS_PATH="./requirements"
|
||||
# Install torch, torchaudio and torchvision. If TORCH_WHEELS_PATH is default
|
||||
# to ./requirements, it will pull the nightly versions using pip. Otherwise,
|
||||
# it will use the local wheels from TORCH_WHEELS_PATH
|
||||
# Install torch, torchaudio and torchvision
|
||||
# if TORCH_WHEELS_PATH is default "./requirements", it will pull the nightly versions using pip using torch_build_versions.txt
|
||||
# otherwise, it will use the whls from TORCH_WHEELS_PATH from the host machine
|
||||
RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
if [ -n "$TORCH_WHEELS_PATH" ] && [ "$TORCH_WHEELS_PATH" != "./requirements" ] && [ -d "/dist" ] && ls /dist/torch*.whl >/dev/null 2>&1; then \
|
||||
@ -290,14 +344,18 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# Install xformers wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/xformers/*.whl --verbose
|
||||
|
||||
# Build FlashInfer from source
|
||||
# Build flashinfer from source.
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
# install package for build flashinfer
|
||||
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
|
||||
|
||||
RUN pip freeze | grep -E 'setuptools|packaging|build'
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# Build flashinfer for torch nightly from source around 10 mins
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
@ -309,7 +367,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
&& cd .. \
|
||||
&& rm -rf flashinfer
|
||||
|
||||
# Install FlashInfer
|
||||
# install flashinfer python
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system wheels/flashinfer/*.whl --verbose
|
||||
|
||||
@ -319,6 +377,49 @@ RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
|
||||
|
||||
#################### UNITTEST IMAGE #############################
|
||||
FROM vllm-base as test
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
COPY tests/ tests/
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
COPY pyproject.toml pyproject.toml
|
||||
# Install build and runtime dependencies without stable torch version
|
||||
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
|
||||
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
# install packages
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/common.txt
|
||||
# enable fast downloads from hf (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system hf_transfer
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -e tests/vllm_test_utils
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/nightly_torch_test.txt
|
||||
|
||||
# Logging to confirm the torch versions
|
||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
|
||||
|
||||
# Logging to confirm all the packages are installed
|
||||
RUN pip freeze
|
||||
|
||||
#################### UNITTEST IMAGE #############################
|
||||
|
||||
#################### EXPORT STAGE ####################
|
||||
FROM scratch as export-wheels
|
||||
|
||||
4
.github/pytorch-probot.yml
vendored
4
.github/pytorch-probot.yml
vendored
@ -3,7 +3,6 @@ ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
@ -16,8 +15,7 @@ ciflow_push_tags:
|
||||
- ciflow/inductor-micro-benchmark
|
||||
- ciflow/inductor-micro-benchmark-cpu-x86
|
||||
- ciflow/inductor-perf-compare
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
|
||||
- ciflow/inductor-perf-test-nightly-rocm
|
||||
- ciflow/inductor-perf-test-nightly-x86-zen
|
||||
- ciflow/inductor-periodic
|
||||
- ciflow/inductor-rocm
|
||||
|
||||
2
.github/scripts/filter_test_configs.py
vendored
2
.github/scripts/filter_test_configs.py
vendored
@ -512,8 +512,6 @@ def perform_misc_tasks(
|
||||
"keep-going",
|
||||
branch == MAIN_BRANCH
|
||||
or bool(tag and re.match(r"^trunk/[a-f0-9]{40}$", tag))
|
||||
# Pattern for tags created via manual run on HUD
|
||||
or bool(tag and re.match(r"^ciflow/[^/]+/[a-f0-9]{40}$", tag))
|
||||
or check_for_setting(labels, pr_body, "keep-going"),
|
||||
)
|
||||
set_output(
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -2042,6 +2042,10 @@ def validate_revert(
|
||||
f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
|
||||
)
|
||||
|
||||
# Raises exception if matching rule is not found, but ignores all status checks
|
||||
find_matching_merge_rule(
|
||||
pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
|
||||
)
|
||||
commit_sha = get_pr_commit_sha(repo, pr)
|
||||
return (author_login, commit_sha)
|
||||
|
||||
|
||||
62
.github/workflows/b200-distributed.yml
vendored
62
.github/workflows/b200-distributed.yml
vendored
@ -1,62 +0,0 @@
|
||||
name: CI for distributed tests on B200
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/b200-distributed.yml
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/b200-distributed/*
|
||||
schedule:
|
||||
- cron: 46 8 * * * # about 1:46am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
|
||||
with:
|
||||
timeout-minutes: 1200
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
19
.github/workflows/build-vllm-wheel.yml
vendored
19
.github/workflows/build-vllm-wheel.yml
vendored
@ -27,8 +27,9 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [ '3.12' ]
|
||||
# TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
|
||||
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
|
||||
device: [ 'cu128', 'cu129', 'cu130' ]
|
||||
device: [ 'cu128', 'cu129' ]
|
||||
include:
|
||||
- platform: manylinux_2_28_x86_64
|
||||
device: cu128
|
||||
@ -38,10 +39,6 @@ jobs:
|
||||
device: cu129
|
||||
manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
|
||||
runner: linux.12xlarge.memory
|
||||
- platform: manylinux_2_28_x86_64
|
||||
device: cu130
|
||||
manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
|
||||
runner: linux.12xlarge.memory
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu128
|
||||
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
|
||||
@ -50,11 +47,6 @@ jobs:
|
||||
device: cu129
|
||||
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
|
||||
runner: linux.arm64.r7g.12xlarge.memory
|
||||
exclude:
|
||||
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
|
||||
# xformers is update to support 13.0
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu130
|
||||
name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
|
||||
runs-on: ${{ matrix.runner }}
|
||||
timeout-minutes: 480
|
||||
@ -177,12 +169,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
|
||||
device: [ 'cu128', 'cu129', 'cu130' ]
|
||||
exclude:
|
||||
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
|
||||
# xformers is update to support 13.0
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu130
|
||||
device: [ 'cu128', 'cu129' ]
|
||||
env:
|
||||
PLATFORM: ${{ matrix.platform }}
|
||||
BUILD_DEVICE: ${{ matrix.device }}
|
||||
|
||||
@ -1,132 +0,0 @@
|
||||
name: inductor-perf-nightly-rocm-mi300
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi300/*
|
||||
schedule:
|
||||
- cron: 15 0 * * *
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
training:
|
||||
description: Run training (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
inference:
|
||||
description: Run inference (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
default:
|
||||
description: Run inductor_default?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
dynamic:
|
||||
description: Run inductor_dynamic_shapes?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
freezing_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
maxautotune:
|
||||
description: Run inductor_max_autotune?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
linux-jammy-rocm-py3_10-inductor-benchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: rocm-py3_10-inductor-benchmark-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-inductor-benchmark-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: rocm-py3_10-inductor-benchmark-test
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs: linux-jammy-rocm-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# Disable monitor in perf tests for more investigation
|
||||
disable-monitor: true
|
||||
monitor-log-interval: 10
|
||||
monitor-data-collect-interval: 2
|
||||
secrets: inherit
|
||||
@ -1,11 +1,11 @@
|
||||
name: inductor-perf-nightly-rocm-mi355
|
||||
name: inductor-perf-nightly-rocm
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi355/*
|
||||
- ciflow/inductor-perf-test-nightly-rocm/*
|
||||
schedule:
|
||||
- cron: 15 0 * * *
|
||||
- cron: 0 7 * * 0,3
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
@ -59,7 +59,7 @@ on:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355
|
||||
default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
@ -88,27 +88,23 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
23
.github/workflows/operator_benchmark.yml
vendored
23
.github/workflows/operator_benchmark.yml
vendored
@ -7,11 +7,9 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test_mode:
|
||||
type: choice
|
||||
options:
|
||||
- 'short'
|
||||
- 'long'
|
||||
- 'all'
|
||||
required: false
|
||||
type: string
|
||||
default: 'short'
|
||||
description: tag filter for operator benchmarks, options from long, short, all
|
||||
schedule:
|
||||
# Run at 07:00 UTC every Sunday
|
||||
@ -39,7 +37,20 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opbenchmark-on-demand-build:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
|
||||
name: opbenchmark-on-demand-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
8
.github/workflows/trunk.yml
vendored
8
.github/workflows/trunk.yml
vendored
@ -180,13 +180,13 @@ jobs:
|
||||
disable-monitor: false
|
||||
secrets: inherit
|
||||
|
||||
win-vs2022-cuda12_8-py3-build:
|
||||
name: win-vs2022-cuda12.8-py3
|
||||
win-vs2022-cuda12_6-py3-build:
|
||||
name: win-vs2022-cuda12.6-py3
|
||||
uses: ./.github/workflows/_win-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: win-vs2022-cuda12.8-py3
|
||||
cuda-version: "12.8"
|
||||
build-environment: win-vs2022-cuda12.6-py3
|
||||
cuda-version: "12.6"
|
||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
secrets: inherit
|
||||
|
||||
|
||||
4
.github/workflows/vllm.yml
vendored
4
.github/workflows/vllm.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
||||
runner: linux.24xlarge.memory
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "vllm_basic_correctness_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_basic_correctness_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_basic_models_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_entrypoints_test", shard: 1, num_shards: 1,runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_regression_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
{ config: "vllm_pytorch_compilation_unit_tests", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_lora_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_multi_model_test_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu"},
|
||||
{ config: "vllm_language_model_test_extended_generation_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu"},
|
||||
{ config: "vllm_languagde_model_test_extended_generation_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu"},
|
||||
{ config: "vllm_distributed_test_2_gpu_28_failure_test", shard: 1, num_shards: 1, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_lora_test", shard: 0, num_shards: 4, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_lora_test", shard: 1, num_shards: 4, runner: "linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -395,4 +395,3 @@ android/pytorch_android_torchvision/.cxx
|
||||
CLAUDE.local.md
|
||||
/test_*.py
|
||||
/debug_*.py
|
||||
CLAUDE_CONTEXT/
|
||||
|
||||
@ -256,7 +256,6 @@ endif()
|
||||
IF(USE_FBGEMM_GENAI)
|
||||
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
|
||||
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
|
||||
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
@ -293,64 +292,58 @@ IF(USE_FBGEMM_GENAI)
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
else()
|
||||
if(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
elseif(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
endif()
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
endif()
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -699,6 +692,12 @@ if(USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
if(USE_FBGEMM_GENAI)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
|
||||
@ -389,16 +389,37 @@ void fillVersion<DLManagedTensorVersioned>(
|
||||
// constructed out of ATen tensor
|
||||
template <class T>
|
||||
T* toDLPackImpl(const Tensor& src) {
|
||||
auto view = src;
|
||||
|
||||
// Detect whether there is need to normalize the strides
|
||||
// Background: gh-83069
|
||||
//
|
||||
// However, normalizing strides can come at a high-cost
|
||||
// to slow down toDLPack conversion 3x, so we
|
||||
// only normalize if needed.
|
||||
//
|
||||
// The following code detects whether the src follows
|
||||
// a continuous pattern. If the src follows such pattern (common-case)
|
||||
// then we do not need to normalize the strides.
|
||||
bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
|
||||
// less common case, try normalizing the strides
|
||||
if (need_normalize_strides) {
|
||||
// create a new tensor with possibly normalized strides
|
||||
// gh-83069
|
||||
auto shape = src.sizes();
|
||||
view = src.as_strided(shape, {1}, src.storage_offset());
|
||||
}
|
||||
|
||||
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
|
||||
atDLMTensor->handle = src;
|
||||
atDLMTensor->handle = view;
|
||||
atDLMTensor->tensor.manager_ctx = atDLMTensor;
|
||||
atDLMTensor->tensor.deleter = &deleter<T>;
|
||||
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
|
||||
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
|
||||
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
|
||||
fillVersion(&atDLMTensor->tensor);
|
||||
|
||||
|
||||
@ -52,16 +52,16 @@ struct DLPackTraits {};
|
||||
|
||||
template <>
|
||||
struct DLPackTraits<DLManagedTensor> {
|
||||
inline static constexpr const char* capsule = "dltensor";
|
||||
inline static constexpr const char* used = "used_dltensor";
|
||||
inline static const char* capsule = "dltensor";
|
||||
inline static const char* used = "used_dltensor";
|
||||
inline static auto toDLPack = at::toDLPack;
|
||||
inline static auto fromDLPack = at::fromDLPack;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DLPackTraits<DLManagedTensorVersioned> {
|
||||
inline static constexpr const char* capsule = "dltensor_versioned";
|
||||
inline static constexpr const char* used = "used_dltensor_versioned";
|
||||
inline static const char* capsule = "dltensor_versioned";
|
||||
inline static const char* used = "used_dltensor_versioned";
|
||||
inline static auto toDLPack = at::toDLPackVersioned;
|
||||
inline static auto fromDLPack = at::fromDLPackVersioned;
|
||||
};
|
||||
|
||||
@ -42,14 +42,8 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
|
||||
}
|
||||
|
||||
bool torch_function_mode_enabled() {
|
||||
// Manually flatten because gcc is refusing to inline here. Note
|
||||
// that we are still calling __tls_get_addr twice here with GCC,
|
||||
// presumably because of
|
||||
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81501 (which says
|
||||
// the fix ships in GCC 16), but forcing inlining still improves
|
||||
// performance.
|
||||
const auto& ptfs = pythonTorchFunctionState;
|
||||
return ptfs.disabled_state_ != TorchFunctionDisabledState::ALL_DISABLED && !ptfs.stack_.empty();
|
||||
return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
|
||||
PythonTorchFunctionTLS::stack_len() > 0;
|
||||
}
|
||||
|
||||
// This is needed to disambiguate the ternary torch function disabled states
|
||||
|
||||
@ -27,7 +27,6 @@ struct TORCH_API PythonTorchFunctionTLS {
|
||||
TorchFunctionDisabledState disabled_state_ =
|
||||
TorchFunctionDisabledState::ENABLED;
|
||||
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
|
||||
friend TORCH_API bool torch_function_mode_enabled();
|
||||
};
|
||||
|
||||
TORCH_API bool torch_function_mode_enabled();
|
||||
|
||||
@ -624,14 +624,7 @@ struct TORCH_API IValue final {
|
||||
IValue(const c10::SymBool& i) {
|
||||
if (auto mi = i.maybe_as_bool()) {
|
||||
tag = Tag::Bool;
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
payload.u.as_int = *mi;
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
/* due to byteorder if value assigned as_int, as_bool actually is not set correctly */
|
||||
payload.u.as_bool = *mi;
|
||||
#else
|
||||
#error Unexpected or undefined __BYTE_ORDER__
|
||||
#endif
|
||||
} else {
|
||||
tag = Tag::SymBool;
|
||||
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
@ -151,7 +150,6 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
|
||||
BLASType = "unknown";
|
||||
}
|
||||
return BLASType;
|
||||
|
||||
}
|
||||
|
||||
// Similar to Compute Type in GemmRocblas.h
|
||||
@ -246,25 +244,33 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
|
||||
|
||||
namespace detail {
|
||||
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
|
||||
|
||||
if (!config.enabled) {
|
||||
return true; // skip when disabled
|
||||
}
|
||||
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
|
||||
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
||||
// comparison done as 1D tensor
|
||||
at::Tensor ref = at::from_blob(c, {size}, options);
|
||||
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
||||
at::Tensor ref_float = ref.to(at::kFloat);
|
||||
at::Tensor oth_float = oth.to(at::kFloat);
|
||||
|
||||
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
|
||||
if (ok) {
|
||||
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
} else {
|
||||
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
double last_succeed_atol = 1;
|
||||
double last_succeed_rtol = 1;
|
||||
for (auto& atol : atols) {
|
||||
for (auto& rtol : rtols) {
|
||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
||||
last_succeed_atol = atol;
|
||||
last_succeed_rtol = rtol;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ok;
|
||||
if (last_succeed_atol == 1) {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
@ -349,10 +355,8 @@ struct GemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -445,10 +449,8 @@ struct GemmAndBiasParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -544,10 +546,8 @@ struct GemmStridedBatchedParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -663,9 +663,7 @@ struct ScaledGemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
|
||||
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
|
||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
|
||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
|
||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
|
||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
|
||||
@ -173,9 +173,10 @@ All python APIs exist in the `torch.cuda.tunable` module.
|
||||
| get_max_tuning_iterations() -> int | |
|
||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
|
||||
| get_filename() -> str | |
|
||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
|
||||
| get_results() -> Tuple[str, str, str, float] | |
|
||||
| get_validators() -> Tuple[str, str] | |
|
||||
| write_file_on_exit(val: bool) -> None | Default is True. |
|
||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
|
||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
|
||||
|
||||
@ -107,30 +107,14 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
|
||||
}
|
||||
|
||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
|
||||
bool is_new = false;
|
||||
ResultEntry inserted = ResultEntry::Null();
|
||||
std::scoped_lock l{lock_};
|
||||
|
||||
// ---- mutate maps under results lock ----
|
||||
{
|
||||
std::scoped_lock l{lock_};
|
||||
auto& km = results_[op_signature]; // creates if missing
|
||||
is_new = (km.find(params_signature) == km.end());
|
||||
AddImpl(op_signature, params_signature, std::move(best), km);
|
||||
if (is_new) {
|
||||
inserted = km.at(params_signature); // snapshot for I/O after unlocking
|
||||
}
|
||||
}
|
||||
if (!is_new) return; // only write once per unique (op, params)
|
||||
|
||||
TuningContext* ctx = getTuningContext();
|
||||
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
|
||||
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
|
||||
|
||||
if (is_new && realtime_out_ && realtime_out_->good()) {
|
||||
AppendResultLine(op_signature, params_signature, inserted);
|
||||
}
|
||||
auto it = results_.find(op_signature);
|
||||
if (it == results_.end()) {
|
||||
it = results_.insert({op_signature, {}}).first;
|
||||
}
|
||||
|
||||
AddImpl(op_signature, params_signature, std::move(best), it->second);
|
||||
}
|
||||
|
||||
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
||||
@ -166,77 +150,6 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (realtime_out_ && realtime_filename_ != filename) {
|
||||
realtime_out_->flush();
|
||||
realtime_out_->close();
|
||||
realtime_out_.reset();
|
||||
validators_written_ = false;
|
||||
}
|
||||
|
||||
bool file_exists = false;
|
||||
bool file_empty = true;
|
||||
|
||||
{
|
||||
std::ifstream check_file(filename);
|
||||
if (check_file.good()) {
|
||||
file_exists = true;
|
||||
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
|
||||
}
|
||||
}
|
||||
|
||||
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
|
||||
|
||||
if (!realtime_out_->good()) {
|
||||
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
|
||||
realtime_out_.reset();
|
||||
return;
|
||||
}
|
||||
|
||||
if(!file_exists || file_empty) {
|
||||
for(const auto& [key, val] : validators) {
|
||||
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
|
||||
realtime_out_->flush();
|
||||
}
|
||||
validators_written_ = true;
|
||||
|
||||
TUNABLE_LOG2("Wrote validators to realtime output file");
|
||||
}
|
||||
|
||||
realtime_filename_ = filename;
|
||||
}
|
||||
|
||||
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
if(!realtime_out_ || !realtime_out_->good()) {
|
||||
return;
|
||||
}
|
||||
|
||||
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
realtime_out_->flush(); //ensure immediate write to disk
|
||||
|
||||
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
|
||||
}
|
||||
|
||||
void TuningResultsManager::CloseRealtimeAppend() {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
|
||||
if(realtime_out_) {
|
||||
realtime_out_->flush();
|
||||
realtime_out_->close();
|
||||
realtime_out_.reset();
|
||||
TUNABLE_LOG2("Closed realtime output file");
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
|
||||
std::scoped_lock l{lock_};
|
||||
|
||||
@ -483,6 +396,7 @@ TuningContext::TuningContext() :
|
||||
tuning_enable_{true},
|
||||
record_untuned_enable_{false},
|
||||
manager_initialized_{false},
|
||||
write_file_on_exit_{true},
|
||||
numerics_check_enable_{false},
|
||||
max_tuning_duration_ms_{30},
|
||||
max_tuning_iterations_{100},
|
||||
@ -503,8 +417,20 @@ TuningContext::~TuningContext() {
|
||||
// but doesn't do any computation itself.
|
||||
return;
|
||||
}
|
||||
TUNABLE_LOG1("Closing File");
|
||||
GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
|
||||
auto filename = GetFilename();
|
||||
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
|
||||
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
|
||||
if (results_count_from_input_file_ > 0) {
|
||||
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG1("writing file ", filename);
|
||||
}
|
||||
if (!WriteFile(filename)) {
|
||||
TUNABLE_LOG1("failed to write file ", filename);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (untuned_file_.good()) {
|
||||
untuned_file_.close();
|
||||
@ -585,54 +511,20 @@ std::ofstream& TuningContext::GetUntunedFile(){
|
||||
return untuned_file_;
|
||||
}
|
||||
|
||||
void TuningContext::WriteFileOnExit(bool value) {
|
||||
write_file_on_exit_ = value;
|
||||
}
|
||||
|
||||
void TuningContext::EnableNumericsCheck(bool value) {
|
||||
numerics_check_enable_ = value;
|
||||
}
|
||||
|
||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
|
||||
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
|
||||
if (!env_opt.has_value()) {
|
||||
return numerics_cfg_;
|
||||
}
|
||||
|
||||
const std::string& env = env_opt.value();
|
||||
|
||||
if (env == "0") {
|
||||
return NumericalCheckConfig(false, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
const size_t underscore = env.find('_');
|
||||
|
||||
TORCH_CHECK(
|
||||
underscore != std::string::npos,
|
||||
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
|
||||
"Expected 'atol_rtol', got: ",
|
||||
env);
|
||||
|
||||
double atol = 0.0;
|
||||
double rtol = 0.0;
|
||||
|
||||
try {
|
||||
atol = std::stod(env.substr(0, underscore));
|
||||
rtol = std::stod(env.substr(underscore + 1));
|
||||
} catch (const std::exception& e) {
|
||||
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
|
||||
}
|
||||
|
||||
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
|
||||
return NumericalCheckConfig(true, atol, rtol);
|
||||
}
|
||||
|
||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
|
||||
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
|
||||
numerics_cfg_ = {enabled, atol, rtol};
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto cfg = GetNumericalCheckConfig();
|
||||
return cfg.enabled || numerics_check_enable_;
|
||||
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env == "1") {
|
||||
return true;
|
||||
}
|
||||
return numerics_check_enable_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
||||
@ -742,6 +634,11 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
|
||||
auto filename = GetFilename();
|
||||
if (!filename.empty() && !IsRecordUntunedEnabled()) {
|
||||
ReadFile(filename);
|
||||
// attempt immediately to open file for writing to catch errors early
|
||||
std::ofstream file(filename, std::ios::out | std::ios::app);
|
||||
if (!file.good()) {
|
||||
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
|
||||
}
|
||||
}
|
||||
});
|
||||
return manager_;
|
||||
@ -847,6 +744,27 @@ bool TuningContext::ReadFile(const std::string& filename_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TuningContext::WriteFile(const std::string& filename_) {
|
||||
std::string filename = filename_.empty() ? GetFilename() : filename_;
|
||||
std::ofstream file(filename, std::ios::out | std::ios::trunc);
|
||||
if (!file.good()) {
|
||||
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
|
||||
return false;
|
||||
}
|
||||
auto validators = GetTuningResultsValidator().GetAllValidators();
|
||||
for (const auto& [key, val] : validators) {
|
||||
file << "Validator," << key << "," << val << std::endl;
|
||||
}
|
||||
auto results = GetTuningResultsManager().Dump();
|
||||
for (const auto& [op_sig, kernelmap] : results) {
|
||||
for (const auto& [param_sig, result] : kernelmap) {
|
||||
file << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
}
|
||||
}
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct MaybeDelete {
|
||||
|
||||
@ -103,24 +103,10 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
|
||||
|
||||
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
||||
const std::string& params_signature, const std::string& blas_signature);
|
||||
|
||||
void InitRealtimeAppend(
|
||||
const std::string& filename,
|
||||
const std::unordered_map<std::string, std::string>& validators);
|
||||
|
||||
void AppendResultLine(const std::string& op_sig,
|
||||
const std::string& param_sig,
|
||||
const ResultEntry& result);
|
||||
|
||||
void CloseRealtimeAppend(); // For clean shutdown
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::mutex realtime_file_mutex_;
|
||||
std::unique_ptr<std::ofstream> realtime_out_;
|
||||
std::string realtime_filename_;
|
||||
ResultsMap results_;
|
||||
UntunedMap untuned_results_;
|
||||
bool validators_written_ = false;
|
||||
|
||||
};
|
||||
|
||||
@ -148,16 +134,6 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
|
||||
GetValidateFuncs validators_;
|
||||
};
|
||||
|
||||
struct NumericalCheckConfig {
|
||||
bool enabled{false};
|
||||
double atol{1e-5};
|
||||
double rtol{1e-5};
|
||||
|
||||
NumericalCheckConfig() = default;
|
||||
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
|
||||
};
|
||||
|
||||
|
||||
class TORCH_CUDA_CPP_API TuningContext {
|
||||
public:
|
||||
TuningContext();
|
||||
@ -179,8 +155,6 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
|
||||
void EnableNumericsCheck(bool value);
|
||||
bool IsNumericsCheckEnabled() const;
|
||||
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
|
||||
NumericalCheckConfig GetNumericalCheckConfig() const;
|
||||
|
||||
void SetMaxTuningDurationMs(int max_duration_ms);
|
||||
int GetMaxTuningDurationMs() const;
|
||||
@ -211,7 +185,10 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
|
||||
std::string GetFilename() const;
|
||||
|
||||
void WriteFileOnExit(bool value);
|
||||
|
||||
bool ReadFile(const std::string& filename={});
|
||||
bool WriteFile(const std::string& filename={});
|
||||
|
||||
template<class... Types>
|
||||
void Log(int level, Types... args) {
|
||||
@ -230,6 +207,7 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
bool tuning_enable_;
|
||||
bool record_untuned_enable_;
|
||||
bool manager_initialized_;
|
||||
bool write_file_on_exit_;
|
||||
bool numerics_check_enable_;
|
||||
int max_tuning_duration_ms_;
|
||||
int max_tuning_iterations_;
|
||||
@ -244,8 +222,6 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
std::ofstream untuned_file_;
|
||||
size_t results_count_from_input_file_;
|
||||
bool is_shutting_down_;
|
||||
|
||||
NumericalCheckConfig numerics_cfg_{};
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
||||
|
||||
@ -267,10 +267,27 @@ class TunableOp {
|
||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// collect a small profile
|
||||
@ -293,22 +310,6 @@ class TunableOp {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// for warmup does user set max duration, max iters, or both?
|
||||
// warmup is skipped by default, i.e. warmup_iter = 0
|
||||
// warmup will be set to the non-zero value of max_warmup_duration
|
||||
|
||||
@ -213,22 +213,40 @@ static cudnn_grid_sample_backward_batch_rule(
|
||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
||||
}
|
||||
|
||||
// uses functional formulation for one_hot under vmap to be compatible with
|
||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
|
||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
|
||||
// but requires explicit positive num_classes under vmap to avoid
|
||||
// data-dependent output shapes.
|
||||
// TODO: replace with targetable functionalization
|
||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
|
||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
// disallow implicit inference under vmap; this would be data-dependent
|
||||
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
|
||||
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
|
||||
"provide an explicit positive num_classes argument.");
|
||||
|
||||
const auto options = self.options();
|
||||
at::Tensor index = at::arange(num_classes, options);
|
||||
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
|
||||
// Disabling all of the following checks. This is OK because scatter has checks too.
|
||||
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
|
||||
// // non-empty tensor
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //for cuda, rely on device assert thrown by scatter
|
||||
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
|
||||
// }
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //rely on device asserts from scatter to avoid sync here
|
||||
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
||||
// }
|
||||
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
return ret.scatter(-1, self.unsqueeze(-1), 1);
|
||||
}
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
|
||||
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = self.sym_sizes().vec();
|
||||
auto shape = self.sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.sym_numel() == 0) {
|
||||
if (self.numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
shape.push_back(num_classes);
|
||||
return at::empty(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
shape.push_back(num_classes);
|
||||
Tensor ret = at::zeros(shape, self.options());
|
||||
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
[&]() {
|
||||
using scalar_t =
|
||||
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
|
||||
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
cpu_kernel_vec(iter,
|
||||
|
||||
@ -1230,205 +1230,8 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
);
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_tunable_scaled_gemm_rocm(
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a, const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
const at::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifdef USE_ROCM
|
||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
}
|
||||
AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype;
|
||||
params.c = args.result->data_ptr();
|
||||
params.c_scale_ptr = args.scale_result_ptr;
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}),
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
#undef TUNABLE_DISPATCH
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices");
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_gemm(
|
||||
const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a, const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
// ROCM enables the TunableOp path only
|
||||
// but can fallback to at::cuda::blas::scaled_gemm
|
||||
#ifdef USE_ROCM
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled();
|
||||
#else
|
||||
bool tunable_op_enabled = false;
|
||||
#endif
|
||||
if (tunable_op_enabled) {
|
||||
// Only available on ROCM
|
||||
return _tunable_scaled_gemm_rocm(
|
||||
args,
|
||||
mat1, mat2,
|
||||
scale_a, scale_b,
|
||||
scaling_choice_a, scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out_dtype_,
|
||||
out);
|
||||
}
|
||||
else
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here
|
||||
// to help cleanup v1 call structure.
|
||||
Tensor&
|
||||
_scaled_rowwise_rowwise(
|
||||
const Tensor&, const Tensor&,
|
||||
const Tensor&, const Tensor&,
|
||||
const std::optional<Tensor>&,
|
||||
const c10::ScalarType,
|
||||
bool,
|
||||
Tensor&);
|
||||
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices
|
||||
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
|
||||
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
|
||||
@ -1470,10 +1273,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// by decreasing priority. We prefer "simpler" schemes as they are supported
|
||||
// more broadly (more GPU archs, more CUDA versions) and because they are more
|
||||
// efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
|
||||
|
||||
// List of supported BlockWise pairs for FP8:
|
||||
// https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
|
||||
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
@ -1506,7 +1305,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
|
||||
#ifndef USE_ROCM
|
||||
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
|
||||
TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
||||
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
||||
"Multiplication of two Float8_e5m2 matrices is not supported");
|
||||
#endif
|
||||
if (use_fast_accum) {
|
||||
@ -1572,44 +1371,41 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
|
||||
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
|
||||
// and only for compute capability 9.0+. In other cases we use CUTLASS.
|
||||
// We are doing row-wise scaling
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
return _scaled_rowwise_rowwise(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
out.scalar_type(),
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
// We are doing row-wise scaling
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
|
||||
&& ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
#else
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
|
||||
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
|
||||
Tensor b = mat2;
|
||||
if (_scaled_mm_is_fnuz()) {
|
||||
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz,
|
||||
"Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype());
|
||||
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn,
|
||||
"Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype());
|
||||
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
|
||||
}
|
||||
// Until more than bf16 is supported.
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
|
||||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
|
||||
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
|
||||
#endif
|
||||
}
|
||||
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
|
||||
#ifdef USE_ROCM
|
||||
#if ROCM_VERSION >= 70000
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
|
||||
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
|
||||
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
|
||||
|
||||
int packed_factor = 1;
|
||||
@ -1618,20 +1414,163 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// effectively packing two elements into one byte.
|
||||
packed_factor = 2;
|
||||
}
|
||||
TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
|
||||
TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
|
||||
mat2.size(1) % 16 == 0,
|
||||
"M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
#ifdef USE_ROCM
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
}
|
||||
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
|
||||
params.c = args.result->data_ptr();
|
||||
params.c_scale_ptr = args.scale_result_ptr;
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype_;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}),
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
#undef TUNABLE_DISPATCH
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1971,6 +1910,159 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
|
||||
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
Tensor&
|
||||
_cutlass_scaled_gemm(
|
||||
const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a, const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
#ifdef USE_ROCM
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
}
|
||||
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
|
||||
params.c = args.result->data_ptr();
|
||||
params.c_scale_ptr = args.scale_result_ptr;
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype_;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}),
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
#undef TUNABLE_DISPATCH
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_tensorwise_tensorwise(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -1990,7 +2082,7 @@ _scaled_tensorwise_tensorwise(
|
||||
auto scaling_choice_a = ScalingType::TensorWise;
|
||||
auto scaling_choice_b = ScalingType::TensorWise;
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2026,7 +2118,7 @@ _scaled_rowwise_rowwise(
|
||||
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -2052,38 +2144,11 @@ _scaled_rowwise_rowwise(
|
||||
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
|
||||
#endif
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_block1x128_block1x128(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -2101,14 +2166,15 @@ _scaled_block1x128_block1x128(
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
|
||||
TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
|
||||
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2123,8 +2189,6 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
|
||||
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
@ -2132,14 +2196,15 @@ _scaled_block128x128_block1x128(
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
|
||||
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2161,14 +2226,15 @@ _scaled_block1x128_block128x128(
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
|
||||
"expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2222,7 +2288,7 @@ _scaled_mxfp8_mxfp8(
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -2259,7 +2325,7 @@ _scaled_nvfp4_nvfp4(
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x16;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x16;
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
|
||||
@ -2508,9 +2574,7 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
@ -2521,16 +2585,6 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
@ -2615,9 +2669,6 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
@ -2717,15 +2768,11 @@ _scaled_grouped_mm_cuda(
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
@ -2742,140 +2789,6 @@ _scaled_grouped_mm_cuda(
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
|
||||
@ -856,13 +856,9 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
|
||||
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
|
||||
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
|
||||
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
|
||||
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
|
||||
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
|
||||
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
|
||||
if (ret_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg0_t == rt_binary_specializations[arg_index][1] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][2])
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
@ -870,9 +866,12 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
cret_t,
|
||||
carg0_t,
|
||||
carg1_t>(
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][2]>::t)>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
@ -880,7 +879,6 @@ struct type_specialized_kernel_launcher {
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -655,14 +655,8 @@ struct ReduceOp {
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
|
||||
// matching Triton, etc.
|
||||
// todo for AMD
|
||||
#ifdef USE_ROCM
|
||||
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = ops.warp_shfl_down(value[i], offset);
|
||||
|
||||
@ -77,8 +77,8 @@ struct nansum_functor_complex {
|
||||
#if AT_USE_JITERATOR()
|
||||
void operator()(TensorIterator& iter) {
|
||||
std::string func = jiterator_stringify(
|
||||
arg_t combine(arg_t a, arg_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : b);
|
||||
arg_t combine(arg_t a, scalar_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
|
||||
}
|
||||
);
|
||||
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
|
||||
|
||||
@ -464,7 +464,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
}
|
||||
#endif
|
||||
int32_t trailingSize;
|
||||
int nDimsLocal = nDims;
|
||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
||||
if (isInOutAligned) {
|
||||
// in this case we can and should flatten the tensors after the cat dim
|
||||
@ -478,7 +477,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
||||
// for input, we will fix up the sizes and strides in the kernel directly
|
||||
kernelOutputParam = outputParam;
|
||||
nDimsLocal = dimension + 1;
|
||||
nDims = dimension + 1;
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
||||
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
||||
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
||||
@ -495,7 +494,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
cat_dim = nDimsLocal - cat_dim;
|
||||
cat_dim = nDims - cat_dim;
|
||||
break;
|
||||
default:
|
||||
cat_dim--;
|
||||
@ -526,7 +525,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
}\
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
switch (nDimsLocal) {
|
||||
switch (nDims) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
|
||||
@ -21,15 +21,9 @@ namespace {
|
||||
struct offset_t {
|
||||
int stride;
|
||||
int begin;
|
||||
__device__ int operator[](int i) const {
|
||||
__device__ int operator[](int i) {
|
||||
return stride * (begin + i);
|
||||
}
|
||||
#if CCCL_VERSION >= 3001000
|
||||
__device__ offset_t& operator+=(int i) {
|
||||
begin += i;
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
// Segmented sort by full sort algorithm:.
|
||||
// Say we are sorting a (2, 3) tensor. We have in flattened form:
|
||||
|
||||
@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
||||
@ -662,7 +662,7 @@ void svd_cusolver(const Tensor& A,
|
||||
const auto n = A.size(-1);
|
||||
const auto k = std::min(m, n);
|
||||
|
||||
static constexpr const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
|
||||
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
|
||||
|
||||
// The default heuristic is to use gesvdj driver
|
||||
#ifdef USE_ROCM
|
||||
|
||||
@ -466,11 +466,7 @@ struct ReduceJitOp {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = reducer::warp_shfl_down(value[i], offset);
|
||||
|
||||
@ -3,9 +3,6 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -22,30 +19,28 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
|
||||
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
|
||||
"containing at least one element, but got normalized_shape = ",
|
||||
normalized_shape);
|
||||
if (weight.defined()) {
|
||||
TORCH_SYM_CHECK(
|
||||
sym_equals(weight.sym_sizes(), normalized_shape),
|
||||
"Expected weight to be of same shape as normalized_shape, but got ",
|
||||
"weight of shape ",
|
||||
weight.sym_sizes(),
|
||||
" and normalized_shape = ",
|
||||
normalized_shape);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || weight.sym_sizes().equals(normalized_shape),
|
||||
"Expected weight to be of same shape as normalized_shape, but got ",
|
||||
"weight of shape ",
|
||||
weight.sym_sizes(),
|
||||
" and normalized_shape = ",
|
||||
normalized_shape);
|
||||
|
||||
const auto input_ndim = input.dim();
|
||||
const auto input_shape = input.sym_sizes();
|
||||
TORCH_CHECK_VALUE(
|
||||
input_ndim >= normalized_ndim,
|
||||
"Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
|
||||
|
||||
auto expect_input_shape_msg = c10::str(
|
||||
"Given normalized_shape=", normalized_shape,
|
||||
", expected input with shape [*", c10::Join(", ", normalized_shape),
|
||||
"], but got input of size", input_shape);
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
|
||||
expect_input_shape_msg);
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim)
|
||||
.equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
for (auto size : normalized_shape) {
|
||||
ss << ", " << size;
|
||||
}
|
||||
ss << "], but got input of size" << input_shape;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
|
||||
|
||||
@ -99,9 +99,6 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape);
|
||||
MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
|
||||
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
|
||||
|
||||
// Determines whether a tensor is too large to use MPSGraph
|
||||
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI = true);
|
||||
|
||||
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
}
|
||||
|
||||
@ -439,22 +439,6 @@ static void check_mps_shape(MPSShape* shape) {
|
||||
}
|
||||
}
|
||||
|
||||
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI) {
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
if ((!tensor.is_contiguous() || tensor.storage_offset()) && useMPSStridedAPI && is_macOS_15_0_or_newer) {
|
||||
auto storage_numel = tensor.storage().nbytes() / tensor.element_size() - tensor.storage_offset();
|
||||
if (storage_numel > std::numeric_limits<int32_t>::max()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (auto size : tensor.sizes()) {
|
||||
if (size > std::numeric_limits<int32_t>::max()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
|
||||
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
|
||||
|
||||
|
||||
@ -249,7 +249,7 @@ kernel void embedding_bag(
|
||||
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct MaybeDivBagSize {
|
||||
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> /*bag_size*/) {
|
||||
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeSharedParams {
|
||||
int32_t ndim;
|
||||
int32_t cat_dim;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
};
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeInputParams {
|
||||
idx_type_t cat_dim_offset;
|
||||
idx_type_t input_element_offset;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
};
|
||||
@ -1,82 +0,0 @@
|
||||
#include <ATen/native/mps/kernels/Shape.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
kernel void cat_large(
|
||||
constant T_in* input [[buffer(0)]],
|
||||
device T_out* output [[buffer(1)]],
|
||||
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
|
||||
constant CatLargeInputParams<>& input_params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto ndim = shared_params.ndim;
|
||||
auto cat_dim = shared_params.cat_dim;
|
||||
constant auto& output_strides = shared_params.output_strides;
|
||||
constant auto& output_sizes = shared_params.output_sizes;
|
||||
|
||||
auto cat_dim_offset = input_params.cat_dim_offset;
|
||||
auto input_element_offset = input_params.input_element_offset;
|
||||
constant auto& input_strides = input_params.input_strides;
|
||||
constant auto& input_sizes = input_params.input_sizes;
|
||||
|
||||
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
|
||||
int64_t input_offset = 0;
|
||||
int64_t output_offset = 0;
|
||||
|
||||
for (auto dim = ndim - 1; dim >= 0; dim--) {
|
||||
auto dim_size = input_sizes[dim];
|
||||
auto input_dim_idx = input_element_idx % dim_size;
|
||||
auto output_dim_idx =
|
||||
input_dim_idx + ((dim == cat_dim) ? cat_dim_offset : 0);
|
||||
|
||||
input_offset += input_strides[dim] * input_dim_idx;
|
||||
output_offset += output_strides[dim] * output_dim_idx;
|
||||
|
||||
input_element_idx = input_element_idx / dim_size;
|
||||
}
|
||||
|
||||
output[output_offset] = static_cast<T_out>(input[input_offset]);
|
||||
}
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
|
||||
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
|
||||
kernel void cat_large<T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
|
||||
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
|
||||
REGISTER_CAT_LARGE_OP(float, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(half, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(int, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uint, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(long, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ulong, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(short, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ushort, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(char, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uchar, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bool, T_out);
|
||||
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
|
||||
|
||||
REGISTER_CAT_LARGE_OP(float2, float2);
|
||||
REGISTER_CAT_LARGE_OP(half2, half2);
|
||||
@ -512,7 +512,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
|
||||
}
|
||||
|
||||
static MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// (1.0f + erf(x*SQRT1_2)) * 0.5f;
|
||||
// (1.0f + erf(x*SQRT1_2)) * 0.5f * x;
|
||||
auto dataType = [inputTensor dataType];
|
||||
const float SQRT1_2 = 0.707106781186547524400844362104849039f;
|
||||
MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType];
|
||||
|
||||
@ -54,10 +54,6 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSBinaryCachedGraph;
|
||||
|
||||
if (self.numel() == 0 & other.numel() == 0) {
|
||||
return zeros({}, self.options());
|
||||
}
|
||||
|
||||
dot_check(self, other);
|
||||
|
||||
auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
|
||||
|
||||
@ -2,13 +2,9 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/Shape.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -20,13 +16,6 @@
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/Shape_metallib.h>
|
||||
#endif
|
||||
|
||||
namespace mps {
|
||||
|
||||
// Produces a shape with the `dim` dimension set to 0.
|
||||
@ -68,70 +57,6 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
||||
// This implementation of cat is used only if one of the inputs or the output is
|
||||
// too large to use MPSGraph.
|
||||
// NOTE: `output` is expected to already have the correct size.
|
||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatLargeSharedParams shared_params;
|
||||
|
||||
shared_params.ndim = output.dim();
|
||||
shared_params.cat_dim = dimension;
|
||||
|
||||
for (const auto dim : c10::irange(output.dim())) {
|
||||
shared_params.output_strides[dim] = output.stride(dim);
|
||||
shared_params.output_sizes[dim] = output.size(dim);
|
||||
}
|
||||
|
||||
int64_t cat_dim_offset = 0;
|
||||
size_t input_idx = 0;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Launch a separate kernels for each input. This will produce some overhead,
|
||||
// but that should be relatively minimal since at least one of the inputs is
|
||||
// very large. In order to launch only one kernel to process all inputs, we
|
||||
// would have to copy all the input tensor data into a packed buffer, which
|
||||
// would not be ideal.
|
||||
for (const Tensor& input : inputs) {
|
||||
if (input.numel() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Metal can only launch up to MAX_INT threads at one time. If the input has
|
||||
// more than that number of elements, launch multiple kernels with different
|
||||
// offsets into the data.
|
||||
const int64_t max_num_threads = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
|
||||
|
||||
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
|
||||
auto num_threads = std::min(max_num_threads, numel_remaining);
|
||||
CatLargeInputParams input_params;
|
||||
|
||||
input_params.cat_dim_offset = cat_dim_offset;
|
||||
input_params.input_element_offset = input.numel() - numel_remaining;
|
||||
|
||||
for (const auto dim : c10::irange(input.dim())) {
|
||||
input_params.input_strides[dim] = input.stride(dim);
|
||||
input_params.input_sizes[dim] = input.size(dim);
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(
|
||||
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
|
||||
mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
cat_dim_offset += input.size(dimension);
|
||||
input_idx++;
|
||||
}
|
||||
}
|
||||
} // namespace mps
|
||||
|
||||
// topk
|
||||
@ -306,11 +231,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
idx = 0;
|
||||
bool has_large_tensor = false;
|
||||
for (const Tensor& tensor : materialized_inputs) {
|
||||
if (isTooLargeForMPSGraph(tensor)) {
|
||||
has_large_tensor |= true;
|
||||
}
|
||||
if (!should_skip(tensor)) {
|
||||
// TODO: Factor out `check_shape_except_dim`
|
||||
check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
|
||||
@ -328,12 +249,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
return;
|
||||
}
|
||||
|
||||
has_large_tensor |= isTooLargeForMPSGraph(out);
|
||||
|
||||
if (has_large_tensor) {
|
||||
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
std::vector<MPSGraphTensor*> inputTensors_;
|
||||
|
||||
@ -4545,7 +4545,6 @@
|
||||
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
|
||||
dispatch:
|
||||
CPU, CUDA: _cdist_forward
|
||||
MTIA: _cdist_forward_mtia
|
||||
MPS: _cdist_forward_mps
|
||||
autogen: _cdist_forward.out
|
||||
tags: core
|
||||
@ -7183,12 +7182,6 @@
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda_v2
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
||||
@ -178,30 +178,24 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
|
||||
0 & \text{ else }
|
||||
\end{cases}
|
||||
*/
|
||||
bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
|
||||
at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
|
||||
at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
|
||||
at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
|
||||
at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
|
||||
auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
|
||||
|
||||
auto zero_point_rounded = _get_rounded_zero_point(zero_point_, quant_min, quant_max);
|
||||
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
|
||||
|
||||
TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
|
||||
TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
|
||||
|
||||
TORCH_CHECK(X_.sizes() == dY_.sizes(), "`X` and `dY` are not the same size");
|
||||
TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
|
||||
TORCH_CHECK(
|
||||
quant_min <= 0 && quant_max >= 0,
|
||||
"Expecting `quant_min` <= 0 and `quant_max` >= 0");
|
||||
TORCH_CHECK(scale_.dim() == 1, "scale should be a 1-D tensor");
|
||||
TORCH_CHECK(zero_point_.dim() == 1, "zero point should be a 1-D tensor");
|
||||
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
|
||||
TORCH_CHECK(
|
||||
scale_.numel() == zero_point_.numel(),
|
||||
scale.numel() == zero_point.numel(),
|
||||
"scale and zero-point need to have the same dimensions");
|
||||
TORCH_CHECK(
|
||||
scale_.numel() == X_.size(axis),
|
||||
scale.numel() == X.size(axis),
|
||||
"dimensions of scale and zero-point are not consistent with input tensor")
|
||||
|
||||
TORCH_CHECK(
|
||||
@ -210,42 +204,42 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
|
||||
"`zero_point` must be between `quant_min` and `quant_max`.");
|
||||
|
||||
TORCH_CHECK(
|
||||
axis >= 0 && axis < X_.dim(),
|
||||
axis >= 0 && axis < X.dim(),
|
||||
"`axis` must be between 0 and number of dimensions of input");
|
||||
|
||||
if (X_.numel() <= 0) {
|
||||
if (X.numel() <= 0) {
|
||||
return std::make_tuple(X, scale, zero_point);
|
||||
}
|
||||
|
||||
auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
|
||||
auto numDimensions = X_.ndimension();
|
||||
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
|
||||
auto numDimensions = X.ndimension();
|
||||
|
||||
// Create an axis mask for vectorizing and reshaping the scale and zero point tensors
|
||||
// into the same shapes as X along the channel axis.
|
||||
c10::DimVector axis_mask(numDimensions);
|
||||
for (const auto i : c10::irange(numDimensions)) {
|
||||
axis_mask[i] = (i == axis) ? X_.size(axis) : 1;
|
||||
axis_mask[i] = (i == axis) ? X.size(axis) : 1;
|
||||
}
|
||||
auto X_shape = X_.sizes();
|
||||
auto scale_vectorized = scale_.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
|
||||
auto X_shape = X.sizes();
|
||||
auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
|
||||
auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
.add_output(dX)
|
||||
.add_output(dScale_vec)
|
||||
.add_output(dZeroPoint_vec)
|
||||
.add_input(X_)
|
||||
.add_input(dY_)
|
||||
.add_input(X)
|
||||
.add_input(dY)
|
||||
.add_input(scale_vectorized)
|
||||
.add_input(zero_point_vectorized)
|
||||
.build();
|
||||
|
||||
fake_quant_grad_learnable_channel_stub(
|
||||
X_.device().type(), iter, quant_min, quant_max, grad_factor);
|
||||
X.device().type(), iter, quant_min, quant_max, grad_factor);
|
||||
|
||||
auto numElements = X_.ndimension() - 1;
|
||||
auto numElements = X.ndimension() - 1;
|
||||
|
||||
// Create a collection of axes that include all but the channel axis for
|
||||
// reduction when summing over the dScale and dZeroPoint tensors.
|
||||
|
||||
@ -80,19 +80,3 @@ TEST(XpuGeneratorTest, testMultithreadingGetSetCurrentSeed) {
|
||||
t2.join();
|
||||
EXPECT_EQ(gen1.current_seed(), initial_seed+3);
|
||||
}
|
||||
|
||||
TEST(XpuGeneratorTest, testRNGForking) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
if (!at::xpu::is_available()) return;
|
||||
auto default_gen = at::xpu::detail::getDefaultXPUGenerator();
|
||||
auto current_gen = at::xpu::detail::createXPUGenerator();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(default_gen.mutex());
|
||||
current_gen = default_gen.clone(); // capture the current state of default generator
|
||||
}
|
||||
auto target_value = at::randn({1000}, at::kXPU);
|
||||
// Dramatically alter the internal state of the main generator
|
||||
auto x = at::randn({100000}, at::kXPU);
|
||||
auto forked_value = at::randn({1000}, current_gen, at::kXPU);
|
||||
ASSERT_EQ(target_value.sum().item<double>(), forked_value.sum().item<double>());
|
||||
}
|
||||
|
||||
@ -1,45 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
namespace at {
|
||||
|
||||
struct PhiloxXpuState {
|
||||
PhiloxXpuState() = default;
|
||||
PhiloxXpuState(uint64_t seed, uint64_t offset) {
|
||||
seed_.val = seed;
|
||||
offset_.val = offset;
|
||||
}
|
||||
// for graph capture
|
||||
PhiloxXpuState(
|
||||
int64_t* seed,
|
||||
int64_t* offset_extragraph,
|
||||
uint32_t offset_intragraph) {
|
||||
seed_.ptr = seed;
|
||||
offset_.ptr = offset_extragraph;
|
||||
offset_intragraph_ = offset_intragraph;
|
||||
captured_ = true;
|
||||
}
|
||||
|
||||
union Payload {
|
||||
uint64_t val;
|
||||
int64_t* ptr;
|
||||
};
|
||||
|
||||
Payload seed_{};
|
||||
Payload offset_{};
|
||||
uint32_t offset_intragraph_ = 0;
|
||||
bool captured_ = false;
|
||||
};
|
||||
|
||||
namespace xpu::philox {
|
||||
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxXpuState arg) {
|
||||
if (arg.captured_) {
|
||||
return std::make_tuple(
|
||||
static_cast<uint64_t>(*arg.seed_.ptr),
|
||||
static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
||||
} else {
|
||||
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xpu::philox
|
||||
} // namespace at
|
||||
@ -1,14 +1,9 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/xpu/XPUGeneratorImpl.h>
|
||||
#include <ATen/xpu/XPUGraphsUtils.h>
|
||||
#include <c10/core/StreamGuard.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
|
||||
constexpr uint64_t PHILOX_ROUND_SIZE = 4;
|
||||
|
||||
namespace at {
|
||||
namespace xpu::detail {
|
||||
namespace {
|
||||
@ -63,82 +58,29 @@ Generator createXPUGenerator(DeviceIndex device) {
|
||||
|
||||
} // namespace xpu::detail
|
||||
|
||||
// Creates a clone of this XPU Generator State.
|
||||
c10::intrusive_ptr<XPUGeneratorState> XPUGeneratorState::clone() {
|
||||
return make_intrusive<XPUGeneratorState>(
|
||||
seed_, philox_offset_per_thread_, offset_intragraph_);
|
||||
}
|
||||
|
||||
// Function to increase the internal offset based on the specified increment.
|
||||
void XPUGeneratorState::increase(uint64_t increment) {
|
||||
increment = ((increment + PHILOX_ROUND_SIZE - 1) / PHILOX_ROUND_SIZE) *
|
||||
PHILOX_ROUND_SIZE;
|
||||
if (at::xpu::currentStreamCaptureStatus() !=
|
||||
at::xpu::CaptureStatus::Executing) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
capturing_,
|
||||
"Attempt to increase offset for a XPU generator not in capture mode.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
|
||||
"Increment causes overflow in the offset value.");
|
||||
offset_intragraph_ += increment;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!capturing_,
|
||||
"Offset increment outside graph capture encountered unexpectedly.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
philox_offset_per_thread_ % 4 == 0,
|
||||
"RNG offset must be a multiple of 4.");
|
||||
philox_offset_per_thread_ += increment;
|
||||
}
|
||||
}
|
||||
|
||||
XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index)
|
||||
: GeneratorImpl{
|
||||
Device(DeviceType::XPU, device_index),
|
||||
DispatchKeySet(c10::DispatchKey::XPU)} {
|
||||
at::xpu::assertNotCapturing("Cannot construct a new XPUGeneratorImpl");
|
||||
state_ = make_intrusive<XPUGeneratorState>();
|
||||
}
|
||||
|
||||
XPUGeneratorImpl::XPUGeneratorImpl(
|
||||
DeviceIndex device_index,
|
||||
intrusive_ptr<XPUGeneratorState> state)
|
||||
: GeneratorImpl{Device(DeviceType::XPU, device_index), DispatchKeySet(c10::DispatchKey::XPU)},
|
||||
state_(std::move(state)) {}
|
||||
DispatchKeySet(c10::DispatchKey::XPU)} {}
|
||||
|
||||
void XPUGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
if (C10_LIKELY(
|
||||
at::xpu::currentStreamCaptureStatus() ==
|
||||
at::xpu::CaptureStatus::Executing)) {
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
state_->seed_ == seed,
|
||||
"XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
|
||||
}
|
||||
seed_ = seed;
|
||||
set_philox_offset_per_thread(0);
|
||||
}
|
||||
|
||||
void XPUGeneratorImpl::set_offset(uint64_t offset) {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::set_offset");
|
||||
set_philox_offset_per_thread(offset);
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::get_offset() const {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::get_offset");
|
||||
return state_->philox_offset_per_thread_;
|
||||
return philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::current_seed() const {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::current_seed");
|
||||
return state_->seed_;
|
||||
return seed_;
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::seed() {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::seed");
|
||||
auto random = c10::detail::getNonDeterministicRandom(true);
|
||||
this->set_current_seed(random);
|
||||
return random;
|
||||
@ -168,65 +110,39 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||
}
|
||||
|
||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::xpu::assertNotCapturing(
|
||||
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(uint64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
|
||||
at::detail::check_rng_state(new_state);
|
||||
|
||||
bool no_philox_seed = false;
|
||||
auto new_state_size = new_state.numel();
|
||||
if (new_state_size == total_size - offset_size) {
|
||||
no_philox_seed = true;
|
||||
} else {
|
||||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
|
||||
}
|
||||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
|
||||
|
||||
uint64_t input_seed = 0;
|
||||
uint64_t input_seed;
|
||||
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
|
||||
memcpy(&input_seed, new_rng_state, seed_size);
|
||||
this->set_current_seed(input_seed);
|
||||
uint64_t philox_offset = 0;
|
||||
if (!no_philox_seed) {
|
||||
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
|
||||
}
|
||||
uint64_t philox_offset;
|
||||
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
|
||||
this->set_philox_offset_per_thread(philox_offset);
|
||||
}
|
||||
|
||||
void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
|
||||
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
philox_offset_per_thread_ = offset;
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::philox_offset_per_thread() const {
|
||||
return state_->philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
PhiloxXpuState XPUGeneratorImpl::philox_xpu_state(uint64_t increment) {
|
||||
if (at::xpu::currentStreamCaptureStatus() !=
|
||||
at::xpu::CaptureStatus::Executing) {
|
||||
uint32_t offset = state_->offset_intragraph_;
|
||||
state_->increase(increment);
|
||||
return PhiloxXpuState(
|
||||
state_->seed_extragraph_.data_ptr<int64_t>(),
|
||||
state_->offset_extragraph_.data_ptr<int64_t>(),
|
||||
offset);
|
||||
} else {
|
||||
uint64_t offset = state_->philox_offset_per_thread_;
|
||||
state_->increase(increment);
|
||||
return PhiloxXpuState(state_->seed_, offset);
|
||||
}
|
||||
return philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs(
|
||||
uint64_t increment) {
|
||||
at::xpu::assertNotCapturing(
|
||||
"Refactor this op to use XPUGeneratorImpl::philox_xpu_state. Cannot call XPUGeneratorImpl::philox_engine_inputs");
|
||||
uint64_t offset = state_->philox_offset_per_thread_;
|
||||
state_->increase(increment);
|
||||
return std::make_pair(state_->seed_, offset);
|
||||
increment = ((increment + 3) / 4) * 4;
|
||||
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
|
||||
uint64_t offset = this->philox_offset_per_thread_;
|
||||
this->philox_offset_per_thread_ += increment;
|
||||
return std::make_pair(this->seed_, offset);
|
||||
}
|
||||
|
||||
DeviceType XPUGeneratorImpl::device_type() {
|
||||
@ -238,8 +154,9 @@ std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const {
|
||||
}
|
||||
|
||||
XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::clone_impl");
|
||||
auto gen = new XPUGeneratorImpl(this->device().index(), state_->clone());
|
||||
auto gen = new XPUGeneratorImpl(this->device().index());
|
||||
gen->set_current_seed(this->seed_);
|
||||
gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
|
||||
return gen;
|
||||
}
|
||||
|
||||
|
||||
@ -1,43 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <ATen/xpu/PhiloxXpuState.h>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace xpu {
|
||||
struct XPUGraph;
|
||||
}
|
||||
|
||||
struct XPUGeneratorState : public c10::intrusive_ptr_target {
|
||||
uint64_t seed_;
|
||||
uint64_t philox_offset_per_thread_;
|
||||
uint32_t offset_intragraph_;
|
||||
bool capturing_{};
|
||||
at::TensorBase seed_extragraph_{};
|
||||
at::TensorBase offset_extragraph_{};
|
||||
|
||||
XPUGeneratorState(
|
||||
uint64_t seed = default_rng_seed_val,
|
||||
uint64_t philox_offset_per_thread = 0,
|
||||
uint32_t offset_intragraph = 0)
|
||||
: seed_(seed),
|
||||
philox_offset_per_thread_(philox_offset_per_thread),
|
||||
offset_intragraph_(offset_intragraph) {}
|
||||
|
||||
void increase(uint64_t increment);
|
||||
|
||||
c10::intrusive_ptr<XPUGeneratorState> clone();
|
||||
};
|
||||
|
||||
struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
|
||||
// Constructors
|
||||
XPUGeneratorImpl(DeviceIndex device_index = -1);
|
||||
XPUGeneratorImpl(
|
||||
DeviceIndex device_index,
|
||||
c10::intrusive_ptr<XPUGeneratorState> state_);
|
||||
~XPUGeneratorImpl() override = default;
|
||||
|
||||
// XPUGeneratorImpl methods
|
||||
@ -49,18 +18,15 @@ struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
|
||||
uint64_t seed() override;
|
||||
void set_state(const c10::TensorImpl& new_state) override;
|
||||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
||||
|
||||
void set_philox_offset_per_thread(uint64_t offset);
|
||||
uint64_t philox_offset_per_thread() const;
|
||||
|
||||
PhiloxXpuState philox_xpu_state(uint64_t increment);
|
||||
// will remove once all ops are refactored to use philox_xpu_state.
|
||||
std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
|
||||
static c10::DeviceType device_type();
|
||||
|
||||
private:
|
||||
XPUGeneratorImpl* clone_impl() const override;
|
||||
c10::intrusive_ptr<XPUGeneratorState> state_;
|
||||
uint64_t seed_ = default_rng_seed_val;
|
||||
uint64_t philox_offset_per_thread_ = 0;
|
||||
};
|
||||
|
||||
namespace xpu::detail {
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUGraphsC10Utils.h>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
inline CaptureStatus currentStreamCaptureStatus() {
|
||||
return c10::xpu::currentStreamCaptureStatusMayInitCtx();
|
||||
}
|
||||
|
||||
inline void assertNotCapturing(const std::string& attempt) {
|
||||
auto status = currentStreamCaptureStatus();
|
||||
TORCH_CHECK(
|
||||
status == CaptureStatus::Executing,
|
||||
attempt,
|
||||
" during XPU graph capture. If you need this call to be captured, "
|
||||
"please file an issue. "
|
||||
"Current xpuStreamCaptureStatus: ",
|
||||
status);
|
||||
}
|
||||
|
||||
} // namespace at::xpu
|
||||
@ -15,8 +15,6 @@ flaky_models = {
|
||||
"moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291
|
||||
# discovered in https://github.com/pytorch/pytorch/issues/161419. Its not flaky but really hard to repro, so skipping it
|
||||
"mobilenetv3_large_100",
|
||||
# https://github.com/pytorch/pytorch/issues/163670
|
||||
"vision_maskrcnn",
|
||||
}
|
||||
|
||||
|
||||
@ -63,10 +61,6 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
|
||||
"swsl_resnext101_32x16d",
|
||||
"torchrec_dlrm",
|
||||
"vgg16",
|
||||
"BERT_pytorch",
|
||||
"coat_lite_mini",
|
||||
"mobilenet_v3_large",
|
||||
"vision_maskrcnn",
|
||||
# LLM
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"google/gemma-2-2b",
|
||||
|
||||
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -290,7 +290,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,21
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -206,7 +206,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,40
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -130,7 +130,7 @@ maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
microbench_unbacked_tolist_sum,pass,0
|
||||
microbench_unbacked_tolist_sum,pass,1
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -130,7 +130,7 @@ maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
microbench_unbacked_tolist_sum,pass,0
|
||||
microbench_unbacked_tolist_sum,pass,1
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -254,7 +254,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,fail_accuracy,30
|
||||
vision_maskrcnn,fail_accuracy,29
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -290,7 +290,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,21
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -202,7 +202,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,40
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -114,7 +114,7 @@ maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
microbench_unbacked_tolist_sum,pass,0
|
||||
microbench_unbacked_tolist_sum,pass,1
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -114,7 +114,7 @@ maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
microbench_unbacked_tolist_sum,pass,0
|
||||
microbench_unbacked_tolist_sum,pass,1
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -242,7 +242,7 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
torch_multimodal_clip,pass,3
|
||||
|
||||
|
||||
|
||||
@ -254,7 +254,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,fail_accuracy,30
|
||||
vision_maskrcnn,fail_accuracy,29
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -290,7 +290,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,18
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -202,7 +202,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,37
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -290,7 +290,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,21
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -206,7 +206,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,40
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -290,7 +290,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,18
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -206,7 +206,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,37
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
||||
|
@ -290,7 +290,7 @@ vgg16,eager_two_runs_differ,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,21
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -206,7 +206,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,40
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user