Compare commits

..

1 Commits

Author SHA1 Message Date
fc4e215ee7 tmp 2025-10-16 11:05:14 -07:00
355 changed files with 3204 additions and 7926 deletions

View File

@ -187,22 +187,19 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
export USE_CUFILE=0
else
DEPS_LIST+=(
"/usr/local/cuda/lib64/libnvToolsExt.so.1"
"/usr/local/cuda/lib64/libcublas.so.12"
"/usr/local/cuda/lib64/libcublasLt.so.12"
"/usr/local/cuda/lib64/libcudart.so.12"
"/usr/local/cuda/lib64/libnvrtc.so.12"
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
DEPS_SONAME+=(
"libnvToolsExt.so.1"
"libcublas.so.12"
"libcublasLt.so.12"
"libcudart.so.12"
"libnvrtc.so.12"
"libcupti.so.12")
if [[ $CUDA_VERSION != 12.9* ]]; then
DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
DEPS_SONAME+=("libnvToolsExt.so.1")
fi
fi
else
echo "Using nvidia libs from pypi."

View File

@ -1615,7 +1615,6 @@ test_operator_benchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
TEST_DIR=$(pwd)
ARCH=$(uname -m)
test_inductor_set_cpu_affinity
@ -1630,7 +1629,7 @@ test_operator_benchmark() {
pip_install pandas
python check_perf_csv.py \
--actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \
--expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv"
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
}
test_operator_microbenchmark() {

View File

@ -65,7 +65,7 @@ runs:
cd .ci/lumen_cli
python3 -m pip install -e .
)
MAX_JOBS="$(nproc --ignore=10)"
MAX_JOBS="$(nproc --ignore=6)"
export MAX_JOBS
# Split the comma-separated list and build each target

View File

@ -1,63 +0,0 @@
name: Get changed files
description: >-
Returns a space-separated list of changed files for a PR, or '*' when not in a PR.
Mirrors the logic from the original workflow but packaged as a reusable composite action.
inputs:
all_files:
description: 'Whether to return all files instead of just changed files'
required: false
default: 'false'
outputs:
changed-files:
description: "List of changed files (space-separated) or '*' if not in a PR"
value: ${{ steps.get-files.outputs.changed-files }}
runs:
using: composite
steps:
- id: get-files
name: Get changed files (bash)
shell: bash
env:
GH_TOKEN: ${{ github.token }}
run: |
# Call the bundled entrypoint script. Running via bash avoids needing execute bit set.
# Check if we're in a pull request context
if [ "${{ github.event_name }}" = "pull_request" ] || [ "${{ github.event_name }}" = "pull_request_target" ]; then
echo "Running in PR context"
# Get the PR number from the github context
PR_NUMBER="${{ github.event.number }}"
# Check if all_files is requested
if [ "${{ inputs.all_files }}" = "true" ]; then
echo "all_files input is true, returning all files"
echo "changed-files=*" >> "$GITHUB_OUTPUT"
else
# Use gh CLI to get changed files in the PR with explicit repo
CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//')
# See https://github.com/pytorch/pytorch/pull/134215#issuecomment-2332128790
PYI_FILES_TO_ADD=""
for file in ${CHANGED_FILES}; do
if [[ "${file}" == *".pyi.in" ]]; then
PYI_FILES_TO_ADD="${PYI_FILES_TO_ADD} ${file//.in/}"
fi
done
CHANGED_FILES="${CHANGED_FILES}${PYI_FILES_TO_ADD}"
if [ -z "$CHANGED_FILES" ]; then
echo "No changed files found, setting to '*'"
CHANGED_FILES="*"
fi
echo "Changed files: $CHANGED_FILES"
echo "changed-files=$CHANGED_FILES" >> "$GITHUB_OUTPUT"
fi
else
echo "Not in PR context, setting changed files to '*'"
echo "changed-files=*" >> "$GITHUB_OUTPUT"
fi

View File

@ -1 +1 @@
1b013f5b5a87a1882eb143c26d79d091150d6a37
8ad2aa5d354d1bf432339113860185d5a5d1abbd

View File

@ -1 +1 @@
faffd5cf673615583da6517275e361cb3dbc77e6
f5c6c2ec6490455e86f67b2a25c10390d60a27f7

View File

@ -0,0 +1,64 @@
name: Get Changed Files
on:
workflow_call:
inputs:
all_files:
description: "Whether to return all files instead of just changed files"
required: false
type: boolean
default: false
outputs:
changed-files:
description: "List of changed files (space-separated) or '*' if not in a PR"
value: ${{ jobs.get-changed-files.outputs.changed-files }}
jobs:
get-changed-files:
runs-on: ubuntu-latest
outputs:
changed-files: ${{ steps.get-files.outputs.changed-files }}
steps:
- name: Get changed files
id: get-files
env:
GH_TOKEN: ${{ github.token }}
run: |
# Check if we're in a pull request context
if [ "${{ github.event_name }}" = "pull_request" ] || [ "${{ github.event_name }}" = "pull_request_target" ]; then
echo "Running in PR context"
# Get the PR number from the github context
PR_NUMBER="${{ github.event.number }}"
# Check if all_files is requested
if [ "${{ inputs.all_files }}" = "true" ]; then
echo "all_files input is true, returning all files"
echo "changed-files=*" >> "$GITHUB_OUTPUT"
else
# Use gh CLI to get changed files in the PR with explicit repo
CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//')
# See https://github.com/pytorch/pytorch/pull/134215#issuecomment-2332128790
PYI_FILES_TO_ADD=""
for file in ${CHANGED_FILES}; do
if [[ "${file}" == *".pyi.in" ]]; then
PYI_FILES_TO_ADD="${PYI_FILES_TO_ADD} ${file//.in/}"
fi
done
CHANGED_FILES="${CHANGED_FILES}${PYI_FILES_TO_ADD}"
if [ -z "$CHANGED_FILES" ]; then
echo "No changed files found, setting to '*'"
CHANGED_FILES="*"
fi
echo "Changed files: $CHANGED_FILES"
echo "changed-files=$CHANGED_FILES" >> "$GITHUB_OUTPUT"
fi
else
echo "Not in PR context, setting changed files to '*'"
echo "changed-files=*" >> "$GITHUB_OUTPUT"
fi

View File

@ -27,8 +27,9 @@ jobs:
fail-fast: false
matrix:
python-version: [ '3.12' ]
# TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
device: [ 'cu128', 'cu129', 'cu130' ]
device: [ 'cu128', 'cu129' ]
include:
- platform: manylinux_2_28_x86_64
device: cu128
@ -38,10 +39,6 @@ jobs:
device: cu129
manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
runner: linux.12xlarge.memory
- platform: manylinux_2_28_x86_64
device: cu130
manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
runner: linux.12xlarge.memory
- platform: manylinux_2_28_aarch64
device: cu128
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
@ -50,11 +47,6 @@ jobs:
device: cu129
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
runner: linux.arm64.r7g.12xlarge.memory
exclude:
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
# xformers is update to support 13.0
- platform: manylinux_2_28_aarch64
device: cu130
name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
runs-on: ${{ matrix.runner }}
timeout-minutes: 480
@ -177,12 +169,7 @@ jobs:
fail-fast: false
matrix:
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
device: [ 'cu128', 'cu129', 'cu130' ]
exclude:
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
# xformers is update to support 13.0
- platform: manylinux_2_28_aarch64
device: cu130
device: [ 'cu128', 'cu129' ]
env:
PLATFORM: ${{ matrix.platform }}
BUILD_DEVICE: ${{ matrix.device }}

View File

@ -12,6 +12,7 @@ on:
- landchecks/*
tags:
- ciflow/pull/*
- ciflow/trunk/*
workflow_dispatch:
permissions: read-all
@ -30,36 +31,30 @@ jobs:
get-changed-files:
if: github.repository_owner == 'pytorch'
name: Get changed files
runs-on: ubuntu-latest
outputs:
changed-files-matrix: ${{ steps.output-test-matrix.outputs.changed-files-matrix }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 1
submodules: false
- name: Get changed files
id: get-changed-files
uses: ./.github/actions/get-changed-files
with:
all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') || github.event_name == 'push' }}
- name: Output test-matrix
id: output-test-matrix
run: |
set -euox pipefail
MATRIX=$(jq -n '[{"changed-files": "${{ steps.get-changed-files.outputs.changed-files }}"}, {"changed-files": "*"}]' | jq -c 'unique')
echo "changed-files-matrix={\"include\":$MATRIX}" >> "$GITHUB_OUTPUT"
uses: ./.github/workflows/_get-changed-files.yml
with:
all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') || github.event_name == 'push' }}
lintrunner-clang:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
# Needed to prevent deduping on HUD
name: lintrunner-clang-${{ matrix.changed-files == '*' && 'all' || 'partial' }}
name: lintrunner-clang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
strategy:
matrix: ${{ fromJson(needs.get-changed-files.outputs.changed-files-matrix) }}
# Only run if there are changed files relevant to clangtidy / clangformat
if: github.repository_owner == 'pytorch'
if: |
github.repository_owner == 'pytorch' && (
needs.get-changed-files.outputs.changed-files == '*' ||
contains(needs.get-changed-files.outputs.changed-files, '.h') ||
contains(needs.get-changed-files.outputs.changed-files, '.cpp') ||
contains(needs.get-changed-files.outputs.changed-files, '.cc') ||
contains(needs.get-changed-files.outputs.changed-files, '.cxx') ||
contains(needs.get-changed-files.outputs.changed-files, '.hpp') ||
contains(needs.get-changed-files.outputs.changed-files, '.hxx') ||
contains(needs.get-changed-files.outputs.changed-files, '.cu') ||
contains(needs.get-changed-files.outputs.changed-files, '.cuh') ||
contains(needs.get-changed-files.outputs.changed-files, '.mm') ||
contains(needs.get-changed-files.outputs.changed-files, '.metal')
)
with:
timeout: 120
runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
@ -70,7 +65,7 @@ jobs:
submodules: true
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
CHANGED_FILES="${{ matrix.changed-files }}"
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
if [ "$CHANGED_FILES" = "*" ]; then
export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files"
else
@ -83,12 +78,15 @@ jobs:
# fails to find types when it should
lintrunner-mypy:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
name: lintrunner-mypy-${{ matrix.changed-files == '*' && 'all' || 'partial' }}
name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
# Only run if there are changed files relevant to mypy
if: github.repository_owner == 'pytorch'
strategy:
matrix: ${{ fromJson(needs.get-changed-files.outputs.changed-files-matrix) }}
if: |
github.repository_owner == 'pytorch' && (
needs.get-changed-files.outputs.changed-files == '*' ||
contains(needs.get-changed-files.outputs.changed-files, '.py') ||
contains(needs.get-changed-files.outputs.changed-files, '.pyi')
)
with:
timeout: 120
runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
@ -99,20 +97,13 @@ jobs:
submodules: true
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running mypy"
CHANGED_FILES="${{ matrix.changed-files }}"
if [[ "$CHANGED_FILES" == *".py"* || "$CHANGED_FILES" == "*" ]]; then
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
else
echo "No .py files changed, skipping mypy"
fi
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
lintrunner-noclang:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
strategy:
matrix: ${{ fromJson(needs.get-changed-files.outputs.changed-files-matrix) }}
name: lintrunner-noclang-${{ matrix.changed-files == '*' && 'all' || 'partial' }}
name: lintrunner-noclang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
with:
timeout: 120
@ -124,7 +115,7 @@ jobs:
submodules: true
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
CHANGED_FILES="${{ matrix.changed-files }}"
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running all other linters"
if [ "$CHANGED_FILES" = '*' ]; then
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh

View File

@ -7,11 +7,9 @@ on:
workflow_dispatch:
inputs:
test_mode:
type: choice
options:
- 'short'
- 'long'
- 'all'
required: false
type: string
default: 'short'
description: tag filter for operator benchmarks, options from long, short, all
schedule:
# Run at 07:00 UTC every Sunday
@ -30,25 +28,38 @@ permissions:
contents: read
jobs:
x86-opbenchmark-build:
opbenchmark-build:
if: github.repository_owner == 'pytorch'
name: x86-opbenchmark-build
name: opbenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
]}
secrets: inherit
x86-opbenchmark-test:
name: x86-opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: x86-opbenchmark-build
opbenchmark-on-demand-build:
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
name: opbenchmark-on-demand-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
]}
secrets: inherit
opbenchmark-test:
name: opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: opbenchmark-build
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }}
secrets: inherit

View File

@ -256,7 +256,6 @@ endif()
IF(USE_FBGEMM_GENAI)
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
@ -293,64 +292,58 @@ IF(USE_FBGEMM_GENAI)
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
)
target_include_directories(fbgemm_genai PRIVATE
target_include_directories(fbgemm_genai PUBLIC
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${fbgemm_genai_mx8mx8bf16_grouped}
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
else()
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add FBGEMM_GENAI include directories for torch_ops.h
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
elseif(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
endif()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
endif()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PRIVATE
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
# Add FBGEMM_GENAI include directories for torch_ops.h
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
endif()
@ -699,6 +692,12 @@ if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
# Add FBGEMM_GENAI include directories for torch_ops.h
if(USE_FBGEMM_GENAI)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
if($ENV{ATEN_STATIC_CUDA})
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS

View File

@ -13,7 +13,6 @@
#include <c10/core/ScalarType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
@ -151,7 +150,6 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
BLASType = "unknown";
}
return BLASType;
}
// Similar to Compute Type in GemmRocblas.h
@ -246,25 +244,33 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
if (!config.enabled) {
return true; // skip when disabled
}
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {size}, options);
at::Tensor oth = at::from_blob(other_c, {size}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
if (ok) {
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
} else {
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
return ok;
if (last_succeed_atol == 1) {
return false;
}
else {
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return true;
}
}
@ -349,10 +355,8 @@ struct GemmParams : OpParams {
}
TuningStatus NumericalCheck(GemmParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa{};
@ -445,10 +449,8 @@ struct GemmAndBiasParams : OpParams {
}
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa{};
@ -544,10 +546,8 @@ struct GemmStridedBatchedParams : OpParams {
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa{};
@ -663,9 +663,7 @@ struct ScaledGemmParams : OpParams {
}
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa{};

View File

@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
@ -173,9 +173,10 @@ All python APIs exist in the `torch.cuda.tunable` module.
| get_max_tuning_iterations() -> int | |
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
| get_filename() -> str | |
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
| get_results() -> Tuple[str, str, str, float] | |
| get_validators() -> Tuple[str, str] | |
| write_file_on_exit(val: bool) -> None | Default is True. |
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |

View File

@ -107,30 +107,14 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
}
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
bool is_new = false;
ResultEntry inserted = ResultEntry::Null();
std::scoped_lock l{lock_};
// ---- mutate maps under results lock ----
{
std::scoped_lock l{lock_};
auto& km = results_[op_signature]; // creates if missing
is_new = (km.find(params_signature) == km.end());
AddImpl(op_signature, params_signature, std::move(best), km);
if (is_new) {
inserted = km.at(params_signature); // snapshot for I/O after unlocking
}
}
if (!is_new) return; // only write once per unique (op, params)
TuningContext* ctx = getTuningContext();
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
if (is_new && realtime_out_ && realtime_out_->good()) {
AppendResultLine(op_signature, params_signature, inserted);
}
auto it = results_.find(op_signature);
if (it == results_.end()) {
it = results_.insert({op_signature, {}}).first;
}
AddImpl(op_signature, params_signature, std::move(best), it->second);
}
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
@ -166,77 +150,6 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
}
}
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
std::scoped_lock fl{realtime_file_mutex_};
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
return;
}
if (realtime_out_ && realtime_filename_ != filename) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
validators_written_ = false;
}
bool file_exists = false;
bool file_empty = true;
{
std::ifstream check_file(filename);
if (check_file.good()) {
file_exists = true;
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
}
}
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
if (!realtime_out_->good()) {
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
realtime_out_.reset();
return;
}
if(!file_exists || file_empty) {
for(const auto& [key, val] : validators) {
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
realtime_out_->flush();
}
validators_written_ = true;
TUNABLE_LOG2("Wrote validators to realtime output file");
}
realtime_filename_ = filename;
}
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
std::scoped_lock fl{realtime_file_mutex_};
if(!realtime_out_ || !realtime_out_->good()) {
return;
}
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
realtime_out_->flush(); //ensure immediate write to disk
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
}
void TuningResultsManager::CloseRealtimeAppend() {
std::scoped_lock fl{realtime_file_mutex_};
if(realtime_out_) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
TUNABLE_LOG2("Closed realtime output file");
}
}
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_};
@ -483,6 +396,7 @@ TuningContext::TuningContext() :
tuning_enable_{true},
record_untuned_enable_{false},
manager_initialized_{false},
write_file_on_exit_{true},
numerics_check_enable_{false},
max_tuning_duration_ms_{30},
max_tuning_iterations_{100},
@ -503,8 +417,20 @@ TuningContext::~TuningContext() {
// but doesn't do any computation itself.
return;
}
TUNABLE_LOG1("Closing File");
GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
auto filename = GetFilename();
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
if (results_count_from_input_file_ > 0) {
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
}
else {
TUNABLE_LOG1("writing file ", filename);
}
if (!WriteFile(filename)) {
TUNABLE_LOG1("failed to write file ", filename);
}
}
}
if (untuned_file_.good()) {
untuned_file_.close();
@ -585,54 +511,20 @@ std::ofstream& TuningContext::GetUntunedFile(){
return untuned_file_;
}
void TuningContext::WriteFileOnExit(bool value) {
write_file_on_exit_ = value;
}
void TuningContext::EnableNumericsCheck(bool value) {
numerics_check_enable_ = value;
}
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (!env_opt.has_value()) {
return numerics_cfg_;
}
const std::string& env = env_opt.value();
if (env == "0") {
return NumericalCheckConfig(false, 1e-5, 1e-5);
}
const size_t underscore = env.find('_');
TORCH_CHECK(
underscore != std::string::npos,
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
"Expected 'atol_rtol', got: ",
env);
double atol = 0.0;
double rtol = 0.0;
try {
atol = std::stod(env.substr(0, underscore));
rtol = std::stod(env.substr(underscore + 1));
} catch (const std::exception& e) {
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
}
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
return NumericalCheckConfig(true, atol, rtol);
}
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
numerics_cfg_ = {enabled, atol, rtol};
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto cfg = GetNumericalCheckConfig();
return cfg.enabled || numerics_check_enable_;
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (env == "1") {
return true;
}
return numerics_check_enable_;
}
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
@ -742,6 +634,11 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
auto filename = GetFilename();
if (!filename.empty() && !IsRecordUntunedEnabled()) {
ReadFile(filename);
// attempt immediately to open file for writing to catch errors early
std::ofstream file(filename, std::ios::out | std::ios::app);
if (!file.good()) {
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
}
}
});
return manager_;
@ -847,6 +744,27 @@ bool TuningContext::ReadFile(const std::string& filename_) {
return true;
}
bool TuningContext::WriteFile(const std::string& filename_) {
std::string filename = filename_.empty() ? GetFilename() : filename_;
std::ofstream file(filename, std::ios::out | std::ios::trunc);
if (!file.good()) {
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
return false;
}
auto validators = GetTuningResultsValidator().GetAllValidators();
for (const auto& [key, val] : validators) {
file << "Validator," << key << "," << val << std::endl;
}
auto results = GetTuningResultsManager().Dump();
for (const auto& [op_sig, kernelmap] : results) {
for (const auto& [param_sig, result] : kernelmap) {
file << op_sig << "," << param_sig << "," << result << std::endl;
}
}
file.close();
return true;
}
namespace {
struct MaybeDelete {

View File

@ -103,24 +103,10 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
const std::string& params_signature, const std::string& blas_signature);
void InitRealtimeAppend(
const std::string& filename,
const std::unordered_map<std::string, std::string>& validators);
void AppendResultLine(const std::string& op_sig,
const std::string& param_sig,
const ResultEntry& result);
void CloseRealtimeAppend(); // For clean shutdown
private:
std::mutex lock_;
std::mutex realtime_file_mutex_;
std::unique_ptr<std::ofstream> realtime_out_;
std::string realtime_filename_;
ResultsMap results_;
UntunedMap untuned_results_;
bool validators_written_ = false;
};
@ -148,16 +134,6 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
GetValidateFuncs validators_;
};
struct NumericalCheckConfig {
bool enabled{false};
double atol{1e-5};
double rtol{1e-5};
NumericalCheckConfig() = default;
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
};
class TORCH_CUDA_CPP_API TuningContext {
public:
TuningContext();
@ -179,8 +155,6 @@ class TORCH_CUDA_CPP_API TuningContext {
void EnableNumericsCheck(bool value);
bool IsNumericsCheckEnabled() const;
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
NumericalCheckConfig GetNumericalCheckConfig() const;
void SetMaxTuningDurationMs(int max_duration_ms);
int GetMaxTuningDurationMs() const;
@ -211,7 +185,10 @@ class TORCH_CUDA_CPP_API TuningContext {
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
std::string GetFilename() const;
void WriteFileOnExit(bool value);
bool ReadFile(const std::string& filename={});
bool WriteFile(const std::string& filename={});
template<class... Types>
void Log(int level, Types... args) {
@ -230,6 +207,7 @@ class TORCH_CUDA_CPP_API TuningContext {
bool tuning_enable_;
bool record_untuned_enable_;
bool manager_initialized_;
bool write_file_on_exit_;
bool numerics_check_enable_;
int max_tuning_duration_ms_;
int max_tuning_iterations_;
@ -244,8 +222,6 @@ class TORCH_CUDA_CPP_API TuningContext {
std::ofstream untuned_file_;
size_t results_count_from_input_file_;
bool is_shutting_down_;
NumericalCheckConfig numerics_cfg_{};
};
TORCH_CUDA_CPP_API TuningContext* getTuningContext();

View File

@ -267,10 +267,27 @@ class TunableOp {
for (size_t i = 0; i < op_names_.size(); i++) {
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
else {
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
// collect a small profile
@ -293,22 +310,6 @@ class TunableOp {
continue;
}
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
// for warmup does user set max duration, max iters, or both?
// warmup is skipped by default, i.e. warmup_iter = 0
// warmup will be set to the non-zero value of max_warmup_duration

View File

@ -213,22 +213,40 @@ static cudnn_grid_sample_backward_batch_rule(
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
}
// uses functional formulation for one_hot under vmap to be compatible with
// fakeTensor/dynamic shapes and compiled functorch transforms.
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
// but requires explicit positive num_classes under vmap to avoid
// data-dependent output shapes.
// TODO: replace with targetable functionalization
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
auto shape = self.sym_sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.sym_numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.emplace_back(num_classes);
return at::empty_symint(shape, self.options());
}
}
// disallow implicit inference under vmap; this would be data-dependent
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
"provide an explicit positive num_classes argument.");
const auto options = self.options();
at::Tensor index = at::arange(num_classes, options);
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
// Disabling all of the following checks. This is OK because scatter has checks too.
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
// // non-empty tensor
// if (self.device().type() != at::kCUDA) {
// //for cuda, rely on device assert thrown by scatter
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
// }
// if (self.device().type() != at::kCUDA) {
// //rely on device asserts from scatter to avoid sync here
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
// }
shape.emplace_back(num_classes);
Tensor ret = at::zeros_symint(shape, self.options());
return ret.scatter(-1, self.unsqueeze(-1), 1);
}
template <typename A, A a, typename C>

View File

@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
}
}
auto shape = self.sym_sizes().vec();
auto shape = self.sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.sym_numel() == 0) {
if (self.numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.emplace_back(num_classes);
return at::empty_symint(shape, self.options());
shape.push_back(num_classes);
return at::empty(shape, self.options());
}
}
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
}
}
shape.emplace_back(num_classes);
Tensor ret = at::zeros_symint(shape, self.options());
shape.push_back(num_classes);
Tensor ret = at::zeros(shape, self.options());
ret.scatter_(-1, self.unsqueeze(-1), 1);
return ret;
}

View File

@ -1906,9 +1906,11 @@ Tensor& index_fill_(
"This also applies to advanced indexing e.g. tensor[mask] = scalar");
}
TORCH_CHECK(
self.is_complex() || !source.isComplex(),
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
if (!self.is_complex() && source.isComplex()) {
TORCH_CHECK(
false,
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
}
// Handle the case when `self` is 0-dim
Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -1230,205 +1230,8 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
);
}
Tensor&
_tunable_scaled_gemm_rocm(
cublasCommonArgs& args,
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
const at::ScalarType out_dtype,
Tensor& out) {
#ifdef USE_ROCM
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices");
#endif
}
Tensor&
_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
// ROCM enables the TunableOp path only
// but can fallback to at::cuda::blas::scaled_gemm
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled();
#else
bool tunable_op_enabled = false;
#endif
if (tunable_op_enabled) {
// Only available on ROCM
return _tunable_scaled_gemm_rocm(
args,
mat1, mat2,
scale_a, scale_b,
scaling_choice_a, scaling_choice_b,
bias,
use_fast_accum,
out_dtype_,
out);
}
else
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
return out;
}
}
} // namespace
// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here
// to help cleanup v1 call structure.
Tensor&
_scaled_rowwise_rowwise(
const Tensor&, const Tensor&,
const Tensor&, const Tensor&,
const std::optional<Tensor>&,
const c10::ScalarType,
bool,
Tensor&);
// Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
@ -1506,7 +1309,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
#ifndef USE_ROCM
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
#endif
if (use_fast_accum) {
@ -1572,44 +1375,41 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
// and only for compute capability 9.0+. In other cases we use CUTLASS.
// We are doing row-wise scaling
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
return _scaled_rowwise_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
out.scalar_type(),
use_fast_accum,
out);
}
// We are doing row-wise scaling
auto dprops = at::cuda::getCurrentDeviceProperties();
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
&& ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
use_fast_accum,
out);
return out;
}
#else
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
Tensor b = mat2;
if (_scaled_mm_is_fnuz()) {
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz,
"Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype());
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
}
else {
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn,
"Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype());
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
}
// Until more than bf16 is supported.
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
}
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
#ifdef USE_ROCM
#if ROCM_VERSION >= 70000
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
int packed_factor = 1;
@ -1618,20 +1418,163 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// effectively packing two elements into one byte.
packed_factor = 2;
}
TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
mat2.size(1) % 16 == 0,
"M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
}
#endif
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
}
namespace {
@ -1971,6 +1914,159 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
Tensor&
_cutlass_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
}
Tensor&
_scaled_tensorwise_tensorwise(
const Tensor& mat_a, const Tensor& mat_b,
@ -1990,7 +2086,7 @@ _scaled_tensorwise_tensorwise(
auto scaling_choice_a = ScalingType::TensorWise;
auto scaling_choice_b = ScalingType::TensorWise;
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2026,7 +2122,7 @@ _scaled_rowwise_rowwise(
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat_a,
mat_b,
@ -2052,38 +2148,11 @@ _scaled_rowwise_rowwise(
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
// and strides become somewhat meaningless
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
if (scale_type == ScalingType::BlockWise1x128) {
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
} else if (scale_type == ScalingType::BlockWise128x128) {
TORCH_CHECK_VALUE(check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
ceil_div<int64_t>(t.size(1), 128)),
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
TORCH_CHECK(check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
}
}
Tensor&
_scaled_block1x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
@ -2101,14 +2170,15 @@ _scaled_block1x128_block1x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2123,8 +2193,6 @@ _scaled_block128x128_block1x128(
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
@ -2132,14 +2200,15 @@ _scaled_block128x128_block1x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2161,14 +2230,15 @@ _scaled_block1x128_block128x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
"expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2222,7 +2292,7 @@ _scaled_mxfp8_mxfp8(
#endif
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
@ -2259,7 +2329,7 @@ _scaled_nvfp4_nvfp4(
auto scaling_choice_a = ScalingType::BlockWise1x16;
auto scaling_choice_b = ScalingType::BlockWise1x16;
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
@ -2508,9 +2578,7 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;
@ -2521,16 +2589,6 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
#ifdef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
#else
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
#endif
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
fbgemm_gpu::mx8mx8bf16_grouped_mm(
@ -2615,9 +2673,6 @@ _f8_f8_bf16_rowwise_grouped_mm(
const std::optional<Tensor>& bias,
bool use_fast_accum,
Tensor& out) {
// FP8 per-tensor and per-row scaling expect fp32 scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"For grouped FP8 rowwise, both scales must be float32 tensors");
#ifndef USE_ROCM
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
mat_a,
@ -2717,15 +2772,11 @@ _scaled_grouped_mm_cuda(
#endif
if (is_mx8mx8bf16) {
// Note: Passing implied SwizzleType here, correctness of scale previously checked
// in `check_scale` call
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a,
SwizzleType::SWIZZLE_32_4_4,
scale_b,
SwizzleType::SWIZZLE_32_4_4,
offs.value(),
out);
}
@ -2742,140 +2793,6 @@ _scaled_grouped_mm_cuda(
out);
}
namespace {
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
} // anonymous namespace
Tensor
_scaled_grouped_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
if (contraction_dim.size() > 0) {
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
mat_b.size(dim_b));
// Note: only (-1, -2) is currently supported
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
} else {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
break;
}
}
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
"No gemm implementation was found");
switch (gemm_impl) {
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
offs,
bias,
use_fast_accum,
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0],
swizzle_a_enum[0],
scale_b[0],
swizzle_b_enum[0],
offs.value(),
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
}
}
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,

View File

@ -856,13 +856,9 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -870,9 +866,12 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
cret_t,
carg0_t,
carg1_t>(
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
numel,
f,
data,
@ -880,7 +879,6 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -38,41 +38,12 @@ __device__ inline int min(int a, int b) {
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
#endif
template <typename index_t>
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
}
template <typename index_t>
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
return std::min((size + pad) / stride + 1, pooled_size);
}
static inline bool can_use_int32_nhwc(
int64_t nbatch, int64_t channels,
int64_t height, int64_t width,
int64_t pooled_height, int64_t pooled_width,
int64_t in_stride_n, int64_t in_stride_c,
int64_t in_stride_h, int64_t in_stride_w)
{
constexpr int64_t int_max = std::numeric_limits<int>::max();
int64_t max_intra_batch =
(height ? (height - 1) * in_stride_h : 0) +
(width ? (width - 1) * in_stride_w : 0) +
(channels? (channels - 1) * in_stride_c : 0);
int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;
if (max_input_offset > int_max) return false;
int64_t out_batch_stride = pooled_height * pooled_width * channels;
if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;
if (height * width > int_max) return false;
return true;
static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
return min((size + pad) / stride + 1, pooled_size);
}
// kernels borrowed from Caffe
@ -114,25 +85,21 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom
}
}
template <typename scalar_t, typename index_t>
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
__global__ void max_pool_forward_nhwc(
const scalar_t* bottom_data,
const int nbatch,
const index_t channels, const index_t height, const index_t width,
const index_t pooled_height, const index_t pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const index_t in_stride_n, const index_t in_stride_c,
const index_t in_stride_h, const index_t in_stride_w,
const int kernel_stride_C, const int kernel_size_C,
scalar_t* top_data, int64_t* top_mask) {
extern __shared__ unsigned char smem_raw[];
index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
scalar_t *out_cached = reinterpret_cast<scalar_t*>(
out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);
__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
const int64_t channels, const int64_t height,
const int64_t width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int in_stride_n, const int in_stride_c,
const int in_stride_h, const int in_stride_w,
const int kernel_stride_C, const int kernel_size_C,
scalar_t* top_data, int64_t* top_mask) {
extern __shared__ int smem[];
int *out_mask_cached = smem;
scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
// flattening cta for pre-computation & smem initialization;
int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
@ -151,26 +118,26 @@ __global__ void max_pool_forward_nhwc(
int channel_id = blockIdx.x / nbatch;
int channel_offset = threadIdx.x + channel_id * blockDim.x;
top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;
top_data = top_data + batch_id * pooled_height * pooled_width * channels;
top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
bottom_data = bottom_data + batch_id * in_stride_n;
out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
out_mask_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
int oW = (static_cast<int>(pooled_width) + gridDim.y - 1) / gridDim.y;
int oH = (pooled_height + gridDim.z-1) / gridDim.z;
int oW = (pooled_width + gridDim.y-1) / gridDim.y;
int ostartH = threadIdx.z + blockIdx.z*oH;
int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
int oendH = ::min(ostartH+oH, pooled_height);
int ostartW = threadIdx.y + blockIdx.y*oW;
int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));
int oendW = ::min(ostartW+oW, pooled_width);
for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
int hstart = oh * stride_h - pad_h;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
int wstart = ow * stride_w - pad_w;
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
@ -218,12 +185,12 @@ __global__ void max_pool_forward_nhwc(
// Else do it Non-Prefetch...
else
#endif
for (index_t ih = hstart; ih < hend; ih += dilation_h) {
for (index_t iw = wstart; iw < wend; iw += dilation_w) {
for (int ih = hstart; ih < hend; ih += dilation_h) {
for (int iw = wstart; iw < wend; iw += dilation_w) {
int cached_index = threadIdx.x;
const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
scalar_t val = ptr_input[c * in_stride_c];
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
scalar_t val = ptr_input[c*in_stride_c];
if ((val > out_cached[cached_index]) || at::_isnan(val)) {
out_cached[cached_index] = val;
out_mask_cached[cached_index] = ih * width + iw;
@ -233,15 +200,15 @@ __global__ void max_pool_forward_nhwc(
}
}
scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
int cached_index = threadIdx.x;
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
ptr_output_data[c] = out_cached[cached_index];
ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
ptr_output_mask[c] = out_mask_cached[cached_index];
out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
out_mask_cached[cached_index] = index_t(0);
out_mask_cached[cached_index] = 0;
cached_index += blockDim.x;
}
}
@ -495,11 +462,6 @@ const Tensor& indices) {
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
const dim3 block(block_x, block_y, block_z);
bool use_int32 = can_use_int32_nhwc(
nbatch, nInputPlane, inputHeight, inputWidth,
outputHeight, outputWidth,
in_stride_n, in_stride_c, in_stride_h, in_stride_w);
int kernel_stride_C = ceil_div(
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
int kernel_size_C = ceil_div(
@ -514,41 +476,18 @@ const Tensor& indices) {
ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
const dim3 grid(grid_x, grid_y, grid_z);
size_t shmem_size;
size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
if (use_int32) {
shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
"shared memory too small");
max_pool_forward_nhwc<scalar_t, int32_t>
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
input_data, static_cast<int>(nbatch),
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
static_cast<int32_t>(in_stride_n),
static_cast<int32_t>(in_stride_c),
static_cast<int32_t>(in_stride_h),
static_cast<int32_t>(in_stride_w),
kernel_stride_C, kernel_size_C,
output_data, indices_data);
} else {
shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
"shared memory too small");
max_pool_forward_nhwc<scalar_t, int64_t>
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
input_data, static_cast<int>(nbatch),
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
in_stride_n, in_stride_c, in_stride_h, in_stride_w,
kernel_stride_C, kernel_size_C,
output_data, indices_data);
}
max_pool_forward_nhwc<scalar_t>
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
input_data, nbatch,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
in_stride_n, in_stride_c,
in_stride_h, in_stride_w,
kernel_stride_C, kernel_size_C,
output_data, indices_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}

View File

@ -655,14 +655,8 @@ struct ReduceOp {
}
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// todo for AMD
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);

View File

@ -77,8 +77,8 @@ struct nansum_functor_complex {
#if AT_USE_JITERATOR()
void operator()(TensorIterator& iter) {
std::string func = jiterator_stringify(
arg_t combine(arg_t a, arg_t b) {
return a + (std::isnan(b) ? arg_t{0.} : b);
arg_t combine(arg_t a, scalar_t b) {
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
}
);
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(

View File

@ -464,7 +464,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
}
#endif
int32_t trailingSize;
int nDimsLocal = nDims;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
if (isInOutAligned) {
// in this case we can and should flatten the tensors after the cat dim
@ -478,7 +477,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
// and divide all strides except last by elems_per_vec (last stride is 1 always)
// for input, we will fix up the sizes and strides in the kernel directly
kernelOutputParam = outputParam;
nDimsLocal = dimension + 1;
nDims = dimension + 1;
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
@ -495,7 +494,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
case 0:
break;
case 1:
cat_dim = nDimsLocal - cat_dim;
cat_dim = nDims - cat_dim;
break;
default:
cat_dim--;
@ -526,7 +525,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
}\
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDimsLocal) {
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;

View File

@ -21,15 +21,9 @@ namespace {
struct offset_t {
int stride;
int begin;
__device__ int operator[](int i) const {
__device__ int operator[](int i) {
return stride * (begin + i);
}
#if CCCL_VERSION >= 3001000
__device__ offset_t& operator+=(int i) {
begin += i;
return *this;
}
#endif
};
// Segmented sort by full sort algorithm:.
// Say we are sorting a (2, 3) tensor. We have in flattened form:

View File

@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
#ifdef USE_ROCM
// Helper function to compute output pixel range that can contribute to input pixel
template <typename accscalar_t>
__device__ __forceinline__ void compute_output_range(
int input_pos,
accscalar_t scale,
int output_size,
bool align_corners,
int& min_output,
int& max_output) {
accscalar_t lo, hi;
if (align_corners) {
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
} else {
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
}
min_output = max(0, static_cast<int>(ceil(lo)));
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
}
#endif
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
#ifdef USE_ROCM
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
index += blockDim.x * gridDim.x) {
// Decode input pixel coordinates
size_t index_temp = index;
const int w1 = index_temp % width1;
index_temp /= width1;
const int h1 = index_temp % height1;
const size_t nc_idx = index_temp / height1;
accscalar_t grad_sum = 0;
// Find range of output pixels that could interpolate from this input pixel
int h2_min, h2_max, w2_min, w2_max;
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
// Iterate over potential output pixels
for (int h2 = h2_min; h2 <= h2_max; h2++) {
for (int w2 = w2_min; w2 <= w2_max; w2++) {
// Compute source coordinates for this output pixel
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
rheight, h2, align_corners, /*cubic=*/false);
const int h1_base = (int)h1r;
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
const accscalar_t h1lambda = h1r - h1_base;
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1_base = (int)w1r;
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1_base;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
// Check if our input pixel participates in this interpolation and accumulate all weights
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
// to the same pixel, so we need to accumulate weights from all matching positions
accscalar_t weight = 0;
// Check all four interpolation positions and accumulate weights
if (h1 == h1_base && w1 == w1_base) {
weight += h0lambda * w0lambda; // top-left
}
if (h1 == h1_base && w1 == w1_base + w1p) {
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base) {
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
}
if (weight > 0) {
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
}
}
}
// Write accumulated gradient (no atomics needed)
idata[index] = static_cast<scalar_t>(grad_sum);
}
#else
const size_t o_numel = nc * width2 * height2;
const size_t i_numel = nc * width1 * height1;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
true);
}
#endif
}
template <typename scalar_t, typename accscalar_t>
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
#ifdef USE_ROCM
constexpr bool use_input = true;
#else
constexpr bool use_input = false;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * output_height * output_width;
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
input_height,
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
num_threads,

View File

@ -466,11 +466,7 @@ struct ReduceJitOp {
__syncthreads();
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = reducer::warp_shfl_down(value[i], offset);

View File

@ -160,12 +160,8 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
}
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
#if defined(__x86_64__) || defined(_M_X64)
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
#else
return false; //TF32 not supported on power system
#endif
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
}
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {

View File

@ -74,12 +74,8 @@ static bool use_mkldnn_bf32_linear() {
}
static bool use_mkldnn_tf32_linear() {
#if defined(__x86_64__) || defined(_M_X64)
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
#else
return false; // TF32 not supported on power system
#endif
}
Tensor mkldnn_linear(

View File

@ -114,13 +114,8 @@ static bool use_mkldnn_bf32_matmul() {
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
}
static bool use_mkldnn_tf32_matmul() {
#if defined(__x86_64__) || defined(_M_X64)
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
#else
return false; // TF32 not supported on power system
#endif
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
}
// returns an ideep::tensor

View File

@ -712,7 +712,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) {
} else if (scalar.isBoolean()) {
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool));
} else if (scalar.isComplex()) {
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexFloat));
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble));
} else {
TORCH_INTERNAL_ASSERT(scalar.isIntegral(false));
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong));

View File

@ -54,10 +54,6 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
if (self.numel() == 0 & other.numel() == 0) {
return zeros({}, self.options());
}
dot_check(self, other);
auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);

View File

@ -907,8 +907,6 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
"index_fill_(): Expected dtype int32 or int64 for index");
TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor");
TORCH_CHECK(self.is_complex() || !source.is_complex(),
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
// MPS.scatter crashes if used with complex dtypes
TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported");

View File

@ -7183,12 +7183,6 @@
CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda_v2
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function
dispatch:
@ -7384,7 +7378,7 @@
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_mask
SparseCPU, SparseCUDA: sparse_mask
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed
autogen: sparse_mask.out

View File

@ -1,8 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -15,8 +13,6 @@
#include <ATen/ops/mul_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_native.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/argsort.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/copy_sparse_to_sparse.h>
#include <ATen/ops/mul.h>
@ -440,137 +436,4 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
return out;
}
using OptTensor = std::optional<Tensor>;
static void sparse_mask_apply_out_mps_kernel(
Tensor& result,
const Tensor& src_in,
const Tensor& mask_in,
bool accumulate_matches,
bool require_same_sizes,
bool coalesce_mask) {
TORCH_CHECK(src_in.is_sparse() && mask_in.is_sparse(),
"sparse_mask: expected both inputs to be sparse COO");
TORCH_CHECK(src_in.is_mps() && mask_in.is_mps(),
"sparse_mask: expected tensors to be on MPS device");
TORCH_CHECK(src_in.sparse_dim() == mask_in.sparse_dim(),
"sparse_mask: sparse_dim mismatch: ", src_in.sparse_dim(), " vs ", mask_in.sparse_dim());
if (require_same_sizes) {
TORCH_CHECK(src_in.sizes().equals(mask_in.sizes()),
"sparse_mask: sizes must match exactly (no broadcasting)");
}
auto src = src_in.coalesce();
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
const int64_t src_nnz = src._nnz();
const int64_t mask_nnz = mask._nnz();
const int64_t sd = src.sparse_dim();
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
auto commonDtype = at::result_type(src, mask);
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
"Can't convert result type ", commonDtype, " to output ", result.scalar_type());
if (mask_nnz == 0) {
alias_into_sparse(
result,
mask._indices().narrow(1, 0, 0),
at::empty({0}, result.options().dtype(result.scalar_type())));
result._coalesced_(mask.is_coalesced());
return;
}
TORCH_CHECK(sd > 0 || (src_nnz <= 1 && mask_nnz <= 1),
"sparse_mask: invalid sparse_dim or nnz");
if (sd == 0) {
auto out_indices = mask._indices().narrow(1, 0, 1);
auto out_values = src_nnz
? src._values().narrow(0, 0, 1).to(commonDtype)
: at::zeros({1}, at::device(result.device()).dtype(commonDtype));
alias_into_sparse(result, out_indices, out_values);
result._coalesced_(mask.is_coalesced());
return;
}
if (src_nnz == 0) {
auto out_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype);
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
alias_into_sparse(result, out_indices, out_values);
result._coalesced_(mask.is_coalesced());
return;
}
auto mask_indices = mask._indices().contiguous();
auto src_indices = src._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
const bool A_is_src = (src_nnz <= mask_nnz);
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
auto A_keys = A_is_src ? src_keys : mask_keys;
auto B_keys = A_is_src ? mask_keys : src_keys;
const auto device = result.device();
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_src);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
if (M > 0) {
auto src_match = outA_idx.narrow(0, 0, M);
auto mask_match = outB_idx.narrow(0, 0, M);
auto src_rows = src_values.index_select(0, src_match);
if (accumulate_matches) {
out_values.index_add_(0, mask_match, src_rows);
} else {
out_values.index_copy_(0, mask_match, src_rows);
}
}
alias_into_sparse(result, mask_indices, out_values);
result._coalesced_(mask.is_coalesced());
}
static void sparse_mask_intersection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
const Tensor& rhs,
const OptTensor& = std::nullopt) {
sparse_mask_apply_out_mps_kernel(
result,
/*src_in=*/lhs,
/*mask_in=*/rhs,
/*accumulate_matches=*/false,
/*require_same_sizes=*/false,
/*coalesce_mask=*/false);
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
} // namespace at::native

View File

@ -3,9 +3,6 @@
using namespace metal;
template <typename T> struct MulAccum { using type = float; };
template <> struct MulAccum<float2> { using type = float2; };
template <typename T>
kernel void dense_sparse_mul_kernel(
device const T* dense [[buffer(0)]],
@ -32,9 +29,8 @@ kernel void dense_sparse_mul_kernel(
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
using accum_t = typename MulAccum<T>::type;
const accum_t a = static_cast<accum_t>(values[val_idx]);
const accum_t b = static_cast<accum_t>(dense[dense_idx]);
const auto a = static_cast<float>(values[val_idx]);
const auto b = static_cast<float>(dense[dense_idx]);
out_values[val_idx] = static_cast<T>(a * b);
}
@ -134,8 +130,6 @@ kernel void fused_gather_mul_kernel(
INSTANTIATE_DENSE_SPARSE_MUL(float);
INSTANTIATE_DENSE_SPARSE_MUL(half);
INSTANTIATE_DENSE_SPARSE_MUL(bfloat);
INSTANTIATE_DENSE_SPARSE_MUL(long);
INSTANTIATE_DENSE_SPARSE_MUL(float2);
#define INSTANTIATE_FUSED_GATHER_MUL(DTYPE) \
template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void \

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 fail_accuracy pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass fail_accuracy 0
nfnet_l0 pass 0
repvgg_a2 fail_accuracy 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,fail_accuracy,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s fail_accuracy pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass fail_accuracy 6 7
nfnet_l0 pass 7
repvgg_a2 fail_accuracy 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 fail_accuracy pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass fail_accuracy 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 0
11 mobilenetv3_large_100 nfnet_l0 pass 0
12 mobilevit_s repvgg_a2 pass 0
nfnet_l0 pass 0
repvgg_a2 pass 0
swin_base_patch4_window7_224 pass 0
tf_efficientnet_b0 pass 0
13 visformer_small swin_base_patch4_window7_224 pass 0
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 pass 0
15 vit_base_patch16_siglip_256 visformer_small pass 0
16
17
18
19
55
56
57

View File

@ -10,18 +10,10 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -63,11 +55,3 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 mobilevit_s pass 7 6
11 mobilenetv3_large_100 nfnet_l0 pass 7
12 mobilevit_s repvgg_a2 pass 6 7
nfnet_l0 pass 7
repvgg_a2 pass 7
swin_base_patch4_window7_224 pass 7
tf_efficientnet_b0 pass 6
13 visformer_small swin_base_patch4_window7_224 pass 7
14 vit_base_patch14_dinov2.lvd142m tf_efficientnet_b0 fail_accuracy pass 7 6
15 vit_base_patch16_siglip_256 visformer_small pass 7
16
17
18
19
55
56
57

View File

@ -271,6 +271,8 @@ class TimmRunner(BenchmarkRunner):
memory_format=torch.channels_last if channels_last else None,
)
self.num_classes = model.num_classes
data_config = resolve_data_config(
vars(self._args) if timmversion >= "0.8.0" else self._args,
model=model,
@ -300,6 +302,7 @@ class TimmRunner(BenchmarkRunner):
example_inputs = [
example_inputs,
]
self.target = self._gen_target(batch_size, device)
self.loss = torch.nn.CrossEntropyLoss().to(device)
@ -367,6 +370,11 @@ class TimmRunner(BenchmarkRunner):
tolerance = 1e-2
return tolerance, cosine
def _gen_target(self, batch_size, device):
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
self.num_classes
)
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.

View File

@ -1,8 +1,6 @@
adv_inception_v3 128
beit_base_patch16_224 128
convnextv2_nano.fcmae_ft_in22k_in1k 128
deit_base_distilled_patch16_224 128
deit_tiny_patch16_224.fb_in1k 128
dm_nfnet_f0 128
ghostnet_100 512
inception_v3 128
@ -14,5 +12,3 @@ repvgg_a2 128
swin_base_patch4_window7_224 128
tf_efficientnet_b0 128
visformer_small 128
vit_base_patch14_dinov2.lvd142m 128
vit_base_patch16_siglip_256 128

View File

@ -1,8 +1,6 @@
adv_inception_v3,128
beit_base_patch16_224,64
convnextv2_nano.fcmae_ft_in22k_in1k,128
deit_base_distilled_patch16_224,64
deit_tiny_patch16_224.fb_in1k,128
dm_nfnet_f0,128
ghostnet_100,128
inception_v3,128
@ -14,5 +12,3 @@ repvgg_a2,128
swin_base_patch4_window7_224,64
tf_efficientnet_b0,128
visformer_small,128
vit_base_patch14_dinov2.lvd142m,128
ViT-B-16-SigLIP-i18n-256,128

View File

@ -1729,10 +1729,8 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2025,9 +2023,6 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -13,22 +13,20 @@ constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB
AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() {
static AcceleratorAllocatorConfig instance;
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env) \
auto env##_name = c10::utils::get_env(#env); \
if (env##_name.has_value()) { \
instance.parseArgs(env##_name.value()); \
return true; \
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \
auto env##_name = c10::utils::get_env(#env); \
if (env##_name.has_value()) { \
if (deprecated) { \
TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \
} \
instance.parseArgs(env##_name.value()); \
return true; \
}
static bool env_flag [[maybe_unused]] = []() {
// Parse allocator configuration from environment variables.
// The first two entries are kept for backward compatibility with legacy
// CUDA and HIP environment variable names. The new unified variable
// (PYTORCH_ALLOC_CONF) should be used going forward.
// Note: keep the parsing order and logic stable to avoid potential
// performance regressions in internal tests.
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false)
// Keep this for backwards compatibility
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true)
return false;
}();
#undef C10_ALLOCATOR_CONFIG_PARSE_ENV
@ -129,7 +127,8 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
std::fill(
std::next(
roundup_power2_divisions_.begin(),
static_cast<std::vector<size_t>::difference_type>(last_index)),
static_cast<std::vector<size_t>::difference_type>(
last_index + 1)),
roundup_power2_divisions_.end(),
value);
} else {

View File

@ -28,8 +28,101 @@
namespace c10 {
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
// [dtype Macros note] For the macros below:
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
// designed to behave similarly to the Dispatch macros with the same name.
//
// For adding a new dtype: In the beginning, we had an idea that there was a
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
// iterate over them. But over the years we added weird types which couldn't
// be handled uniformly everywhere and so in the end we ended up with some
// mish-mosh of some helper macros, but mostly use sites making a call about
// what dtypes they can or can't support. So if you want to add a new dtype,
// the preferred resolution is to find a dtype similar to what you want,
// grep for it and edit all the sites you find this way. If you need to add
// a completely new kind of dtype, you're going to have to laboriously audit
// all of the sites everywhere to figure out how it should work. Consulting
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
//
// TODO: To add unsigned int types here, we must define accumulate type.
// But uint8 currently accumulates into int64, so we would have to make
// an inconsistent choice for the larger types. Difficult.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<c10::Half>, ComplexHalf) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
namespace impl {
// These are used to map ScalarTypes to C++ types.
template <c10::ScalarType N>
struct ScalarTypeToCPPType;
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
template <> \
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
using type = cpp_type; \
\
/* This is a workaround for the CUDA bug which prevents */ \
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
/* ambiguous reference which can't to be resolved. For some reason it */ \
/* can't pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
static type t; \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
#undef SPECIALIZE_ScalarTypeToCPPType
template <c10::ScalarType N>
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
} // namespace impl
template <typename T>
struct CppTypeToScalarType;
@ -45,6 +138,130 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
// NB: despite its generic sounding name, the macros that don't take _AND
// are mostly only used by tensorexpr
#define AT_FORALL_INT_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long)
#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double)
// These macros are often controlling how many template instantiations we
// create for kernels. It is typically inappropriate to add new dtypes here,
// instead, new types should be added to use sites on a case-by-case basis.
// We generally are not accepting new dtypes due to binary size concerns.
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
SCALARTYPE5) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE6>::t), \
SCALARTYPE6) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE7>::t), \
SCALARTYPE7)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32) \
_(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble)
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;
@ -52,6 +269,19 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}
inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
case ScalarType::name: \
@ -295,6 +525,12 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
inline std::ostream& operator<<(
std::ostream& stream,
at::ScalarType scalar_type) {
return stream << toString(scalar_type);
}
// Returns a pair of strings representing the names for each dtype.
// The returned pair is (name, legacy_name_if_applicable)
C10_API std::pair<std::string, std::string> getDtypeNames(

View File

@ -1,5 +1,6 @@
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/llvmMathExtras.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
@ -7,119 +8,386 @@
namespace c10::cuda::CUDACachingAllocator {
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
CUDAAllocatorConfig::CUDAAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max()),
m_max_non_split_rounding_size(kLargeBuffer),
m_garbage_collection_threshold(0),
m_pinned_num_register_threads(1),
m_pinned_reserve_segment_size_mb(0),
m_expandable_segments(false),
#if CUDA_VERSION >= 12030
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::UNSPECIFIED),
#else
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::POSIX_FD),
#endif
m_release_lock_on_cudamalloc(false),
m_pinned_use_cuda_host_register(false),
m_graph_capture_record_stream_reuse(false),
m_pinned_use_background_threads(false) {
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
}
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
size_t log_size = (63 - llvm::countLeadingZeros(size));
// Our intervals start at 1MB and end at 64GB
const size_t interval_start =
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
const size_t interval_end =
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
TORCH_CHECK(
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
"kRoundUpPowerOfTwoIntervals mismatch");
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
index = std::max(0, index);
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
return instance().m_roundup_power2_divisions[index];
}
void CUDAAllocatorConfig::lexArgs(
const std::string& env,
std::vector<std::string>& config) {
std::vector<char> buf;
for (char ch : env) {
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
buf.clear();
}
config.emplace_back(1, ch);
} else if (ch != ' ') {
buf.emplace_back(ch);
}
}
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
}
}
void CUDAAllocatorConfig::consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c) {
TORCH_CHECK(
i < config.size() && config[i] == std::string(1, c),
"Error parsing CachingAllocator settings, expected ",
c,
"");
}
size_t CUDAAllocatorConfig::parseMaxSplitSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_split_size_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_split_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_non_split_rounding_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
double val1 = stod(config[i]);
TORCH_CHECK(
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
TORCH_CHECK(
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
m_garbage_collection_threshold = val1;
} else {
TORCH_CHECK(
false, "Error, expecting garbage_collection_threshold value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
bool first_value = true;
if (++i < config.size()) {
if (std::string_view(config[i]) == "[") {
size_t last_index = 0;
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
while (++i < config.size() && std::string_view(config[i]) != "]") {
const std::string& val1 = config[i];
size_t val2 = 0;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
val2 = stoi(config[i]);
} else {
TORCH_CHECK(
false, "Error parsing roundup_power2_divisions value", "");
}
TORCH_CHECK(
val2 == 0 || llvm::isPowerOf2_64(val2),
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
"");
if (std::string_view(val1) == ">") {
std::fill(
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
last_index)),
m_roundup_power2_divisions.end(),
val2);
} else {
size_t val1_long = stoul(val1);
TORCH_CHECK(
llvm::isPowerOf2_64(val1_long),
"For roundups, the intervals have to be power of 2 ",
"");
size_t index = 63 - llvm::countLeadingZeros(val1_long);
index = std::max((size_t)0, index);
index = std::min(index, m_roundup_power2_divisions.size() - 1);
if (first_value) {
std::fill(
m_roundup_power2_divisions.begin(),
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
index)),
val2);
first_value = false;
}
if (index < m_roundup_power2_divisions.size()) {
m_roundup_power2_divisions[index] = val2;
}
last_index = index;
}
if (std::string_view(config[i + 1]) != "]") {
consumeToken(config, ++i, ',');
}
}
} else { // Keep this for backwards compatibility
size_t val1 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val1),
"For roundups, the divisions has to be power of 2 ",
"");
std::fill(
m_roundup_power2_divisions.begin(),
m_roundup_power2_divisions.end(),
val1);
}
} else {
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseAllocatorConfig(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync) {
// For ease of maintenance and understanding, the CUDA and ROCm
// implementations of this function are separated. This avoids having many
// #ifdef's throughout.
#ifdef USE_ROCM
// Ease burden on ROCm users by allowing either cuda or hip tokens.
// cuda token is broken up to prevent hipify matching it.
#define PYTORCH_TOKEN1 \
"cud" \
"aMallocAsync"
#define PYTORCH_TOKEN2 "hipMallocAsync"
tokenizer.checkToken(++i, ":");
i++; // Move to the value after the colon
#ifdef USE_ROCM
TORCH_CHECK(
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
(tokenizer[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
used_cudaMallocAsync =
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
TORCH_INTERNAL_ASSERT(
tokenizer[i] == get()->name() ||
(tokenizer[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
tokenizer[i],
" != ",
get()->name());
#else // USE_ROCM
TORCH_CHECK(
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1)),
"Unknown allocator backend, "
"options are native and " PYTORCH_TOKEN1);
used_cudaMallocAsync = (tokenizer[i] == PYTORCH_TOKEN1);
TORCH_INTERNAL_ASSERT(
tokenizer[i] == get()->name(),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
tokenizer[i],
" != ",
get()->name());
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else // CUDA_VERSION >= 11040
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif // CUDA_VERSION >= 11040
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
(config[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
used_cudaMallocAsync =
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
TORCH_INTERNAL_ASSERT(
config[i] == get()->name() ||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
config[i],
" != ",
get()->name());
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
#endif // USE_ROCM
return i;
#undef PYTORCH_TOKEN1
#undef PYTORCH_TOKEN2
#else // USE_ROCM
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
TORCH_INTERNAL_ASSERT(
config[i] == get()->name(),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time");
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
return i;
#endif // USE_ROCM
}
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
// If empty, set the default values
m_max_split_size = std::numeric_limits<size_t>::max();
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
m_garbage_collection_threshold = 0;
bool used_cudaMallocAsync = false;
bool used_native_specific_option = false;
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
for (size_t i = 0; i < tokenizer.size(); i++) {
const auto& key = tokenizer[i];
if (key == "backend") {
i = parseAllocatorConfig(tokenizer, i, used_cudaMallocAsync);
if (!env.has_value()) {
return;
}
{
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
m_last_allocator_settings = env.value();
}
std::vector<std::string> config;
lexArgs(env.value(), config);
for (size_t i = 0; i < config.size(); i++) {
std::string_view config_item_view(config[i]);
if (config_item_view == "max_split_size_mb") {
i = parseMaxSplitSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "max_non_split_rounding_mb") {
i = parseMaxNonSplitRoundingSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "garbage_collection_threshold") {
i = parseGarbageCollectionThreshold(config, i);
used_native_specific_option = true;
} else if (config_item_view == "roundup_power2_divisions") {
i = parseRoundUpPower2Divisions(config, i);
used_native_specific_option = true;
} else if (config_item_view == "backend") {
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
} else if (config_item_view == "expandable_segments") {
used_native_specific_option = true;
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for expandable_segments");
config_item_view = config[i];
m_expandable_segments = (config_item_view == "True");
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
key == "release_lock_on_hipmalloc" ||
key ==
config_item_view == "release_lock_on_hipmalloc" ||
config_item_view ==
"release_lock_on_c"
"udamalloc") {
used_native_specific_option = true;
tokenizer.checkToken(++i, ":");
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for release_lock_on_cudamalloc");
config_item_view = config[i];
m_release_lock_on_cudamalloc = (config_item_view == "True");
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
key == "pinned_use_hip_host_register" ||
key ==
config_item_view == "pinned_use_hip_host_register" ||
config_item_view ==
"pinned_use_c"
"uda_host_register") {
i = parsePinnedUseCudaHostRegister(tokenizer, i);
i = parsePinnedUseCudaHostRegister(config, i);
used_native_specific_option = true;
} else if (key == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(tokenizer, i);
} else if (config_item_view == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(config, i);
used_native_specific_option = true;
} else if (key == "pinned_reserve_segment_size_mb") {
i = parsePinnedReserveSegmentSize(tokenizer, i);
} else if (config_item_view == "pinned_reserve_segment_size_mb") {
i = parsePinnedReserveSegmentSize(config, i);
used_native_specific_option = true;
} else if (key == "graph_capture_record_stream_reuse") {
i = parseGraphCaptureRecordStreamReuse(tokenizer, i);
} else if (config_item_view == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(config, i);
used_native_specific_option = true;
} else if (config_item_view == "graph_capture_record_stream_reuse") {
i = parseGraphCaptureRecordStreamReuse(config, i);
used_native_specific_option = true;
} else {
const auto& keys =
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
TORCH_CHECK(
keys.find(key) != keys.end(),
"Unrecognized key '",
key,
"' in CUDA allocator config.");
// Skip the key and its value
i = tokenizer.skipKey(i);
false, "Unrecognized CachingAllocator option: ", config_item_view);
}
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
if (i + 1 < config.size()) {
consumeToken(config, ++i, ',');
}
}
@ -131,51 +399,97 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
}
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_cuda_host_register");
m_pinned_use_cuda_host_register = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_cuda_host_register value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
m_graph_capture_record_stream_reuse = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for graph_capture_record_stream_reuse");
m_graph_capture_record_stream_reuse = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting graph_capture_record_stream_reuse value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
size_t val2 = tokenizer.toSizeT(++i);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2, got ",
val2);
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to ",
maxThreads,
", got ",
val2);
m_pinned_num_register_threads = val2;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
size_t val2 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
} else {
TORCH_CHECK(
false, "Error, expecting pinned_num_register_threads value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parsePinnedReserveSegmentSize(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
size_t val2 = tokenizer.toSizeT(++i);
TORCH_CHECK(val2 > 0, "Pinned reserve segment size has to be greater than 0");
m_pinned_reserve_segment_size_mb = val2;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
size_t val2 = stoi(config[i]);
TORCH_CHECK(
val2 > 0, "Pinned reserve segment size has to be greater than 0 ", "");
m_pinned_reserve_segment_size_mb = val2;
} else {
TORCH_CHECK(
false, "Error, expecting pinned_reserve_segment_size_mb value", "");
}
return i;
}
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_background_threads");
m_pinned_use_background_threads = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_background_threads value", "");
}
return i;
}
// General caching allocator utilities
void setAllocatorSettings(const std::string& env) {
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
}
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,11 +1,16 @@
#pragma once
#include <c10/core/AllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <mutex>
#include <string>
#include <vector>
namespace c10::cuda::CUDACachingAllocator {
enum class Expandable_Segments_Handle_Type : int {
@ -18,23 +23,20 @@ enum class Expandable_Segments_Handle_Type : int {
class C10_CUDA_API CUDAAllocatorConfig {
public:
static size_t max_split_size() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
return instance().m_max_split_size;
}
static double garbage_collection_threshold() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
garbage_collection_threshold();
return instance().m_garbage_collection_threshold;
}
static bool expandable_segments() {
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
use_expandable_segments();
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
if (enabled) {
if (instance().m_expandable_segments) {
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
}
return false;
#else
return enabled;
return instance().m_expandable_segments;
#endif
}
@ -65,8 +67,7 @@ class C10_CUDA_API CUDAAllocatorConfig {
}
static bool pinned_use_background_threads() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
pinned_use_background_threads();
return instance().m_pinned_use_background_threads;
}
static size_t pinned_reserve_segment_size_mb() {
@ -80,23 +81,24 @@ class C10_CUDA_API CUDAAllocatorConfig {
return 128;
}
static size_t roundup_power2_divisions(size_t size) {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions(size);
}
// This is used to round-up allocation size to nearest power of 2 divisions.
// More description below in function roundup_power2_next_division
// As an example, if we want 4 divisions between 2's power, this can be done
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
static size_t roundup_power2_divisions(size_t size);
static std::vector<size_t> roundup_power2_divisions() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions();
return instance().m_roundup_power2_divisions;
}
static size_t max_non_split_rounding_size() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
max_non_split_rounding_size();
return instance().m_max_non_split_rounding_size;
}
static std::string last_allocator_settings() {
return c10::CachingAllocator::getAllocatorSettings();
std::lock_guard<std::mutex> lock(
instance().m_last_allocator_settings_mutex);
return instance().m_last_allocator_settings;
}
static CUDAAllocatorConfig& instance() {
@ -109,75 +111,70 @@ class C10_CUDA_API CUDAAllocatorConfig {
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
// Note: keep the parsing order and logic stable to avoid potential
// performance regressions in internal tests.
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
}
if (env.has_value()) {
inst->parseArgs(env.value());
}
inst->parseArgs(env);
return inst;
})();
return *s_instance;
}
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
// issue.
static const std::unordered_set<std::string>& getKeys() {
static std::unordered_set<std::string> keys{
"backend",
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_cud"
"amalloc",
"pinned_use_cud"
"a_host_register",
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_hipmalloc",
"pinned_use_hip_host_register",
"graph_capture_record_stream_reuse",
"pinned_reserve_segment_size_mb",
"pinned_num_register_threads"};
return keys;
}
void parseArgs(const std::string& env);
void parseArgs(const std::optional<std::string>& env);
private:
CUDAAllocatorConfig() = default;
CUDAAllocatorConfig();
static void lexArgs(const std::string& env, std::vector<std::string>& config);
static void consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c);
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
size_t parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i);
size_t parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i);
size_t parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i);
size_t parseAllocatorConfig(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync);
size_t parsePinnedUseCudaHostRegister(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedReserveSegmentSize(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i);
size_t parseGraphCaptureRecordStreamReuse(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
std::atomic<size_t> m_pinned_num_register_threads{1};
std::atomic<size_t> m_pinned_reserve_segment_size_mb{0};
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
#if CUDA_VERSION >= 12030
{Expandable_Segments_Handle_Type::UNSPECIFIED};
#else
{Expandable_Segments_Handle_Type::POSIX_FD};
#endif
std::atomic<bool> m_release_lock_on_cudamalloc{false};
std::atomic<bool> m_pinned_use_cuda_host_register{false};
std::atomic<bool> m_graph_capture_record_stream_reuse{false};
std::atomic<size_t> m_max_split_size;
std::atomic<size_t> m_max_non_split_rounding_size;
std::vector<size_t> m_roundup_power2_divisions;
std::atomic<double> m_garbage_collection_threshold;
std::atomic<size_t> m_pinned_num_register_threads;
std::atomic<size_t> m_pinned_reserve_segment_size_mb;
std::atomic<bool> m_expandable_segments;
std::atomic<Expandable_Segments_Handle_Type>
m_expandable_segments_handle_type;
std::atomic<bool> m_release_lock_on_cudamalloc;
std::atomic<bool> m_pinned_use_cuda_host_register;
std::atomic<bool> m_graph_capture_record_stream_reuse;
std::atomic<bool> m_pinned_use_background_threads;
std::string m_last_allocator_settings;
std::mutex m_last_allocator_settings_mutex;
};
// Keep this for backwards compatibility
using c10::CachingAllocator::setAllocatorSettings;
// General caching allocator utilities
C10_CUDA_API void setAllocatorSettings(const std::string& env);
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -64,6 +64,10 @@ namespace cuda::CUDACachingAllocator {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
// Included here as this is externally used in CUDAAllocatorConfig
const size_t kLargeBuffer =
20971520; // "large" allocations may be packed in 20 MiB blocks
namespace Native {
//

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/core/AllocatorConfig.h>
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator {
// Preserved only for BC reasons
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::CachingAllocator::kLargeBuffer;
using c10::CachingDeviceAllocator::DeviceStats;
extern const size_t kLargeBuffer;
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
// Struct containing info of an allocation block (i.e. a fractional part of a

View File

@ -67,8 +67,8 @@ TEST(AllocatorConfigTest, allocator_config_test) {
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2);
// EXPECT_EQ(
// AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1);
EXPECT_EQ(
@ -101,8 +101,8 @@ TEST(AllocatorConfigTest, allocator_config_test) {
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0);
// EXPECT_EQ(
// AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2);

View File

@ -65,7 +65,7 @@ struct default_constructible
namespace impl {
template <typename T>
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>* /*unused*/)
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>*)
{
return true;
}
@ -76,7 +76,7 @@ class type : public modifier<M, type<T, Tag, M...>>...
{
public:
template <typename TT = T, typename = std::enable_if_t<std::is_trivially_constructible<TT>{}>>
explicit type(uninitialized_t /*unused*/)
explicit type(uninitialized_t)
noexcept
{
}
@ -138,7 +138,7 @@ private:
namespace impl {
template <typename T, typename Tag, typename ... Ms>
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>* /*unused*/) { return true;}
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>*) { return true;}
constexpr bool is_strong_type_func(...) { return false;}
template <typename T, typename Tag, typename ... Ms>
constexpr T underlying_type(strong::type<T, Tag, Ms...>*);

View File

@ -20,6 +20,8 @@ constexpr size_t kMinBlockSize = 512;
constexpr size_t kSmallSize = 1048576;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// "large" allocations may be packed in 20 MiB blocks
constexpr size_t kLargeBuffer = 20971520;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB
@ -433,18 +435,6 @@ class DeviceCachingAllocator {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_prop.global_mem_size -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
device_free =
raw_device.get_info<sycl::ext::intel::info::device::free_memory>();
}
auto allocated_bytes =
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
@ -467,9 +457,7 @@ class DeviceCachingAllocator {
static_cast<int>(device),
" has a total capacity of ",
format_size(device_total),
" of which ",
format_size(device_free),
" is free. Of the allocated memory ",
". Of the allocated memory ",
format_size(allocated_bytes),
" is allocated by PyTorch, and ",
format_size(reserved_bytes - allocated_bytes),

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/core/AllocatorConfig.h>
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/xpu/XPUStream.h>

View File

@ -68,6 +68,14 @@
.. autofunction:: get_validators
```
```{eval-rst}
.. autofunction:: write_file_on_exit
```
```{eval-rst}
.. autofunction:: write_file
```
```{eval-rst}
.. autofunction:: read_file
```
@ -87,7 +95,3 @@
```{eval-rst}
.. autofunction:: get_rotating_buffer_size
```
```{eval-rst}
.. autofunction:: set_numerical_check_tolerances
```

View File

@ -23,7 +23,6 @@ Submodules
flex_attention
bias
experimental
varlen
.. toctree::
:hidden:
@ -31,4 +30,3 @@ Submodules
nn.attention.flex_attention
nn.attention.bias
nn.attention.experimental
nn.attention.varlen

View File

@ -1,17 +0,0 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# torch.nn.attention.varlen
```{eval-rst}
.. automodule:: torch.nn.attention.varlen
.. currentmodule:: torch.nn.attention.varlen
```
```{eval-rst}
.. autofunction:: varlen_attn
```
```{eval-rst}
.. autoclass:: AuxRequest
```

View File

@ -228,4 +228,3 @@ Low-Precision functions
ScalingType
SwizzleType
scaled_mm
scaled_grouped_mm

View File

@ -1,12 +1,14 @@
```{eval-rst}
.. currentmodule:: torch.compiler.config
```
# torch.compiler.config
```{eval-rst}
.. automodule:: torch.compiler.config
:members:
:undoc-members:
:show-inheritance:
```
```{eval-rst}
.. autodata:: torch.compiler.config.job_id
```

View File

@ -816,10 +816,6 @@ Operator Tags
.. py:module:: torch.types
.. py:module:: torch.version
.. Compiler configuration module - documented in torch.compiler.config.md
.. py:module:: torch.compiler.config
:noindex:
.. Hidden aliases (e.g. torch.functional.broadcast_tensors()). We want `torch.broadcast_tensors()` to
be visible only.
.. toctree::

View File

@ -10,7 +10,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
)

View File

@ -1,76 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/ScalarType.h>
TEST(TestScalarType, ScalarTypeToCPPTypeT) {
using torch::headeronly::ScalarType;
using torch::headeronly::impl::ScalarTypeToCPPTypeT;
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
EXPECT_EQ(typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE));
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
{ \
EXPECT_EQ( \
typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); \
count++; \
}
#define TEST_FORALL(M, EXPECTEDCOUNT, ...) \
TEST(TestScalarType, M) { \
using torch::headeronly::ScalarType; \
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
int8_t count = 0; \
M(__VA_ARGS__ DEFINE_CHECK); \
EXPECT_EQ(count, EXPECTEDCOUNT); \
}
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14)
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18)
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46)
TEST_FORALL(AT_FORALL_INT_TYPES, 5)
TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7)
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, )
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND2, 9, Bool, Half, )
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND3, 10, Bool, Half, ComplexFloat, )
TEST_FORALL(
AT_FORALL_SCALAR_TYPES_AND7,
14,
Bool,
Half,
ComplexHalf,
ComplexFloat,
ComplexDouble,
UInt16,
UInt32, )
TEST_FORALL(AT_FORALL_QINT_TYPES, 5)
TEST_FORALL(AT_FORALL_FLOAT8_TYPES, 5)
TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
#undef DEFINE_CHECK
#undef TEST_FORALL
TEST(TestScalarType, toString) {
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, operator_left_shift) {
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) \
{ \
std::stringstream ss; \
ss << ScalarType::name; \
EXPECT_EQ(ss.str(), #name); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -559,7 +559,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
FAIL = 138
pc = start_processes(
name="echo",
entrypoint=bin("echo4.py"),
entrypoint=bin("echo1.py"),
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(

View File

@ -9,6 +9,7 @@
import argparse
import os
import sys
import time
if __name__ == "__main__":
@ -23,5 +24,6 @@ if __name__ == "__main__":
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
time.sleep(1000)
print(f"{args.msg} stdout from {rank}")
print(f"{args.msg} stderr from {rank}", file=sys.stderr)

View File

@ -1,29 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import sys
import time
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="test binary, exits with exitcode")
parser.add_argument("--exitcode", type=int, default=0)
parser.add_argument("msg", type=str)
args = parser.parse_args()
rank = int(os.environ["RANK"])
exitcode = args.exitcode
if exitcode != 0:
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
time.sleep(1000)
print(f"{args.msg} stdout from {rank}")
print(f"{args.msg} stderr from {rank}", file=sys.stderr)

View File

@ -536,23 +536,6 @@ class TestScheduleLowering(TestCase):
"compute": ["0F0", "0F1", " ", "0B0", "0B1"],
"comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"],
},
{
"compute": ["0F0", "0F1", "1F0", "1F1", "1B0", "1B1", "0B0", "0B1"],
"comms": [
"0UNSHARD",
"1UNSHARD",
"0F0",
"0F1",
"1F0",
"1F1",
"1B0",
"1B1",
"1RESHARD",
"0B0",
"0B1",
"0RESHARD",
],
},
],
)
def test_unshard_reshard(self, test_info):

View File

@ -1,9 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import functools
import itertools
import random
import unittest
from typing import Any, Callable, ClassVar, Optional
from typing import Callable, ClassVar, Optional, Union
import torch
import torch.distributed as dist
@ -31,7 +32,6 @@ from torch.distributed.tensor.experimental._load_balancer import (
_HeadTailLoadBalancer,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
@ -39,7 +39,6 @@ from torch.nn.attention.flex_attention import (
_mask_mod_signature,
AuxOutput,
AuxRequest,
BlockMask,
create_block_mask,
flex_attention,
)
@ -392,7 +391,9 @@ def generate_random_lengths_in_chunks(
return [num_chunks * chunk_size for num_chunks in num_chunks_per_document]
def length_to_offsets(lengths: list[list[int]], device: str | torch.device) -> Tensor:
def length_to_offsets(
lengths: list[list[int]], device: Union[str, torch.device]
) -> Tensor:
"""Converts a list of lengths to a list of offsets.
Args:
@ -474,9 +475,8 @@ class CPFlexAttentionTest(DTensorTestBase):
*,
qkv_size: int,
B: int = 1,
block_mask,
lb_type: str,
document_lengths: Optional[list[list[int]]] = None,
mask_func: _mask_mod_signature = causal_mask,
lb: Optional[_LoadBalancer] = None,
) -> None:
torch.use_deterministic_algorithms(True)
torch.cuda.manual_seed(1234)
@ -486,14 +486,6 @@ class CPFlexAttentionTest(DTensorTestBase):
dim = 32
nheads = 8
seq_dim = 2
lb = self._get_load_balancer(
lb_type,
{
"seq_length": qkv_size,
"document_lengths": document_lengths,
"block_mask": block_mask,
},
)
qkv = [
torch.rand(
@ -505,6 +497,15 @@ class CPFlexAttentionTest(DTensorTestBase):
for _ in range(3)
]
block_mask = compiled_create_block_mask(
mask_func,
B=B,
H=1,
Q_LEN=qkv_size,
KV_LEN=qkv_size,
device=self.device_type,
)
expect_out, expect_aux = compiled_flex_attention(
*qkv, block_mask=block_mask, return_aux=AuxRequest(lse=True)
)
@ -546,8 +547,6 @@ class CPFlexAttentionTest(DTensorTestBase):
# backward run
cp_out.sum().backward()
atol = 2e-06
rtol = 1e-05
# unshard the output
cp_out, cp_lse = context_parallel_unshard(
device_mesh,
@ -555,8 +554,8 @@ class CPFlexAttentionTest(DTensorTestBase):
seq_dims=[seq_dim] * 2,
load_balancer=lb,
)
torch.testing.assert_close(cp_out, expect_out, atol=atol, rtol=rtol)
torch.testing.assert_close(cp_lse, expect_aux.lse, atol=atol, rtol=rtol)
torch.testing.assert_close(cp_out, expect_out)
torch.testing.assert_close(cp_lse, expect_aux.lse)
# unshard the gradient
cp_qkv_grad = context_parallel_unshard(
@ -568,38 +567,7 @@ class CPFlexAttentionTest(DTensorTestBase):
qkv_grad = [t.grad for t in qkv]
for grad, cp_grad in zip(qkv_grad, cp_qkv_grad):
torch.testing.assert_close(grad, cp_grad, atol=atol, rtol=rtol)
def _get_load_balancer(
self, lb_type: str, kwargs: dict[str, Any]
) -> Optional[_LoadBalancer]:
seq_length = kwargs["seq_length"]
document_lengths = kwargs["document_lengths"]
block_mask = kwargs["block_mask"]
# generate load balancer
if lb_type == "None":
load_balancer = None # no load-balance
elif lb_type == "_HeadTailLoadBalancer":
assert isinstance(seq_length, int)
load_balancer = _HeadTailLoadBalancer(
seq_length, self.world_size, torch.device(self.device_type)
)
elif lb_type == "_PerDocumentHeadTailLoadBalancer":
assert isinstance(document_lengths, list)
load_balancer = _PerDocumentHeadTailLoadBalancer(
document_lengths, self.world_size, torch.device(self.device_type)
)
elif lb_type == "_PTRRLoadBalancer":
assert isinstance(block_mask, BlockMask)
load_balancer = _PTRRLoadBalancer(
block_mask,
self.world_size,
)
else:
raise ValueError(f"load_balancer type {lb_type} is not supported!")
return load_balancer
torch.testing.assert_close(grad, cp_grad)
@skip_if_lt_x_gpu(2)
@with_comms
@ -607,65 +575,33 @@ class CPFlexAttentionTest(DTensorTestBase):
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
def test_cp_flex_attention_causal_mask(self) -> None:
seq_length_list = [256 * self.world_size, 2048]
load_balance_type_list = [
"None",
"_HeadTailLoadBalancer",
"_PTRRLoadBalancer",
]
restore_enable_load_balance = _cp_options.enable_load_balance
# NOTE: Each (seq_len, load_balance_type) tuple introduces 2
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
# will be the total number of compilations in our test case.
torch._dynamo.config.cache_size_limit = (len(seq_length_list) + 1) * (
1 + len(load_balance_type_list)
)
for enable_load_balance in [
False, # test w/o load-balancing
True, # test w/ the default load-balancing
]:
_cp_options.enable_load_balance = enable_load_balance
self.run_subtests(
{
"qkv_size": [
(256 if enable_load_balance else 128) * self.world_size,
2048,
],
},
self._test_cp_flex_attention,
)
for qkv_size, lb_type in itertools.product(
seq_length_list, load_balance_type_list
):
block_mask = compiled_create_block_mask(
causal_mask,
B=1,
H=1,
Q_LEN=qkv_size,
KV_LEN=qkv_size,
device=self.device_type,
)
self._test_cp_flex_attention(
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
)
_cp_options.enable_load_balance = restore_enable_load_balance
# NOTE: Context Parallel should not be used for small attentions (block_size < 128)
qkv_size = 64 * self.world_size
block_mask = compiled_create_block_mask(
causal_mask,
B=1,
H=1,
Q_LEN=qkv_size,
KV_LEN=qkv_size,
device=self.device_type,
)
for lb_type in ["None", "_HeadTailLoadBalancer"]:
with self.assertRaisesRegex(
NotImplementedError,
f"Q_LEN {qkv_size} is not divisible",
):
self._test_cp_flex_attention(
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
)
for lb_type in ["_PTRRLoadBalancer"]:
with self.assertRaisesRegex(
NotImplementedError,
"must be divisible by group_size",
):
self._test_cp_flex_attention(
qkv_size=qkv_size, block_mask=block_mask, lb_type=lb_type
)
with self.assertRaisesRegex(
NotImplementedError, "Q_LEN 128 is not divisible by CP mesh world size"
):
self.run_subtests(
{"qkv_size": [64 * self.world_size]},
self._test_cp_flex_attention,
)
# TODO: merge with the above test
@skip_if_lt_x_gpu(2)
@ -674,71 +610,77 @@ class CPFlexAttentionTest(DTensorTestBase):
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
def test_cp_flex_attention_document_mask(self) -> None:
restore_enable_load_balance = _cp_options.enable_load_balance
random.seed(10)
# parameters for testing
doc_count = 28
enable_load_balance_list = [True, False]
batch_size_list = [2, 4, 8]
max_seq_len_list = [
256 * self.world_size,
2048,
# 128 * self.world_size # NOTE: Mismatched elements: 8 / 131072 (0.0%),
]
load_balance_type = [
"None",
"_HeadTailLoadBalancer",
"_PerDocumentHeadTailLoadBalancer",
"_PTRRLoadBalancer",
]
# NOTE: Each (batch_size, seq_len, load_balance_type) tuple introduces 2
# NOTE: Each (enable_load_balance, batch_size, seq_len) tuple introduces 2
# create_block_mask compilations: 1 for single-rank flex_attention and 1 for
# CP flex_attention. In order to avoid the "exceeds_recompile_limit" error,
# we need to increase the cache_size_limit to 2 * num_of_sub_test_runs which
# will be the total number of compilations in our test case.
# we need to increase the cache_size_limit to 12 which is the total number
# of compilations in our test case.
torch._dynamo.config.cache_size_limit = (
2 * len(batch_size_list) * len(max_seq_len_list) * len(load_balance_type)
2
* len(enable_load_balance_list)
* len(batch_size_list)
* len(max_seq_len_list)
)
# TODO: change this for-loop to run_subtests
# Use a for-loop instead of run_subtests because we need to intialize the mask
# for each subtest. This can be baked into self._test_cp_flex_attention as
# a str argument denoting mask type.
for batch_size, max_seq_len, lb_type in itertools.product(
batch_size_list,
max_seq_len_list,
load_balance_type,
for enable_load_balance, batch_size, max_seq_len in itertools.product(
enable_load_balance_list, batch_size_list, max_seq_len_list
):
_cp_options.enable_load_balance = enable_load_balance
# initialize document mask
lengths = [
(
generate_random_lengths_in_chunks(
max_seq_len, doc_count, chunk_size=2 * self.world_size
)
if lb_type == "_PerDocumentHeadTailLoadBalancer"
if enable_load_balance
else generate_random_lengths(max_seq_len, doc_count)
)
for _ in range(batch_size)
]
offsets = length_to_offsets(lengths, self.device_type)
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
block_mask = compiled_create_block_mask(
document_causal_mask,
B=batch_size,
H=1,
Q_LEN=max_seq_len,
KV_LEN=max_seq_len,
device=self.device_type,
# generate load balancer
load_balancer = (
_PerDocumentHeadTailLoadBalancer(
lengths, self.world_size, torch.device(self.device_type)
)
if enable_load_balance
else None
)
self._test_cp_flex_attention(
# construct testing function
test_func = functools.partial(
self._test_cp_flex_attention,
qkv_size=max_seq_len,
B=batch_size,
lb_type=lb_type,
block_mask=block_mask,
document_lengths=lengths,
lb=load_balancer,
mask_func=document_causal_mask,
)
test_func()
_cp_options.enable_load_balance = restore_enable_load_balance
class TestCPCustomOps(DTensorTestBase):
@property

View File

@ -239,7 +239,9 @@ class DTensorExportTest(TestCase):
"view_9",
"t_15",
"detach",
"detach_3",
"detach_1",
"detach_6",
"detach_7",
"threshold_backward_1",
"t_16",
"mm_6",
@ -257,8 +259,10 @@ class DTensorExportTest(TestCase):
"sum_1",
"view_7",
"t_7",
"detach_1",
"detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward",
"mm_2",
"t_9",

View File

@ -20,7 +20,6 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -1146,22 +1145,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
sharded_dt, mesh, tgt_placement, shard_order=None
)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
device_mesh = init_device_mesh(self.device_type, (4, 2))
x = torch.randn(8, 4, device=self.device_type)
# specify right-to-left order use _StridedShard
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = self.distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],
shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
)
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
if __name__ == "__main__":
run_tests()

View File

@ -70,8 +70,6 @@ def get_patches():
"force_disable_caches": True,
# Messes up existing test strings
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
# interferes with testing, / custom estimation
"test_configs.assume_bucketing_reduces_latency": False,
}
@ -366,8 +364,6 @@ def get_bucket_patches(compute_multiplier=1.0):
"force_disable_caches": True,
# messes up test strings
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
# interferes with testing, / custom estimation
"test_configs.assume_bucketing_reduces_latency": False,
}
@ -583,7 +579,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap_blocking_no_deps(self):
def test_bucketing_split_for_overlap_blocking(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
def func(a, b, c, d, *, ranks):

View File

@ -2,7 +2,6 @@
# Owner(s): ["oncall: distributed"]
import os
import unittest
from datetime import timedelta
import torch
import torch.distributed as dist
@ -41,13 +40,6 @@ from torch.utils._typing_utils import not_none
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()
try:
import torch._C._distributed_c10d.ProcessGroupNCCL
_NCCL_AVAILABLE = True
except ImportError:
_NCCL_AVAILABLE = False
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
os.environ["MASTER_ADDR"] = addr
@ -970,85 +962,6 @@ class TestDeviceMeshGetItem(DTensorTestBase):
# check flattened mesh dependency
self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d)
@with_comms
def test_unflatten_mesh_2d(self):
mesh_shape = (4, 2)
mesh_dim_names = ("dp", "tp")
mesh_2d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
self.assertEqual(
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
)
self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group())
# Not supporting slicing out unflatten dim name from root mesh.
with self.assertRaises(KeyError):
self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh)
@with_comms
def test_unflatten_mesh_3d(self):
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
global_mesh = init_device_mesh(
self.device_type,
(8,),
mesh_dim_names=("world",),
)
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))
unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
self.assertEqual(
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"]
)
self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh)
self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group())
self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh)
self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group())
# Test unflatten with backend override set.
if not _NCCL_AVAILABLE:
return
opts = dist.ProcessGroupNCCL.Options()
opts._timeout = timedelta(seconds=30)
mesh_2d = global_mesh._unflatten(
0,
(1, 8),
("pp", "spmd"),
backend_override={"pp": "fake", "spmd": ("nccl", opts)},
)
opts = dist.ProcessGroupNCCL.Options()
opts._timeout = timedelta(seconds=60)
mesh_4d = mesh_2d._unflatten(
1,
(2, 2, 2),
("dp", "cp", "tp"),
backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)},
)
self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom")
spmd_pg = mesh_2d["spmd"].get_group()
self.assertEqual(spmd_pg._get_backend_name(), "nccl")
w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank))
self.assertTrue(
spmd_pg._get_backend(
torch.device(f"cuda:{self.rank}")
)._verify_work_timeout(w, timedelta(seconds=30))
)
w.wait()
tp_pg = mesh_4d["tp"].get_group()
self.assertEqual(tp_pg._get_backend_name(), "nccl")
w = tp_pg.allreduce(torch.rand(10).cuda(self.rank))
self.assertTrue(
tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout(
w, timedelta(seconds=60)
)
)
w.wait()
@with_comms
def test_reconstruct_mesh_with_flatten_dim(self):
mesh_3d = init_device_mesh(

View File

@ -273,7 +273,12 @@ class TestFakePG(TestCase):
kwargs = {}
return func(*args, **kwargs)
with self.assertRaisesRegex(TypeError, r"No constructor defined"):
with self.assertRaisesRegex(
RuntimeError,
r"FakeProcessGroup cannot be constructed directly\. "
r"Use torch\.distributed\.init_process_group\(backend='fake'\) instead to ensure "
r"proper dispatch system integration\.",
):
fake_pg = FakeProcessGroup(rank=0, world_size=3)
with SimpleTensorMode():

View File

@ -1743,67 +1743,6 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = f(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all"])
def test_all_reduce_bucket(self, bucket_mode):
def func(x, w, ar_0, ar_1, tag, ranks, group_size):
y = torch.mm(x, w)
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ar_0_out = torch.ops._c10d_functional.all_reduce.default(
ar_0, "sum", group_name
)
ar_1_out = torch.ops._c10d_functional.all_reduce.default(
ar_1, "sum", group_name
)
ar_0_w = torch.ops.c10d_functional.wait_tensor(ar_0_out)
ar_1_w = torch.ops.c10d_functional.wait_tensor(ar_1_out)
return y, ar_0_w, ar_1_w
f = func
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ar_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ar_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
inputs = [x, w, ar_0, ar_1]
f(*inputs, **self.get_world_trs())
def _pass(g):
from torch._inductor.fx_passes.bucketing import bucket_all_reduce
bucket_all_reduce(g.owning_module, lambda _: 2000)
torch._inductor.config.post_grad_custom_post_pass = _pass
with torch._inductor.config.patch(
{
"reorder_for_compute_comm_overlap": False,
}
):
compiled = torch.compile(f)
compiled(*inputs, **self.get_world_trs())
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unnecessary copy is made.
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.all_reduce_.default(",
count=1,
exactly=True,
)
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
correct = f(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all", "all_custom_ops"])

Some files were not shown because too many files have changed in this diff Show More