mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Compare commits
158 Commits
csl/lintru
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| eebd57f4fd | |||
| 280d77bd86 | |||
| 54c90ae440 | |||
| fee1ac927d | |||
| 4a7fefd7c7 | |||
| 3b4315940d | |||
| 3eddf04922 | |||
| 7c203b8420 | |||
| 3ca216ae17 | |||
| 9c22bbb2dc | |||
| 6268883f9c | |||
| 16212f0d6b | |||
| c8adc08b3b | |||
| 23b57a445c | |||
| 6c7cad6972 | |||
| bb54296258 | |||
| 5e05a0ae99 | |||
| 298666631b | |||
| e471800dce | |||
| 18f4259626 | |||
| d962bed157 | |||
| 76780b1a3d | |||
| cee03634da | |||
| bc03d7c974 | |||
| f013e804c8 | |||
| 0674e0a0f1 | |||
| b7d348a907 | |||
| 9f9dbe0a9a | |||
| a19e92d433 | |||
| c3dc0c7089 | |||
| 04d6a6f339 | |||
| 0573747b6a | |||
| a663eb9c80 | |||
| 764c54ecae | |||
| 0d81bb7f9c | |||
| 82fafb3304 | |||
| 401c2f9657 | |||
| 13549e0e10 | |||
| 82d86bacf3 | |||
| 3b5d38a3bc | |||
| 84776e1374 | |||
| b3861ac8e7 | |||
| 4cc64d6234 | |||
| 1aef88c72d | |||
| f0745ddb11 | |||
| 4316df857c | |||
| 9d6597b1e9 | |||
| e8fadba28c | |||
| 60333de85d | |||
| 3dc92d69ed | |||
| f91899ca6c | |||
| e2dc32f4ba | |||
| 83cc38d9c1 | |||
| 8d599045cf | |||
| fd5da81fdd | |||
| 9261a1fb12 | |||
| d80ae738c9 | |||
| 51667435f5 | |||
| 2699f5410b | |||
| 9970fb97ff | |||
| dfebdcab86 | |||
| b09fb481e0 | |||
| 4e7232c5da | |||
| 93a70c717a | |||
| d97144d31e | |||
| e4043884c7 | |||
| 4a7bc1d522 | |||
| 8209a0506b | |||
| 70aeb49198 | |||
| cf9a834f39 | |||
| 856a7a5298 | |||
| ef8d97efcf | |||
| d2be06f673 | |||
| 08f4535378 | |||
| 30157d30f0 | |||
| b470e59c38 | |||
| 85b85f6c2c | |||
| b71966f67b | |||
| 0947765eb9 | |||
| 239e7b541a | |||
| ffaa6578b7 | |||
| 365ed62f61 | |||
| fcc1063566 | |||
| 121235956b | |||
| aa9c96af04 | |||
| c3b71d5499 | |||
| 1e3600b528 | |||
| fee7624bd6 | |||
| 24e94e021a | |||
| 69be99ee51 | |||
| 034e951b0c | |||
| 160ab53dd5 | |||
| 5bcfdae71d | |||
| 4e8ba37ce3 | |||
| 26534e9809 | |||
| 657f8c3e21 | |||
| b0831930ed | |||
| c01636e1bc | |||
| fd68d409ad | |||
| 0d3a4f7155 | |||
| 108bb224f7 | |||
| fc8ac1216c | |||
| 030de07aff | |||
| 7d67a41db4 | |||
| 85b035ca9c | |||
| 267d0197bf | |||
| 1dec8a67a8 | |||
| 797cd80b26 | |||
| 7d39401fa0 | |||
| e3ae0594d1 | |||
| f1e4c42b6e | |||
| d3e511f07c | |||
| d3be06cbdc | |||
| 1129605415 | |||
| a6b1ef1717 | |||
| 12577064dd | |||
| 24b6eb7727 | |||
| 32066772b3 | |||
| 47f0024310 | |||
| 98d640bb11 | |||
| 5d288bc3f7 | |||
| bfb47ec50e | |||
| 7a0cd8ed09 | |||
| 984e64b2cd | |||
| b9bcb37f40 | |||
| 7e3b9d105e | |||
| 45c3f02d69 | |||
| f5543e3741 | |||
| 5fc2c7a2a1 | |||
| 7692fa09cd | |||
| df71b70727 | |||
| 80ba6e458f | |||
| 0d50e5d8d4 | |||
| 99b05d1b78 | |||
| f911d64750 | |||
| 52db60170d | |||
| 56838bad5f | |||
| ad3a56ab98 | |||
| a7fd0b4001 | |||
| 181ee3bd42 | |||
| 0ec0549823 | |||
| 8221ee6db9 | |||
| b939de26d1 | |||
| 694db5f549 | |||
| 639a0b1239 | |||
| 398775a43e | |||
| fcd5f8c352 | |||
| 4acc66f119 | |||
| 8f40a0c634 | |||
| a5c3c08d10 | |||
| a553ea9ea4 | |||
| ba71e9ca9a | |||
| 694d205143 | |||
| 629293f568 | |||
| 53dc8a0875 | |||
| 9e119dd8c4 | |||
| 93a6e99edc | |||
| d0892c7792 |
@ -195,13 +195,16 @@ case "$tag" in
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-n-py3)
|
||||
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
VISION=yes
|
||||
XPU_VERSION=2025.2
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
if [[ $tag =~ "benchmarks" ]]; then
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
fi
|
||||
;;
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
set -eux
|
||||
|
||||
ACL_VERSION=${ACL_VERSION:-"v25.02"}
|
||||
ACL_VERSION=${ACL_VERSION:-"v52.6.0"}
|
||||
ACL_INSTALL_DIR="/acl"
|
||||
|
||||
# Clone ACL
|
||||
|
||||
@ -12,8 +12,8 @@ function do_install() {
|
||||
|
||||
rocm_version_nodot=${rocm_version//./}
|
||||
|
||||
# https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
||||
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
|
||||
|
||||
rocm_dir="/opt/rocm"
|
||||
|
||||
@ -97,7 +97,7 @@ case ${image} in
|
||||
manylinux2_28-builder:xpu)
|
||||
TARGET=xpu_final
|
||||
GPU_IMAGE=amd64/almalinux:8
|
||||
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=11"
|
||||
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13"
|
||||
MANY_LINUX_VERSION="2_28"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -54,12 +54,15 @@ ENV OPENSSL_DIR /opt/openssl
|
||||
RUN rm install_openssl.sh
|
||||
|
||||
ARG INDUCTOR_BENCHMARKS
|
||||
ARG ANACONDA_PYTHON_VERSION
|
||||
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
|
||||
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
# Install XPU Dependencies
|
||||
ARG XPU_VERSION
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
SHELL=/usr/bin/env bash
|
||||
|
||||
DOCKER_CMD ?= docker
|
||||
DESIRED_ROCM ?= 7.0
|
||||
DESIRED_ROCM ?= 7.1
|
||||
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
|
||||
PACKAGE_NAME = magma-rocm
|
||||
# inherit this from underlying docker image, do not pass this env var to docker
|
||||
@ -16,6 +16,7 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
|
||||
magma-rocm/build_magma.sh
|
||||
|
||||
.PHONY: all
|
||||
all: magma-rocm71
|
||||
all: magma-rocm70
|
||||
all: magma-rocm64
|
||||
|
||||
@ -24,6 +25,11 @@ clean:
|
||||
$(RM) -r magma-*
|
||||
$(RM) -r output
|
||||
|
||||
.PHONY: magma-rocm71
|
||||
magma-rocm71: DESIRED_ROCM := 7.1
|
||||
magma-rocm71:
|
||||
$(DOCKER_RUN)
|
||||
|
||||
.PHONY: magma-rocm70
|
||||
magma-rocm70: DESIRED_ROCM := 7.0
|
||||
magma-rocm70:
|
||||
|
||||
@ -6,8 +6,8 @@ set -eou pipefail
|
||||
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
|
||||
# https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
||||
|
||||
# Folders for the build
|
||||
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
|
||||
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
|
||||
|
||||
# Fetch magma sources and verify checksum
|
||||
pushd ${PACKAGE_DIR}
|
||||
git clone https://github.com/jeffdaily/magma
|
||||
git clone https://github.com/icl-utk-edu/magma
|
||||
pushd magma
|
||||
git checkout ${MAGMA_VERSION}
|
||||
popd
|
||||
|
||||
@ -426,7 +426,7 @@ fi
|
||||
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
|
||||
# export test times so that potential sharded tests that'll branch off this build will use consistent data
|
||||
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
|
||||
python tools/stats/export_test_times.py
|
||||
PYTHONPATH=. python tools/stats/export_test_times.py
|
||||
fi
|
||||
# don't do this for bazel or s390x or riscv64 as they don't use sccache
|
||||
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then
|
||||
|
||||
@ -572,6 +572,8 @@ fi
|
||||
|
||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||
DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
|
||||
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
|
||||
DYNAMO_BENCHMARK_FLAGS+=(--device xpu)
|
||||
else
|
||||
DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
|
||||
fi
|
||||
@ -665,6 +667,8 @@ test_perf_for_dashboard() {
|
||||
device=cuda_b200
|
||||
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
|
||||
device=rocm
|
||||
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
|
||||
device=xpu
|
||||
fi
|
||||
|
||||
for mode in "${modes[@]}"; do
|
||||
@ -1757,7 +1761,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
|
||||
else
|
||||
# Do this after checkout_install_torchbench to ensure we clobber any
|
||||
# nightlies that torchbench may pull in
|
||||
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
|
||||
if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then
|
||||
install_torchrec_and_fbgemm
|
||||
fi
|
||||
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
|
||||
|
||||
319
.claude/skills/add-uint-support/SKILL.md
Normal file
319
.claude/skills/add-uint-support/SKILL.md
Normal file
@ -0,0 +1,319 @@
|
||||
---
|
||||
name: add-uint-support
|
||||
description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.
|
||||
---
|
||||
|
||||
# Add Unsigned Integer (uint) Support to Operators
|
||||
|
||||
This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill when:
|
||||
- Adding uint16, uint32, or uint64 support to an operator
|
||||
- User mentions "unsigned types", "uint support", "barebones unsigned types"
|
||||
- Enabling support for kUInt16, kUInt32, kUInt64 in kernels
|
||||
- Working with operator implementations that need expanded type coverage
|
||||
|
||||
## Quick reference
|
||||
|
||||
**Add unsigned types to existing dispatch:**
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES));
|
||||
|
||||
// After (method 1: add unsigned types explicitly)
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
|
||||
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
```
|
||||
|
||||
## Type group reference
|
||||
|
||||
**Unsigned type groups:**
|
||||
- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64
|
||||
- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
|
||||
|
||||
**Relationship:**
|
||||
```cpp
|
||||
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
|
||||
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
|
||||
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES
|
||||
```
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Determine if conversion to V2 is needed
|
||||
|
||||
Check if the file uses AT_DISPATCH_V2:
|
||||
|
||||
**If using old AT_DISPATCH:**
|
||||
- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
|
||||
- Then proceed with adding uint support
|
||||
|
||||
**If already using AT_DISPATCH_V2:**
|
||||
- Proceed directly to Step 2
|
||||
|
||||
### Step 2: Analyze the current dispatch macro
|
||||
|
||||
Identify what type groups are currently in use:
|
||||
|
||||
```cpp
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
// body
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Current type coverage
|
||||
```
|
||||
|
||||
Common patterns:
|
||||
- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
|
||||
- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only
|
||||
- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types
|
||||
|
||||
### Step 3: Choose the uint addition method
|
||||
|
||||
Two approaches:
|
||||
|
||||
**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly**
|
||||
- Use when: You want to be explicit about adding uint support
|
||||
- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list
|
||||
|
||||
**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2**
|
||||
- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)`
|
||||
- More concise: replaces one type group with its superset
|
||||
- Only applicable if AT_INTEGRAL_TYPES is present
|
||||
|
||||
### Step 4: Apply the transformation
|
||||
|
||||
**Method 1 example:**
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"min_values_cuda",
|
||||
AT_WRAP([&]() {
|
||||
kernel_impl<scalar_t>(iter);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
kBFloat16, kHalf, kBool
|
||||
);
|
||||
|
||||
// After (add unsigned types)
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"min_values_cuda",
|
||||
AT_WRAP([&]() {
|
||||
kernel_impl<scalar_t>(iter);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
kBFloat16, kHalf, kBool
|
||||
);
|
||||
```
|
||||
|
||||
**Method 2 example:**
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"integral_op",
|
||||
AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}),
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES)
|
||||
);
|
||||
|
||||
// After (substitute with V2)
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"integral_op",
|
||||
AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}),
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
|
||||
);
|
||||
```
|
||||
|
||||
### Step 5: Handle AT_ALL_TYPES vs individual type groups
|
||||
|
||||
If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`:
|
||||
- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES`
|
||||
- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list
|
||||
|
||||
If the dispatch separately lists INTEGRAL and FLOATING:
|
||||
```cpp
|
||||
// Before
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
|
||||
// After (Method 2 preferred)
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
```
|
||||
|
||||
### Step 6: Verify all dispatch sites
|
||||
|
||||
Check the file for ALL dispatch macros that need uint support:
|
||||
- Some operators have multiple dispatch sites (CPU, CUDA, different functions)
|
||||
- Apply the transformation consistently across all sites
|
||||
- Ensure each gets the same type coverage updates
|
||||
|
||||
### Step 7: Validate the changes
|
||||
|
||||
Check that:
|
||||
- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
|
||||
- [ ] Unsigned types are added via one of the two methods
|
||||
- [ ] All relevant dispatch sites in the file are updated
|
||||
- [ ] Type groups use `AT_EXPAND()`
|
||||
- [ ] Arguments are properly formatted and comma-separated
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Pattern 1: AT_ALL_TYPES + extras
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
### Pattern 2: Separate INTEGRAL + FLOATING
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
```
|
||||
|
||||
### Pattern 3: Old dispatch needs conversion first
|
||||
|
||||
```cpp
|
||||
// Before (needs v2 conversion first)
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
|
||||
kernel<scalar_t>();
|
||||
});
|
||||
|
||||
// After v2 conversion
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
|
||||
// After adding uint support
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
## Multiple dispatch sites example
|
||||
|
||||
For a file with multiple functions:
|
||||
|
||||
```cpp
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
|
||||
impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
// Added uint support
|
||||
}
|
||||
|
||||
void min_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
// Added uint support here too
|
||||
}
|
||||
```
|
||||
|
||||
## Decision tree
|
||||
|
||||
Use this decision tree to determine the approach:
|
||||
|
||||
```
|
||||
Is the file using AT_DISPATCH_V2?
|
||||
├─ No → Use at-dispatch-v2 skill first, then continue
|
||||
└─ Yes
|
||||
└─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
|
||||
├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
|
||||
└─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list
|
||||
```
|
||||
|
||||
## Edge cases
|
||||
|
||||
### Case 1: Dispatch with only floating types
|
||||
|
||||
If the operator only supports floating point types, don't add uint support:
|
||||
|
||||
```cpp
|
||||
// Leave as-is - floating point only operator
|
||||
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);
|
||||
```
|
||||
|
||||
### Case 2: Complex types present
|
||||
|
||||
Unsigned types work alongside complex types:
|
||||
|
||||
```cpp
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
AT_EXPAND(AT_COMPLEX_TYPES),
|
||||
kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
### Case 3: Already has uint support
|
||||
|
||||
Check if uint types are already present:
|
||||
- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support
|
||||
- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support
|
||||
- Skip the file if uint support is already present
|
||||
|
||||
## Workflow
|
||||
|
||||
When asked to add uint support:
|
||||
|
||||
1. Read the target file
|
||||
2. Check if using AT_DISPATCH_V2:
|
||||
- If not → use at-dispatch-v2 skill first
|
||||
3. Identify all dispatch macro sites
|
||||
4. For each dispatch:
|
||||
- Analyze current type groups
|
||||
- Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
|
||||
- Apply transformation with Edit tool
|
||||
5. Show the user the changes
|
||||
6. Explain what was modified
|
||||
|
||||
## Important notes
|
||||
|
||||
- Always check if v2 conversion is needed first
|
||||
- Apply changes consistently across all dispatch sites in the file
|
||||
- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
|
||||
- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
|
||||
- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
|
||||
- Some operators may not semantically support unsigned types - use judgment
|
||||
|
||||
## Testing
|
||||
|
||||
After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.
|
||||
305
.claude/skills/at-dispatch-v2/SKILL.md
Normal file
305
.claude/skills/at-dispatch-v2/SKILL.md
Normal file
@ -0,0 +1,305 @@
|
||||
---
|
||||
name: at-dispatch-v2
|
||||
description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.
|
||||
---
|
||||
|
||||
# AT_DISPATCH to AT_DISPATCH_V2 Converter
|
||||
|
||||
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`.
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill when:
|
||||
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
|
||||
- Porting ATen kernels to use the new dispatch API
|
||||
- Working with files in `aten/src/ATen/native/` that use dispatch macros
|
||||
- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
|
||||
|
||||
## Quick reference
|
||||
|
||||
**Old format:**
|
||||
```cpp
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
|
||||
// lambda body
|
||||
});
|
||||
```
|
||||
|
||||
**New format:**
|
||||
```cpp
|
||||
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
|
||||
// lambda body
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
|
||||
```
|
||||
|
||||
## Key transformations
|
||||
|
||||
1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types
|
||||
2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas
|
||||
3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion
|
||||
4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
|
||||
5. **Add include**: `#include <ATen/Dispatch_v2.h>` near other Dispatch includes
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Add the Dispatch_v2.h include
|
||||
|
||||
Add the v2 header near the existing `#include <ATen/Dispatch.h>`:
|
||||
|
||||
```cpp
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
```
|
||||
|
||||
Keep the old Dispatch.h include for now (other code may still need it).
|
||||
|
||||
### Step 2: Identify the old dispatch pattern
|
||||
|
||||
Common patterns to convert:
|
||||
|
||||
- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)`
|
||||
- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)`
|
||||
- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)`
|
||||
- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)`
|
||||
|
||||
### Step 3: Map the old macro to type groups
|
||||
|
||||
Identify which type group macro corresponds to the base types:
|
||||
|
||||
| Old macro base | AT_DISPATCH_V2 type group |
|
||||
|----------------|---------------------------|
|
||||
| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` |
|
||||
| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` |
|
||||
| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` |
|
||||
| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` |
|
||||
| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` |
|
||||
|
||||
For combined patterns, use multiple `AT_EXPAND()` entries:
|
||||
```cpp
|
||||
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
|
||||
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
|
||||
```
|
||||
|
||||
### Step 4: Extract the individual types
|
||||
|
||||
From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.).
|
||||
|
||||
These become the trailing arguments after the type group:
|
||||
```cpp
|
||||
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Individual types from AND3
|
||||
```
|
||||
|
||||
### Step 5: Transform to AT_DISPATCH_V2
|
||||
|
||||
Apply the transformation:
|
||||
|
||||
**Pattern:**
|
||||
```cpp
|
||||
AT_DISPATCH_V2(
|
||||
scalar_type, // 1st: The dtype expression
|
||||
"name", // 2nd: The debug string
|
||||
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
|
||||
type_groups, // 4th+: Type groups with AT_EXPAND()
|
||||
individual_types // Last: Individual types
|
||||
)
|
||||
```
|
||||
|
||||
**Example transformation:**
|
||||
```cpp
|
||||
// BEFORE
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool,
|
||||
iter.dtype(),
|
||||
"min_values_cuda",
|
||||
[&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}
|
||||
);
|
||||
|
||||
// AFTER
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(),
|
||||
"min_values_cuda",
|
||||
AT_WRAP([&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
kBFloat16, kHalf, kBool
|
||||
);
|
||||
```
|
||||
|
||||
### Step 6: Handle multi-line lambdas
|
||||
|
||||
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
|
||||
|
||||
```cpp
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"complex_kernel",
|
||||
AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
|
||||
);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES)
|
||||
);
|
||||
```
|
||||
|
||||
### Step 7: Verify the conversion
|
||||
|
||||
Check that:
|
||||
- [ ] `AT_WRAP()` wraps the entire lambda
|
||||
- [ ] Type groups use `AT_EXPAND()`
|
||||
- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`)
|
||||
- [ ] Argument order is: scalar_type, name, lambda, types
|
||||
- [ ] Include added: `#include <ATen/Dispatch_v2.h>`
|
||||
|
||||
## Type group reference
|
||||
|
||||
Available type group macros (use with `AT_EXPAND()`):
|
||||
|
||||
```cpp
|
||||
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
|
||||
AT_FLOATING_TYPES // kDouble, kFloat
|
||||
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
|
||||
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
|
||||
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
|
||||
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
|
||||
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
|
||||
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
|
||||
AT_FLOAT8_TYPES // Float8 variants
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND2
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
|
||||
kernel<scalar_t>(data);
|
||||
});
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>(data);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
|
||||
tensor.scalar_type(), "float_op", [&] {
|
||||
process<scalar_t>(tensor);
|
||||
});
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
|
||||
process<scalar_t>(tensor);
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
|
||||
```
|
||||
|
||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kComplexHalf, kHalf,
|
||||
self.scalar_type(),
|
||||
"complex_op",
|
||||
[&] {
|
||||
result = compute<scalar_t>(self);
|
||||
}
|
||||
);
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(
|
||||
self.scalar_type(),
|
||||
"complex_op",
|
||||
AT_WRAP([&] {
|
||||
result = compute<scalar_t>(self);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
AT_EXPAND(AT_COMPLEX_TYPES),
|
||||
kComplexHalf,
|
||||
kHalf
|
||||
);
|
||||
```
|
||||
|
||||
## Edge cases
|
||||
|
||||
### Case 1: No extra types (rare)
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES));
|
||||
```
|
||||
|
||||
### Case 2: Many individual types (AND4, AND5, etc.)
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
|
||||
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
|
||||
```
|
||||
|
||||
### Case 3: Lambda with no captures
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
|
||||
static_kernel<scalar_t>();
|
||||
});
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
|
||||
static_kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
|
||||
```
|
||||
|
||||
## Benefits of AT_DISPATCH_V2
|
||||
|
||||
1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4
|
||||
2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()`
|
||||
3. **Extensible**: Easy to add more types without hitting macro limits
|
||||
4. **Clearer**: Type groups are explicit, not implicit in macro name
|
||||
|
||||
## Important notes
|
||||
|
||||
- Keep `#include <ATen/Dispatch.h>` - other code may need it
|
||||
- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda
|
||||
- Type groups need `AT_EXPAND()`, individual types don't
|
||||
- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs
|
||||
- See the header file for the Python script to regenerate the macro implementation
|
||||
|
||||
## Workflow
|
||||
|
||||
When asked to convert AT_DISPATCH macros:
|
||||
|
||||
1. Read the file to identify all AT_DISPATCH uses
|
||||
2. Add `#include <ATen/Dispatch_v2.h>` if not present
|
||||
3. For each dispatch macro:
|
||||
- Identify the pattern and extract components
|
||||
- Map the base type group
|
||||
- Extract individual types
|
||||
- Construct the AT_DISPATCH_V2 call
|
||||
- Apply with Edit tool
|
||||
4. Show the user the complete converted file
|
||||
5. Explain what was changed
|
||||
|
||||
Do NOT compile or test the code - focus on accurate conversion only.
|
||||
4
.github/actions/diskspace-cleanup/action.yml
vendored
4
.github/actions/diskspace-cleanup/action.yml
vendored
@ -27,7 +27,9 @@ runs:
|
||||
docker system prune -af
|
||||
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
|
||||
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
|
||||
echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
|
||||
diskspace_cutoff_int=$((diskspace_cutoff + 0))
|
||||
difference=$((100 - diskspace_cutoff_int))
|
||||
echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
|
||||
echo "$msg"
|
||||
exit 1
|
||||
else
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
218d2ab791d437309f91e0486eb9fa7f00badc17
|
||||
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -19,6 +19,7 @@ ciflow_push_tags:
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
|
||||
- ciflow/inductor-perf-test-nightly-x86-zen
|
||||
- ciflow/inductor-perf-test-nightly-xpu
|
||||
- ciflow/inductor-periodic
|
||||
- ciflow/inductor-rocm
|
||||
- ciflow/linux-aarch64
|
||||
|
||||
89
.github/scripts/generate_binary_build_matrix.py
vendored
89
.github/scripts/generate_binary_build_matrix.py
vendored
@ -11,11 +11,17 @@ architectures:
|
||||
* Latest XPU
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this
|
||||
SCRIPT_DIR = Path(__file__).absolute().parent
|
||||
REPO_ROOT = SCRIPT_DIR.parent.parent
|
||||
|
||||
|
||||
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
|
||||
CUDA_STABLE = "12.8"
|
||||
CUDA_ARCHES_FULL_VERSION = {
|
||||
@ -31,8 +37,7 @@ CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"13.0": "9",
|
||||
}
|
||||
|
||||
# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this
|
||||
ROCM_ARCHES = ["6.4", "7.0"]
|
||||
ROCM_ARCHES = ["7.0", "7.1"]
|
||||
|
||||
XPU_ARCHES = ["xpu"]
|
||||
|
||||
@ -137,9 +142,48 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
|
||||
}
|
||||
|
||||
|
||||
def get_nccl_wheel_version(arch_version: str) -> str:
|
||||
import re
|
||||
# Used by tools/nightly.py
|
||||
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
|
||||
NIGHTLY_SOURCE_MATRIX = {
|
||||
"cpu": dict(
|
||||
name="cpu",
|
||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
|
||||
supported_platforms=["Linux", "macOS", "Windows"],
|
||||
accelerator="cpu",
|
||||
)
|
||||
}
|
||||
CUDA_NIGHTLY_SOURCE_MATRIX = {
|
||||
f"cuda-{major}.{minor}": dict(
|
||||
name=f"cuda-{major}.{minor}",
|
||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}",
|
||||
supported_platforms=["Linux", "Windows"],
|
||||
accelerator="cuda",
|
||||
)
|
||||
for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES)
|
||||
}
|
||||
ROCM_NIGHTLY_SOURCE_MATRIX = {
|
||||
f"rocm-{major}.{minor}": dict(
|
||||
name=f"rocm-{major}.{minor}",
|
||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}",
|
||||
supported_platforms=["Linux"],
|
||||
accelerator="rocm",
|
||||
)
|
||||
for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES)
|
||||
}
|
||||
XPU_NIGHTLY_SOURCE_MATRIX = {
|
||||
"xpu": dict(
|
||||
name="xpu",
|
||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu",
|
||||
supported_platforms=["Linux"],
|
||||
accelerator="xpu",
|
||||
)
|
||||
}
|
||||
NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX)
|
||||
NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX)
|
||||
NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX)
|
||||
|
||||
|
||||
def get_nccl_wheel_version(arch_version: str) -> str:
|
||||
requirements = map(
|
||||
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
|
||||
)
|
||||
@ -147,17 +191,14 @@ def get_nccl_wheel_version(arch_version: str) -> str:
|
||||
|
||||
|
||||
def read_nccl_pin(arch_version: str) -> str:
|
||||
from pathlib import Path
|
||||
|
||||
nccl_pin_path = os.path.join(
|
||||
Path(__file__).absolute().parents[2],
|
||||
".ci",
|
||||
"docker",
|
||||
"ci_commit_pins",
|
||||
f"nccl-cu{arch_version[:2]}.txt",
|
||||
nccl_pin_path = (
|
||||
REPO_ROOT
|
||||
/ ".ci"
|
||||
/ "docker"
|
||||
/ "ci_commit_pins"
|
||||
/ f"nccl-cu{arch_version[:2]}.txt"
|
||||
)
|
||||
with open(nccl_pin_path) as f:
|
||||
return f.read().strip()
|
||||
return nccl_pin_path.read_text().strip()
|
||||
|
||||
|
||||
def validate_nccl_dep_consistency(arch_version: str) -> None:
|
||||
@ -165,7 +206,8 @@ def validate_nccl_dep_consistency(arch_version: str) -> None:
|
||||
wheel_ver = get_nccl_wheel_version(arch_version)
|
||||
if not nccl_release_tag.startswith(f"v{wheel_ver}"):
|
||||
raise RuntimeError(
|
||||
f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}"
|
||||
f"{arch_version} NCCL release tag version {nccl_release_tag} "
|
||||
f"does not correspond to wheel version {wheel_ver}"
|
||||
)
|
||||
|
||||
|
||||
@ -412,7 +454,14 @@ def generate_wheels_matrix(
|
||||
return ret
|
||||
|
||||
|
||||
validate_nccl_dep_consistency("13.0")
|
||||
validate_nccl_dep_consistency("12.9")
|
||||
validate_nccl_dep_consistency("12.8")
|
||||
validate_nccl_dep_consistency("12.6")
|
||||
arch_version = ""
|
||||
for arch_version in CUDA_ARCHES:
|
||||
validate_nccl_dep_consistency(arch_version)
|
||||
del arch_version
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Used by tools/nightly.py
|
||||
(SCRIPT_DIR / "nightly_source_matrix.json").write_text(
|
||||
json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n"
|
||||
)
|
||||
|
||||
13
.github/workflows/_xpu-test.yml
vendored
13
.github/workflows/_xpu-test.yml
vendored
@ -38,6 +38,10 @@ on:
|
||||
default: ""
|
||||
description: |
|
||||
List of tests to include (empty string implies default list)
|
||||
dashboard-tag:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
disable-monitor:
|
||||
description: |
|
||||
[Experimental] Disable utilization monitoring for tests.
|
||||
@ -58,6 +62,11 @@ on:
|
||||
required: false
|
||||
type: number
|
||||
default: 1
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to avoid rate limits when downloading models or datasets from hub
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
@ -196,6 +205,8 @@ jobs:
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
|
||||
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
|
||||
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
|
||||
run: |
|
||||
# Fetch aws credential from IMDs
|
||||
@ -246,6 +257,8 @@ jobs:
|
||||
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
|
||||
-e TESTS_TO_INCLUDE \
|
||||
-e ZE_AFFINITY_MASK \
|
||||
-e HUGGING_FACE_HUB_TOKEN \
|
||||
-e DASHBOARD_TAG \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--ulimit stack=10485760:83886080 \
|
||||
--ulimit core=0 \
|
||||
|
||||
2
.github/workflows/build-almalinux-images.yml
vendored
2
.github/workflows/build-almalinux-images.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
runs-on: linux.9xlarge.ephemeral
|
||||
strategy:
|
||||
matrix:
|
||||
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
|
||||
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"]
|
||||
steps:
|
||||
- name: Build docker image
|
||||
uses: pytorch/pytorch/.github/actions/binary-docker-build@main
|
||||
|
||||
2
.github/workflows/build-libtorch-images.yml
vendored
2
.github/workflows/build-libtorch-images.yml
vendored
@ -52,8 +52,8 @@ jobs:
|
||||
{ tag: "cuda12.9" },
|
||||
{ tag: "cuda12.8" },
|
||||
{ tag: "cuda12.6" },
|
||||
{ tag: "rocm6.4" },
|
||||
{ tag: "rocm7.0" },
|
||||
{ tag: "rocm7.1" },
|
||||
{ tag: "cpu" },
|
||||
]
|
||||
steps:
|
||||
|
||||
2
.github/workflows/build-magma-rocm-linux.yml
vendored
2
.github/workflows/build-magma-rocm-linux.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
rocm_version: ["70", "64"]
|
||||
rocm_version: ["71", "70"]
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
2
.github/workflows/build-manywheel-images.yml
vendored
2
.github/workflows/build-manywheel-images.yml
vendored
@ -54,8 +54,8 @@ jobs:
|
||||
{ name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||
{ name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||
{ name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" },
|
||||
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
|
||||
{ name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" },
|
||||
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
|
||||
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||
{ name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" },
|
||||
|
||||
9
.github/workflows/build-triton-wheel.yml
vendored
9
.github/workflows/build-triton-wheel.yml
vendored
@ -55,7 +55,7 @@ jobs:
|
||||
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
|
||||
include:
|
||||
- device: "rocm"
|
||||
rocm_version: "7.0"
|
||||
rocm_version: "7.1"
|
||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||
- device: "cuda"
|
||||
rocm_version: ""
|
||||
@ -159,12 +159,7 @@ jobs:
|
||||
WITH_CLANG_LDD="--with-clang-ldd"
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_DEVICE}" == xpu ]]; then
|
||||
docker exec -t "${container_name}" bash -c "dnf install -y gcc-toolset-13-gcc-c++"
|
||||
docker exec -t "${container_name}" bash -c "source /opt/rh/gcc-toolset-13/enable && ${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE"
|
||||
else
|
||||
docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
|
||||
fi
|
||||
docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
|
||||
|
||||
if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then
|
||||
docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl"
|
||||
|
||||
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -67,6 +67,7 @@ jobs:
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3-clang18-asan,
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
|
||||
236
.github/workflows/generated-linux-binary-libtorch-nightly.yml
generated
vendored
236
.github/workflows/generated-linux-binary-libtorch-nightly.yml
generated
vendored
@ -384,124 +384,6 @@ jobs:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
libtorch-rocm6_4-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
timeout-minutes: 300
|
||||
build_name: libtorch-rocm6_4-shared-with-deps-release
|
||||
build_environment: linux-binary-libtorch
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
libtorch-rocm6_4-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-rocm6_4-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
runs-on: linux.rocm.gpu.mi250
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
SKIP_ALL_TESTS: 1
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-rocm6_4-shared-with-deps-release
|
||||
path: "${{ runner.temp }}/artifacts/"
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: ROCm set GPU_FLAG
|
||||
run: |
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
|
||||
docker-image-name: libtorch-cxx11-builder
|
||||
custom-tag-prefix: rocm6.4
|
||||
docker-build-dir: .ci/docker
|
||||
working-directory: pytorch
|
||||
- name: Pull Docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Test Pytorch binary
|
||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||
env:
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Teardown ROCm
|
||||
uses: ./.github/actions/teardown-rocm
|
||||
libtorch-rocm6_4-shared-with-deps-release-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: libtorch-rocm6_4-shared-with-deps-release-test
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
build_name: libtorch-rocm6_4-shared-with-deps-release
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
libtorch-rocm7_0-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
@ -619,3 +501,121 @@ jobs:
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
libtorch-rocm7_1-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm7.1
|
||||
GPU_ARCH_VERSION: "7.1"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
timeout-minutes: 300
|
||||
build_name: libtorch-rocm7_1-shared-with-deps-release
|
||||
build_environment: linux-binary-libtorch
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
libtorch-rocm7_1-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-rocm7_1-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
runs-on: linux.rocm.gpu.mi250
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm7.1
|
||||
GPU_ARCH_VERSION: "7.1"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
SKIP_ALL_TESTS: 1
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-rocm7_1-shared-with-deps-release
|
||||
path: "${{ runner.temp }}/artifacts/"
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: ROCm set GPU_FLAG
|
||||
run: |
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
|
||||
docker-image-name: libtorch-cxx11-builder
|
||||
custom-tag-prefix: rocm7.1
|
||||
docker-build-dir: .ci/docker
|
||||
working-directory: pytorch
|
||||
- name: Pull Docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Test Pytorch binary
|
||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||
env:
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Teardown ROCm
|
||||
uses: ./.github/actions/teardown-rocm
|
||||
libtorch-rocm7_1-shared-with-deps-release-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: libtorch-rocm7_1-shared-with-deps-release-test
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm7.1
|
||||
GPU_ARCH_VERSION: "7.1"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
build_name: libtorch-rocm7_1-shared-with-deps-release
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
1610
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
1610
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
148
.github/workflows/inductor-perf-test-nightly-xpu.yml
vendored
Normal file
148
.github/workflows/inductor-perf-test-nightly-xpu.yml
vendored
Normal file
@ -0,0 +1,148 @@
|
||||
name: inductor-perf-nightly-xpu
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/inductor-perf-test-nightly-xpu/*
|
||||
schedule:
|
||||
- cron: 30 17 * * *
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
training:
|
||||
description: Run training (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
inference:
|
||||
description: Run inference (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
default:
|
||||
description: Run inductor_default?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
dynamic:
|
||||
description: Run inductor_dynamic_shapes?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
freezing_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
maxautotune:
|
||||
description: Run inductor_max_autotune?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
xpu-n-py3_10-inductor-benchmark-build:
|
||||
name: xpu-n-py3.10-inductor-benchmark
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
{ config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
xpu-n-py3_10-inductor-benchmark-test-nightly:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
name: xpu-n-py3.10-inductor-benchmark
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# Disable monitor in perf tests for more investigation
|
||||
disable-monitor: true
|
||||
monitor-log-interval: 10
|
||||
monitor-data-collect-interval: 2
|
||||
secrets: inherit
|
||||
|
||||
xpu-n-py3_10-inductor-benchmark-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
name: xpu-n-py3.10-inductor-test
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -143,6 +143,7 @@ scripts/release_notes/*.json
|
||||
sccache-stats*.json
|
||||
lint.json
|
||||
merge_record.json
|
||||
.github/scripts/nightly_source_matrix.json
|
||||
|
||||
# These files get copied over on invoking setup.py
|
||||
torchgen/packaged/*
|
||||
@ -397,3 +398,4 @@ CLAUDE.local.md
|
||||
/test_*.py
|
||||
/debug_*.py
|
||||
CLAUDE_CONTEXT/
|
||||
/.claude/settings.local.json
|
||||
|
||||
@ -11,7 +11,6 @@ aspects of contributing to PyTorch.
|
||||
<!-- toc -->
|
||||
|
||||
- [Developing PyTorch](#developing-pytorch)
|
||||
- [Setup the development environment](#setup-the-development-environment)
|
||||
- [Tips and Debugging](#tips-and-debugging)
|
||||
- [Nightly Checkout & Pull](#nightly-checkout--pull)
|
||||
- [Codebase structure](#codebase-structure)
|
||||
@ -67,23 +66,6 @@ aspects of contributing to PyTorch.
|
||||
|
||||
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
|
||||
|
||||
### Setup the development environment
|
||||
|
||||
First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials.
|
||||
|
||||
Then clone the PyTorch project and setup the development environment:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:<USERNAME>/pytorch.git
|
||||
cd pytorch
|
||||
git remote add upstream git@github.com:pytorch/pytorch.git
|
||||
|
||||
make setup-env
|
||||
# Or run `make setup-env-cuda` for pre-built CUDA binaries
|
||||
# Or run `make setup-env-rocm` for pre-built ROCm binaries
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
### Tips and Debugging
|
||||
|
||||
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
|
||||
|
||||
@ -825,6 +825,14 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
|
||||
display_vmap_fallback_warnings_ = enabled;
|
||||
}
|
||||
|
||||
bool Context::warnOnAccumulateGradStreamMismatch() const {
|
||||
return warn_on_accumulate_grad_stream_mismatch_;
|
||||
}
|
||||
|
||||
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
|
||||
warn_on_accumulate_grad_stream_mismatch_ = enabled;
|
||||
}
|
||||
|
||||
bool Context::isDefaultMobileCPUAllocatorSet() {
|
||||
return prev_allocator_ptr_ != nullptr;
|
||||
}
|
||||
|
||||
@ -404,6 +404,9 @@ class TORCH_API Context {
|
||||
void setDisplayVmapFallbackWarnings(bool enabled);
|
||||
bool areVmapFallbackWarningsEnabled() const;
|
||||
|
||||
void setWarnOnAccumulateGradStreamMismatch(bool enabled);
|
||||
bool warnOnAccumulateGradStreamMismatch() const;
|
||||
|
||||
bool isDefaultMobileCPUAllocatorSet();
|
||||
void setDefaultMobileCPUAllocator();
|
||||
void unsetDefaultMobileCPUAllocator();
|
||||
@ -494,6 +497,7 @@ class TORCH_API Context {
|
||||
bool release_original_weights = false;
|
||||
#endif
|
||||
bool display_vmap_fallback_warnings_ = false;
|
||||
bool warn_on_accumulate_grad_stream_mismatch_ = true;
|
||||
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
|
||||
bool enable_sparse_tensor_invariant_checks = false;
|
||||
bool allow_fp16_reduction_cpu = false;
|
||||
|
||||
@ -19,6 +19,13 @@ inline namespace CPU_CAPABILITY {
|
||||
#error "Big endian is not supported."
|
||||
#endif
|
||||
|
||||
// GCC does not properly optimize bf16 operators
|
||||
#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19)
|
||||
#define BF16_ARITHMETIC_SUPPORTED() 1
|
||||
#else
|
||||
#define BF16_ARITHMETIC_SUPPORTED() 0
|
||||
#endif
|
||||
|
||||
// Unlike the float16_t family of types, bfloat16_t is not available
|
||||
// when we're not targeting bfloat16 hardware support on some
|
||||
// platforms (but not Mac, so we have to be careful not to shadow the
|
||||
@ -352,18 +359,72 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
other, &Vectorized<float>::name); \
|
||||
}
|
||||
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
Vectorized frac() const;
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
// Flip sign bit
|
||||
Vectorized<c10::BFloat16> neg() const {
|
||||
return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
|
||||
}
|
||||
// Fast reciprocal is fine because we are truncating results
|
||||
Vectorized<c10::BFloat16> reciprocal() const {
|
||||
auto x = vcvtq_low_f32_bf16(values);
|
||||
auto y = vcvtq_high_f32_bf16(values);
|
||||
x = vrecpeq_f32(x);
|
||||
y = vrecpeq_f32(y);
|
||||
return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
|
||||
}
|
||||
// Clearing the sign bit
|
||||
Vectorized<c10::BFloat16> abs() const {
|
||||
return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
|
||||
}
|
||||
#else
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||
#endif
|
||||
|
||||
// These functions are optimized on clang-21+
|
||||
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
|
||||
Vectorized<c10::BFloat16> operator==(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values == other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator!=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values != other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values < other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator<=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values <= other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values > other.values;
|
||||
}
|
||||
|
||||
Vectorized<c10::BFloat16> operator>=(
|
||||
const Vectorized<c10::BFloat16>& other) const {
|
||||
return values >= other.values;
|
||||
}
|
||||
#else
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||
#endif
|
||||
|
||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||
@ -412,28 +473,52 @@ template <>
|
||||
Vectorized<c10::BFloat16> inline operator+(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x + y;
|
||||
#else
|
||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator-(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x - y;
|
||||
#else
|
||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator*(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x * y;
|
||||
#else
|
||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator/(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
return x / y;
|
||||
#else
|
||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
@ -544,12 +629,19 @@ Vectorized<c10::BFloat16> inline fmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y + z;
|
||||
#else
|
||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||
// elements, not the bottom and top half, so they don't seem
|
||||
// particularly useful here. Ideally we would include dot product in
|
||||
// the Vectorized interface...
|
||||
return a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -557,8 +649,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y + z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -566,8 +665,15 @@ Vectorized<c10::BFloat16> inline fmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return x * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -575,8 +681,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
#if BF16_ARITHMETIC_SUPPORTED()
|
||||
bfloat16x8_t x = a;
|
||||
bfloat16x8_t y = b;
|
||||
bfloat16x8_t z = c;
|
||||
return (-x) * y - z;
|
||||
#else
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return -a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
@ -6,9 +6,9 @@ namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
|
||||
// Enable auto-vectorization for GCC-13+ and clang-17+
|
||||
// Enable auto-vectorization for clang-17+
|
||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
||||
#if defined(__clang__) && (__clang_major__ >= 17)
|
||||
|
||||
template <typename from_type, typename to_type>
|
||||
inline void convertImpl(
|
||||
|
||||
@ -309,7 +309,7 @@ class Vectorized<float> {
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||
// Implementation copied from Arm Optimized Routine
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
inline Vectorized<float> vexpq_f32_u20() const {
|
||||
// bail out to sleef if it's a special case:
|
||||
// i.e. there's an input s.t. |input| > 87.3....
|
||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||
@ -348,6 +348,9 @@ class Vectorized<float> {
|
||||
|
||||
return vfmaq_f32(scale, poly, scale);
|
||||
}
|
||||
Vectorized<float> exp_u20() const {
|
||||
return vexpq_f32_u20();
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp_u20();
|
||||
}
|
||||
@ -634,7 +637,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
||||
// - exp(- x * x)
|
||||
auto pow_2 = (*this) * (*this);
|
||||
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
||||
auto tmp4 = neg_pow_2.exp();
|
||||
auto tmp4 = neg_pow_2.vexpq_f32_u20();
|
||||
auto tmp5 = tmp4 ^ neg_zero_vec;
|
||||
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
||||
auto tmp6 = t * tmp5;
|
||||
|
||||
@ -1,78 +1,90 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 1
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 0
|
||||
// Suppress unsued private field warnings as this class is not supposed to be called
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field")
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
@ -80,7 +92,7 @@ namespace at::cuda {
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
@ -91,7 +103,7 @@ namespace at::cuda {
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
@ -120,7 +132,7 @@ namespace at::cuda {
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
@ -128,25 +140,9 @@ namespace at::cuda {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
@ -175,7 +171,7 @@ namespace at::cuda {
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -1,53 +1,38 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
// Forward declare green context as opaque ptr
|
||||
typedef struct CUgreenCtx_st* CUgreenCtx;
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
// Green context creation
|
||||
static std::unique_ptr<GreenContext> create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id);
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
|
||||
@ -7,17 +7,6 @@
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
// hipSparse const API added in v2.4.0
|
||||
#if HIPSPARSE_VERSION >= 200400
|
||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||
#else
|
||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||
#endif
|
||||
#else // USE_ROCM
|
||||
#define AT_USE_HIPSPARSE_GENERIC_API() 0
|
||||
#endif // USE_ROCM
|
||||
|
||||
// cuSparse Generic API spsv function was added in CUDA 11.3.0
|
||||
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
|
||||
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
namespace detail {
|
||||
@ -12,39 +10,36 @@ __device__ __constant__ float cublas_one_device;
|
||||
__device__ __constant__ float cublas_zero_device;
|
||||
|
||||
float *get_cublas_device_one() {
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
static float *ptr = nullptr;
|
||||
static auto init_flag = [&]() {
|
||||
const float one = 1.f;
|
||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
||||
return true;
|
||||
}();
|
||||
|
||||
float *ptr;
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
float *get_cublas_device_zero() {
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
static float *ptr = nullptr;
|
||||
static auto init_flag = [&]() {
|
||||
const float zero = 0.f;
|
||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
||||
return true;
|
||||
}();
|
||||
|
||||
float *ptr;
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
float *get_user_alpha_ptr() {
|
||||
static float *alpha_ptr;
|
||||
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
static bool init_flag [[maybe_unused]] = []() {
|
||||
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
|
||||
});
|
||||
return true;
|
||||
}();
|
||||
|
||||
return alpha_ptr;
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
@ -151,6 +152,36 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
}
|
||||
|
||||
virtual bool isAvailable() const override;
|
||||
|
||||
/* MTIAGraph related APIs */
|
||||
virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void mtiagraphCaptureEnd(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void mtiagraphInstantiate(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void mtiagraphReplay(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void mtiagraphReset(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual MempoolId_t mtiagraphPool(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
|
||||
@ -534,20 +534,20 @@ Tensor trace_decomp(const Tensor& tensor) {
|
||||
std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
|
||||
const Tensor& self,
|
||||
std::optional<int64_t> self_bdim,
|
||||
int64_t diagonal = 0) {
|
||||
c10::SymInt diagonal = 0) {
|
||||
TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto result = at::tril(self_, diagonal);
|
||||
auto result = at::tril_symint(self_, std::move(diagonal));
|
||||
return std::make_tuple(std::move(result), 0);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
|
||||
const Tensor& self,
|
||||
std::optional<int64_t> self_bdim,
|
||||
int64_t diagonal = 0) {
|
||||
c10::SymInt diagonal = 0) {
|
||||
TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto result = at::triu(self_, diagonal);
|
||||
auto result = at::triu_symint(self_, std::move(diagonal));
|
||||
return std::make_tuple(std::move(result), 0);
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
@ -10,9 +8,6 @@
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
static std::unique_ptr<MPSDevice> mps_device;
|
||||
static c10::once_flag mpsdev_init;
|
||||
|
||||
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
|
||||
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
|
||||
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+
|
||||
@ -21,8 +16,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
|
||||
}
|
||||
|
||||
MPSDevice* MPSDevice::getInstance() {
|
||||
c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
|
||||
return mps_device.get();
|
||||
static MPSDevice mps_device;
|
||||
return &mps_device;
|
||||
}
|
||||
|
||||
MPSDevice::~MPSDevice() {
|
||||
|
||||
@ -25,18 +25,19 @@ TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
|
||||
// #20866, #22032: Guarantee this for the official C++ API?
|
||||
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
|
||||
"avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
|
||||
const int64_t kH = kernel_size[0];
|
||||
const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
|
||||
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
|
||||
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
|
||||
|
||||
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
|
||||
"avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
|
||||
const int64_t dH = stride.empty() ? kH : stride[0];
|
||||
const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
|
||||
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
|
||||
const int dW = stride.empty() ? kW :
|
||||
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
|
||||
|
||||
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
|
||||
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
|
||||
const int64_t padH = padding[0];
|
||||
const int64_t padW = padding.size() == 1 ? padH : padding[1];
|
||||
const int padH = safe_downcast<int, int64_t>(padding[0]);
|
||||
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
|
||||
|
||||
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
|
||||
"divisor must be not zero");
|
||||
|
||||
@ -410,8 +410,8 @@ struct ConvParams {
|
||||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
// broken on cuDNN 9.8
|
||||
if (cudnn_version >= 90800) {
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
|
||||
weight.dim() == 5) {
|
||||
|
||||
@ -640,7 +640,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr
|
||||
}
|
||||
}
|
||||
|
||||
return ops[0];
|
||||
return std::move(ops[0]);
|
||||
}
|
||||
|
||||
// _trilinear computes a trilinear einstein sum with an unrolled dimension
|
||||
|
||||
@ -139,7 +139,7 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kHalf, dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
auto norm_val = norm.to<scalar_t>();
|
||||
scalar_t beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<scalar_t>(norm_val);
|
||||
|
||||
@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
// Conditions for bias to be fusable
|
||||
&& (
|
||||
self.is_contiguous() &&
|
||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||
self.sizes().back() == mat2_sizes[1]
|
||||
)
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
scalar_type == at::ScalarType::Double ||
|
||||
|
||||
@ -13,7 +13,7 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
|
||||
if (allow_neg_indices) {
|
||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
|
||||
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
|
||||
if (off >= slice_size) return;
|
||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||
|
||||
@ -794,6 +794,24 @@ void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const Sc
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprops->major != 9) {
|
||||
// Only on Hopper GPUs
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
dprops->major == 9,
|
||||
"DeepSeek style (1x128, 128x128) scaling only supported in CUDA for SM90")
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_block1x128_block1x128(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -802,8 +820,12 @@ _scaled_block1x128_block1x128(
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
@ -821,6 +843,12 @@ _scaled_block1x128_block1x128(
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -831,10 +859,12 @@ _scaled_block128x128_block1x128(
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
|
||||
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
@ -852,6 +882,12 @@ _scaled_block128x128_block1x128(
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -862,8 +898,12 @@ _scaled_block1x128_block128x128(
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
@ -881,6 +921,12 @@ _scaled_block1x128_block128x128(
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
|
||||
@ -160,8 +160,8 @@ struct _cuda_scatter_gather_internal_kernel {
|
||||
auto offsets = offset_calc.get(i);
|
||||
|
||||
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
|
||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "scatter gather kernel index out of bounds");
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "scatter gather kernel index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
|
||||
|
||||
f(
|
||||
(scalar_t*)(self_ptr + offsets[0]),
|
||||
@ -406,9 +406,8 @@ struct _cuda_scatter_fill_internal_kernel {
|
||||
auto offsets = offset_calc.get(i);
|
||||
|
||||
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
|
||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "index out of bounds"
|
||||
);
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
|
||||
|
||||
f(
|
||||
(scalar_t*)(self_ptr + offsets[0]),
|
||||
|
||||
@ -141,7 +141,8 @@ WelfordDataLN cuWelfordOnlineSum(
|
||||
if constexpr (!rms_norm){
|
||||
U delta = val - curr_sum.mean;
|
||||
U new_count = curr_sum.count + 1.f;
|
||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
|
||||
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
|
||||
#else
|
||||
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
|
||||
@ -163,7 +164,8 @@ WelfordDataLN cuWelfordCombine(
|
||||
U count = dataA.count + dataB.count;
|
||||
U mean, sigma2;
|
||||
if (count > decltype(dataB.count){0}) {
|
||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
|
||||
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
auto coef = __builtin_amdgcn_rcpf(count);
|
||||
#else
|
||||
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
|
||||
|
||||
@ -86,6 +86,28 @@ struct zeta_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct logaddexp_functor {
|
||||
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return c10::metal::logaddexp(a, b);
|
||||
}
|
||||
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
|
||||
inline float operator()(const T a, const T b) {
|
||||
return c10::metal::logaddexp(float(a), float(b));
|
||||
}
|
||||
};
|
||||
|
||||
struct logaddexp2_functor {
|
||||
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return c10::metal::logaddexp2(a, b);
|
||||
}
|
||||
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
|
||||
inline float operator()(const T a, const T b) {
|
||||
return c10::metal::logaddexp2(float(a), float(b));
|
||||
}
|
||||
};
|
||||
|
||||
struct xlog1py_functor {
|
||||
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
||||
inline T operator()(const T a, const T b) {
|
||||
@ -377,6 +399,10 @@ REGISTER_FLOAT_BINARY_OP(fmin);
|
||||
REGISTER_FLOAT_BINARY_OP(nextafter);
|
||||
REGISTER_FLOAT_BINARY_OP(zeta);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(zeta);
|
||||
REGISTER_FLOAT_BINARY_OP(logaddexp);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(logaddexp);
|
||||
REGISTER_FLOAT_BINARY_OP(logaddexp2);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(logaddexp2);
|
||||
REGISTER_FLOAT_BINARY_OP(xlog1py);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(xlog1py);
|
||||
REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t);
|
||||
@ -463,6 +489,8 @@ REGISTER_BINARY_OP(add, float2, float2);
|
||||
REGISTER_BINARY_OP(add, half2, half2);
|
||||
REGISTER_BINARY_OP(sub, float2, float2);
|
||||
REGISTER_BINARY_OP(sub, half2, half2);
|
||||
REGISTER_BINARY_OP(logaddexp, float2, float2);
|
||||
REGISTER_BINARY_OP(logaddexp, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2);
|
||||
|
||||
@ -89,6 +89,14 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "zeta");
|
||||
}
|
||||
|
||||
static void logaddexp_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "logaddexp");
|
||||
}
|
||||
|
||||
static void logaddexp2_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "logaddexp2");
|
||||
}
|
||||
|
||||
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
|
||||
lib.exec_binary_kernel(iter, "xlog1py");
|
||||
@ -211,6 +219,8 @@ REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
|
||||
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
|
||||
REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel);
|
||||
REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel);
|
||||
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel)
|
||||
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel)
|
||||
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel)
|
||||
|
||||
@ -17,8 +17,6 @@
|
||||
#include <ATen/ops/ge_native.h>
|
||||
#include <ATen/ops/gt_native.h>
|
||||
#include <ATen/ops/le_native.h>
|
||||
#include <ATen/ops/logaddexp2_native.h>
|
||||
#include <ATen/ops/logaddexp_native.h>
|
||||
#include <ATen/ops/logical_and_native.h>
|
||||
#include <ATen/ops/logical_or_native.h>
|
||||
#include <ATen/ops/logical_xor_native.h>
|
||||
@ -277,30 +275,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* sumTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, output, "logaddexp_out_mps", logaddexp_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* sumTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, output, "logaddexp2_out_mps", logaddexp2_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
onValue:-1.0f
|
||||
offValue:0.0f
|
||||
name:nil];
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
|
||||
if (isWeightsArrayValid) {
|
||||
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
||||
secondaryTensor:weightTensor
|
||||
@ -705,6 +705,7 @@ static void smooth_l1_loss_template(const Tensor& input,
|
||||
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
||||
TORCH_CHECK(input.is_mps());
|
||||
TORCH_CHECK(target.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
||||
return;
|
||||
@ -771,7 +772,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
|
||||
// xn - yn
|
||||
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:targetTensor
|
||||
@ -797,7 +798,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
name:@"lossTensor"];
|
||||
MPSGraphTensor* outputTensor = lossTensor;
|
||||
if (reduction == Reduction::Mean) {
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
|
||||
dataType:[lossTensor dataType]];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
||||
|
||||
@ -84,6 +84,9 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
|
||||
Tensor& output,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
|
||||
"Batch norm for complex is not supported for MPS");
|
||||
using namespace at::native::mps;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
@ -918,6 +921,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
|
||||
@autoreleasepool {
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
// which kernel variant to use based on the normalized axis N size
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/aminmax.h>
|
||||
#include <ATen/ops/avg_pool2d.h>
|
||||
#include <ATen/ops/avg_pool2d_backward.h>
|
||||
#include <ATen/ops/avg_pool2d_backward_native.h>
|
||||
@ -544,8 +545,9 @@ static void max_unpool_out_mps_template(const Tensor& input,
|
||||
if (indices.defined() && indices.numel() > 0) {
|
||||
auto output_image_size = c10::multiply_integers(output_size_);
|
||||
|
||||
int64_t min_idx = indices.min().item<int64_t>();
|
||||
int64_t max_idx = indices.max().item<int64_t>();
|
||||
auto [min_idx_tensor, max_idx_tensor] = indices.aminmax();
|
||||
int64_t min_idx = min_idx_tensor.item<int64_t>();
|
||||
int64_t max_idx = max_idx_tensor.item<int64_t>();
|
||||
|
||||
if (min_idx < 0 || max_idx >= output_image_size) {
|
||||
int64_t error_idx = (min_idx < 0) ? min_idx : max_idx;
|
||||
|
||||
@ -1028,15 +1028,18 @@ TORCH_IMPL_FUNC(prod_out_mps)
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
|
||||
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amax is not defined for complex types");
|
||||
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
|
||||
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amin is not defined for complex types");
|
||||
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(aminmax_out_mps)
|
||||
(const Tensor& input_t, std::optional<int64_t> dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) {
|
||||
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "aminmax is not defined for complex types");
|
||||
reduction_out_mps(input_t,
|
||||
dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : std::nullopt,
|
||||
keepdim,
|
||||
|
||||
@ -83,6 +83,31 @@ std::string get_type_str<int32_t>() {
|
||||
return "int32_t";
|
||||
}
|
||||
|
||||
// If all tensors are contiguous with the same dtype and the cat dimension is 0,
|
||||
// then we can simply copy each tensor's underlying buffer contiguously into the
|
||||
// output.
|
||||
static void cat_out_mps_contiguous_impl(const ITensorListRef& inputs, const Tensor& output) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLBuffer> output_buffer = getMTLBufferStorage(output);
|
||||
size_t output_offset = output.storage_offset() * output.itemsize();
|
||||
|
||||
for (const Tensor& input : inputs) {
|
||||
if (cat_should_skip_tensor(input)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
id<MTLBuffer> input_buffer = getMTLBufferStorage(input);
|
||||
size_t input_offset = input.storage_offset() * input.itemsize();
|
||||
auto nbytes = input.nbytes();
|
||||
auto profile_id =
|
||||
getMPSProfiler().beginProfileCopy(input_buffer, output_buffer, input, output, nbytes, /*non_blocking=*/true);
|
||||
|
||||
stream->copy(input_buffer, output_buffer, nbytes, input_offset, output_offset, profile_id, SyncType::NONE);
|
||||
|
||||
output_offset += nbytes;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: `output` is expected to already have the correct size.
|
||||
template <typename idx_type_t>
|
||||
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
@ -105,7 +130,7 @@ static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, co
|
||||
// copy all the input tensor data into a packed buffer, which would not be
|
||||
// ideal.
|
||||
for (const Tensor& input : inputs) {
|
||||
if (input.numel() == 0) {
|
||||
if (cat_should_skip_tensor(input)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -243,101 +268,16 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (out.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto materialized_inputs = inputs.materialize();
|
||||
auto out_dtype = at::native::result_type(inputs);
|
||||
bool has_large_tensor =
|
||||
isTooLargeForMPSGraph(out) || std::any_of(materialized_inputs.begin(), materialized_inputs.end(), [](auto& t) {
|
||||
return !cat_should_skip_tensor(t) && isTooLargeForMPSGraph(t);
|
||||
});
|
||||
|
||||
int idx = 0;
|
||||
for (const Tensor& t : materialized_inputs) {
|
||||
TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
|
||||
auto lap = at::get_overlap_status(out, t);
|
||||
TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
|
||||
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
|
||||
"of the output memory locations. Found overlap in input tensor ",
|
||||
idx);
|
||||
idx++;
|
||||
}
|
||||
// Check for type promotion
|
||||
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
|
||||
"torch.cat(): input types can't be cast to the desired output type ",
|
||||
out.scalar_type());
|
||||
TORCH_CHECK(!inputs.empty(), "torch.cat(): invalid number of inputs ", inputs.size());
|
||||
|
||||
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
|
||||
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
|
||||
|
||||
// previously, size [0] tensors were the only possible empty tensors; thus, it
|
||||
// wasn't possible to cat empty tensors unless all the other tensors were
|
||||
// 1-dimensional, so we allowed these tensors to be "skipped". We maintain
|
||||
// this behavior for backwards compatibility, but only for this specific size
|
||||
// (i.e. other empty sizes are not skipped).
|
||||
// FIXME: warn if this is the case
|
||||
auto should_skip = [](const Tensor& t) { return t.dim() == 1 && t.size(0) == 0; };
|
||||
at::assert_no_internal_overlap(out);
|
||||
|
||||
Tensor notSkippedTensor;
|
||||
// Indices of tensors to be skipped because they're empty
|
||||
std::vector<int64_t> skipped_tensor_indices;
|
||||
// Tensors to be read
|
||||
std::vector<Tensor> input_tensors;
|
||||
int tensor_idx = 0;
|
||||
for (const Tensor& t : materialized_inputs) {
|
||||
if (t.numel() == 0 || should_skip(t)) {
|
||||
skipped_tensor_indices.push_back(tensor_idx);
|
||||
tensor_idx++;
|
||||
continue;
|
||||
}
|
||||
input_tensors.push_back(t);
|
||||
// TODO: Is this OK?
|
||||
notSkippedTensor = t;
|
||||
tensor_idx++;
|
||||
}
|
||||
// If all inputs are empty tensors, return an empty tensor
|
||||
if (!notSkippedTensor.defined()) {
|
||||
return;
|
||||
}
|
||||
for (const Tensor& t : inputs) {
|
||||
TORCH_CHECK(t.device() == notSkippedTensor.device(),
|
||||
"torch.cat(): all input tensors must be on the same device. Received ",
|
||||
t.device(),
|
||||
" and ",
|
||||
notSkippedTensor.device());
|
||||
}
|
||||
TORCH_CHECK(out.device() == notSkippedTensor.device(),
|
||||
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
|
||||
notSkippedTensor.device(),
|
||||
" and out is on ",
|
||||
out.device());
|
||||
|
||||
std::vector<int64_t> size(notSkippedTensor.sizes().vec());
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
idx = 0;
|
||||
bool has_large_tensor = false;
|
||||
for (const Tensor& tensor : materialized_inputs) {
|
||||
if (isTooLargeForMPSGraph(tensor)) {
|
||||
has_large_tensor |= true;
|
||||
}
|
||||
if (!should_skip(tensor)) {
|
||||
// TODO: Factor out `check_shape_except_dim`
|
||||
check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
|
||||
cat_dim_size += tensor.size(dimension);
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
// Compute the size of the result
|
||||
size[dimension] = cat_dim_size;
|
||||
// skip resizing if size of result is same as expected
|
||||
if (out.sizes() != size) {
|
||||
out.resize_(size, MemoryFormat::Contiguous);
|
||||
}
|
||||
if (out.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
has_large_tensor |= isTooLargeForMPSGraph(out);
|
||||
|
||||
if (has_large_tensor) {
|
||||
if (all_contiguous && all_same_dtype && (memory_format == MemoryFormat::Contiguous) && (dimension == 0)) {
|
||||
return mps::cat_out_mps_contiguous_impl(materialized_inputs, out);
|
||||
} else if (has_large_tensor) {
|
||||
return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
|
||||
} else {
|
||||
return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
|
||||
|
||||
@ -31,6 +31,7 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "kthvalue is not implemented for complex types");
|
||||
// issue #154890, raising error to prevent crash within MPSGraph until
|
||||
// workaround is implemented.
|
||||
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");
|
||||
|
||||
@ -2602,12 +2602,16 @@
|
||||
device_check: NoCheck # TensorIterator
|
||||
structured_delegate: exp.out
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: exp_sparse
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: exp_(Tensor(a!) self) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
structured_delegate: exp.out
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: exp_sparse_
|
||||
tags: pointwise
|
||||
|
||||
- func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -2616,6 +2620,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS, MTIA: exp_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: exp_sparse_out
|
||||
tags: pointwise
|
||||
|
||||
- func: exp2(Tensor self) -> Tensor
|
||||
@ -3622,8 +3627,7 @@
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: logaddexp_out
|
||||
MPS: logaddexp_out_mps
|
||||
CPU, CUDA, MPS: logaddexp_out
|
||||
tags: pointwise
|
||||
|
||||
- func: logaddexp(Tensor self, Tensor other) -> Tensor
|
||||
@ -3635,8 +3639,7 @@
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: logaddexp2_out
|
||||
MPS: logaddexp2_out_mps
|
||||
CPU, CUDA, MPS: logaddexp2_out
|
||||
tags: pointwise
|
||||
|
||||
- func: logaddexp2(Tensor self, Tensor other) -> Tensor
|
||||
@ -8867,11 +8870,11 @@
|
||||
autogen: bitwise_right_shift.Scalar_Tensor_out
|
||||
tags: pointwise
|
||||
|
||||
- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
|
||||
- func: tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)
|
||||
structured_delegate: tril.out
|
||||
variants: method
|
||||
|
||||
- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
|
||||
- func: triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)
|
||||
structured_delegate: triu.out
|
||||
variants: method
|
||||
|
||||
@ -8995,25 +8998,25 @@
|
||||
- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
|
||||
variants: method, function
|
||||
|
||||
- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU: triu_cpu
|
||||
CUDA: triu_cuda
|
||||
MPS: triu_mps_out
|
||||
|
||||
- func: triu(Tensor self, int diagonal=0) -> Tensor
|
||||
- func: triu(Tensor self, SymInt diagonal=0) -> Tensor
|
||||
structured_delegate: triu.out
|
||||
variants: method, function
|
||||
|
||||
- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU: tril_cpu
|
||||
CUDA: tril_cuda
|
||||
MPS: tril_mps_out
|
||||
|
||||
- func: tril(Tensor self, int diagonal=0) -> Tensor
|
||||
- func: tril(Tensor self, SymInt diagonal=0) -> Tensor
|
||||
structured_delegate: tril.out
|
||||
variants: method, function
|
||||
|
||||
|
||||
@ -467,6 +467,28 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
|
||||
!options.has_layout() || options.layout() == kSparse,
|
||||
"expected sparse layout, but got layout ",
|
||||
options.layout());
|
||||
|
||||
if (indices.numel() > 0) {
|
||||
Tensor min_indices =
|
||||
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor cpu_min_indices;
|
||||
if (!indices.is_cpu()) {
|
||||
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
} else {
|
||||
cpu_min_indices = min_indices;
|
||||
}
|
||||
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
|
||||
for (const auto d : c10::irange(indices.size(0))) {
|
||||
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
|
||||
TORCH_CHECK(
|
||||
min_index_in_dim >= 0,
|
||||
"found negative index ",
|
||||
min_index_in_dim,
|
||||
" for dim ",
|
||||
d);
|
||||
}
|
||||
}
|
||||
|
||||
return at::native::_sparse_coo_tensor_unsafe(
|
||||
indices,
|
||||
values,
|
||||
|
||||
@ -26,6 +26,8 @@
|
||||
#include <ATen/ops/erf_native.h>
|
||||
#include <ATen/ops/erfinv.h>
|
||||
#include <ATen/ops/erfinv_native.h>
|
||||
#include <ATen/ops/exp.h>
|
||||
#include <ATen/ops/exp_native.h>
|
||||
#include <ATen/ops/expm1.h>
|
||||
#include <ATen/ops/expm1_native.h>
|
||||
#include <ATen/ops/floor.h>
|
||||
@ -175,6 +177,7 @@ COALESCED_UNARY_UFUNC(atanh)
|
||||
COALESCED_UNARY_UFUNC(ceil)
|
||||
COALESCED_UNARY_UFUNC(deg2rad)
|
||||
COALESCED_UNARY_UFUNC(erf)
|
||||
COALESCED_UNARY_UFUNC(exp)
|
||||
COALESCED_UNARY_UFUNC(erfinv)
|
||||
COALESCED_UNARY_UFUNC(expm1)
|
||||
COALESCED_UNARY_UFUNC(floor)
|
||||
|
||||
@ -1837,6 +1837,10 @@ class BenchmarkRunner:
|
||||
def skip_models_for_cuda(self):
|
||||
return set()
|
||||
|
||||
@property
|
||||
def skip_models_for_xpu(self):
|
||||
return set()
|
||||
|
||||
@property
|
||||
def skip_models_for_cpu(self):
|
||||
return set()
|
||||
@ -3927,6 +3931,8 @@ def run(runner, args, original_dir=None):
|
||||
runner.skip_models.update(runner.skip_models_for_cpu_aarch64)
|
||||
elif args.devices == ["cuda"]:
|
||||
runner.skip_models.update(runner.skip_models_for_cuda)
|
||||
elif args.devices == ["xpu"]:
|
||||
runner.skip_models.update(runner.skip_models_for_xpu)
|
||||
|
||||
if not args.multiprocess:
|
||||
runner.skip_models.update(runner.skip_multiprocess_models)
|
||||
|
||||
@ -56,6 +56,20 @@ def list_benchmarks():
|
||||
print(f"Available benchmarks: {list(BENCHMARK_REGISTRY.keys())}")
|
||||
|
||||
|
||||
def _run_benchmark(
|
||||
benchmark_cls,
|
||||
script_args,
|
||||
):
|
||||
benchmark = benchmark_cls(script_args)
|
||||
benchmark.benchmark()
|
||||
benchmark.report_geomean_speedup()
|
||||
if script_args.print_benchmark_result:
|
||||
print(f"Benchmarking results {benchmark.name}:")
|
||||
print(benchmark.profiling_results)
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
benchmark_name: str,
|
||||
script_args,
|
||||
@ -71,10 +85,7 @@ def run_benchmark(
|
||||
print("=" * 60)
|
||||
|
||||
benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
|
||||
benchmark = benchmark_class(script_args)
|
||||
benchmark.benchmark()
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
_run_benchmark(benchmark_class, script_args)
|
||||
|
||||
return True
|
||||
|
||||
@ -87,10 +98,7 @@ def run_all_benchmarks(script_args):
|
||||
|
||||
for name, cls in BENCHMARK_REGISTRY.items():
|
||||
print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
|
||||
benchmark = cls(script_args)
|
||||
benchmark.benchmark()
|
||||
if script_args.visualize:
|
||||
benchmark.visualize()
|
||||
_run_benchmark(cls, script_args)
|
||||
print()
|
||||
|
||||
|
||||
@ -149,8 +157,43 @@ Examples:
|
||||
help="Whether to exit with an error message for accuracy failure",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-benchmark-result",
|
||||
action="store_true",
|
||||
help="Whether to print the raw benchmarking result. Easier to quickly check the benchmark results on a server without GUI",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--custom-compile-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name for the curve with customized compilation options",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--custom-compile-options",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Json string for the custom compile options.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.custom_compile_options:
|
||||
import json
|
||||
|
||||
try:
|
||||
args.custom_compile_options = json.loads(args.custom_compile_options)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Invalid json string for --custom-compile-options: {args.custom_compile_options}"
|
||||
) from e
|
||||
|
||||
if not args.custom_compile_options:
|
||||
raise RuntimeError("Found no options for --custom-compile-options")
|
||||
if not args.custom_compile_name:
|
||||
raise RuntimeError("Missing label name for the custom compilation")
|
||||
|
||||
# Handle list option
|
||||
if args.list:
|
||||
list_benchmarks()
|
||||
|
||||
@ -8,6 +8,15 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# more important shapes used by internal models
|
||||
extra_shapes_for_norm = (
|
||||
(1152 * 500, 384),
|
||||
(1152 * 500, 512),
|
||||
(1152 * 1000, 384),
|
||||
(1152 * 1000, 512),
|
||||
)
|
||||
|
||||
|
||||
class CrossEntropyForward(BenchmarkKernel):
|
||||
def __init__(self, script_args):
|
||||
super().__init__(script_args)
|
||||
@ -346,7 +355,7 @@ class RMSNormForward(BenchmarkKernel):
|
||||
(32768, 65536),
|
||||
(16384, 131072),
|
||||
(8192, 262144),
|
||||
)
|
||||
) + extra_shapes_for_norm
|
||||
|
||||
def get_memory_bytes(self, args, kwargs) -> int:
|
||||
x, w = args
|
||||
@ -438,8 +447,7 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
(32768, 4096),
|
||||
(32768, 8192),
|
||||
(32768, 16384),
|
||||
(32768, 32768),
|
||||
)
|
||||
) + extra_shapes_for_norm
|
||||
|
||||
def get_memory_bytes(self, args, kwargs) -> int:
|
||||
x, w, dy = args
|
||||
@ -553,7 +561,7 @@ class LayerNormForward(BenchmarkKernel):
|
||||
(32768, 16384),
|
||||
(32768, 32768),
|
||||
(32768, 65536),
|
||||
)
|
||||
) + extra_shapes_for_norm
|
||||
|
||||
def get_memory_bytes(self, args, kwargs) -> int:
|
||||
x, w = args
|
||||
@ -627,7 +635,7 @@ class LayerNormBackward(BenchmarkKernel):
|
||||
(32768, 16384),
|
||||
(32768, 32768),
|
||||
(32768, 65536),
|
||||
)
|
||||
) + extra_shapes_for_norm
|
||||
|
||||
def get_memory_bytes(self, args, kwargs) -> int:
|
||||
x, w, dy = args
|
||||
|
||||
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.stats import gmean
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
@ -107,6 +108,18 @@ class BenchmarkKernel:
|
||||
for backend in self.available_backends:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
|
||||
|
||||
if (
|
||||
"compiled" in self.available_backends
|
||||
and self.script_args.custom_compile_options
|
||||
):
|
||||
torch._dynamo.reset() # cause recompile
|
||||
with torch._inductor.config.patch(self.script_args.custom_compile_options):
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[self.script_args.custom_compile_name] = self.compiled(
|
||||
args_ref, kwargs_ref
|
||||
)()
|
||||
|
||||
gold = res["eager"]
|
||||
|
||||
tol = {}
|
||||
@ -115,7 +128,7 @@ class BenchmarkKernel:
|
||||
"atol": self.script_args.tolerance,
|
||||
"rtol": self.script_args.tolerance,
|
||||
}
|
||||
for backend in self.available_backends:
|
||||
for backend in res:
|
||||
if backend == "eager":
|
||||
continue
|
||||
try:
|
||||
@ -134,37 +147,83 @@ class BenchmarkKernel:
|
||||
print("Exit right away since --exit-on-accuracy-failure is set")
|
||||
sys.exit(1)
|
||||
|
||||
def benchmark_single_shape_for_backend(
|
||||
self, backend, args, kwargs, setting, fn=None
|
||||
) -> bool:
|
||||
if fn is None:
|
||||
fn = getattr(self, backend)
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
try:
|
||||
avg_time = benchmark_kernel_in_milliseconds(fn(args_ref, kwargs_ref))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
|
||||
)
|
||||
self.available_backends.remove(backend) # noqa: B909
|
||||
return False
|
||||
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
|
||||
perf = Performance(setting, avg_time, mem_bytes)
|
||||
print(f"{self.name} kernel on {backend} backend. {perf}")
|
||||
self.profiling_results[backend].append(perf)
|
||||
return True
|
||||
|
||||
def benchmark_single_shape(
|
||||
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
|
||||
):
|
||||
for backend in self.available_backends:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
try:
|
||||
avg_time = benchmark_kernel_in_milliseconds(
|
||||
getattr(self, backend)(args_ref, kwargs_ref)
|
||||
self.benchmark_single_shape_for_backend(backend, args, kwargs, setting)
|
||||
if (
|
||||
"compiled" in self.available_backends
|
||||
and self.script_args.custom_compile_options
|
||||
):
|
||||
torch._dynamo.reset() # cause recompile
|
||||
with torch._inductor.config.patch(self.script_args.custom_compile_options):
|
||||
status = self.benchmark_single_shape_for_backend(
|
||||
self.script_args.custom_compile_name,
|
||||
args,
|
||||
kwargs,
|
||||
setting,
|
||||
fn=self.compiled,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
|
||||
if not status:
|
||||
self.script_args.custom_compile_options = (
|
||||
None # once fail, don't run again
|
||||
)
|
||||
self.available_backends.remove(backend) # noqa: B909
|
||||
continue
|
||||
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
|
||||
perf = Performance(setting, avg_time, mem_bytes)
|
||||
print(f"{self.name} kernel on {backend} backend. {perf}")
|
||||
self.profiling_results[backend].append(perf)
|
||||
|
||||
if should_check_accuracy:
|
||||
self.check_accuracy(args, kwargs)
|
||||
|
||||
def visualize(self) -> None:
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
visualize_comparison(
|
||||
self.profiling_results,
|
||||
title=f"{self.name}",
|
||||
title=f"{self.name} ({device_name})",
|
||||
output_path=f"{self.name}_bench",
|
||||
)
|
||||
return
|
||||
|
||||
def report_geomean_speedup(self) -> None:
|
||||
print(f"Geomean speedup for benchmark {self.name}")
|
||||
eager_result = {
|
||||
result.setting: result for result in self.profiling_results["eager"]
|
||||
}
|
||||
print(f" eager {len(eager_result)} data points")
|
||||
for backend, backend_result in self.profiling_results.items():
|
||||
if backend == "eager":
|
||||
continue
|
||||
speeduplist = []
|
||||
for result in backend_result:
|
||||
eager_latency = eager_result[result.setting].latency
|
||||
backend_latency = result.latency
|
||||
speeduplist.append(
|
||||
eager_latency / backend_latency if backend_latency != 0 else 0.0
|
||||
)
|
||||
|
||||
if len(speeduplist) > 0:
|
||||
print(
|
||||
f" {backend} {len(speeduplist)} data points, {gmean(speeduplist):.2f}x speedup"
|
||||
)
|
||||
|
||||
|
||||
def get_backend_colors() -> dict[str, str]:
|
||||
"""Get consistent color scheme for different backends."""
|
||||
@ -252,5 +311,6 @@ def visualize_comparison(
|
||||
os.makedirs("pics", exist_ok=True)
|
||||
full_path = os.path.join("pics", output_path + ".png")
|
||||
plt.savefig(full_path, dpi=300, bbox_inches="tight", facecolor="white")
|
||||
print(f"Chart saved to {full_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
@ -74,7 +74,8 @@ REQUIRE_HIGHER_TOLERANCE = {
|
||||
REQUIRE_HIGHER_TOLERANCE_AMP = {}
|
||||
|
||||
REQUIRE_EVEN_HIGHER_TOLERANCE = {
|
||||
"beit_base_patch16_224",
|
||||
"deit_base_distilled_patch16_224",
|
||||
"vit_base_patch16_siglip_256",
|
||||
}
|
||||
|
||||
# These models need higher tolerance in MaxAutotune mode
|
||||
@ -354,7 +355,9 @@ class TimmRunner(BenchmarkRunner):
|
||||
if is_training:
|
||||
from torch._inductor import config as inductor_config
|
||||
|
||||
if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
|
||||
if name == "beit_base_patch16_224":
|
||||
tolerance = 16 * 1e-2
|
||||
elif name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
|
||||
inductor_config.max_autotune
|
||||
and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
|
||||
):
|
||||
|
||||
@ -124,6 +124,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
def skip_models_for_cuda(self):
|
||||
return self._skip["device"]["cuda"]
|
||||
|
||||
@property
|
||||
def skip_models_for_xpu(self):
|
||||
return self._skip["device"]["xpu"]
|
||||
|
||||
@property
|
||||
def skip_models_for_freezing_cuda(self):
|
||||
return self._skip["freezing"]["cuda"]
|
||||
|
||||
@ -217,6 +217,9 @@ skip:
|
||||
|
||||
cuda: []
|
||||
|
||||
xpu:
|
||||
- *DETECTRON2_MODELS
|
||||
|
||||
test:
|
||||
training:
|
||||
- *DETECTRON2_MODELS
|
||||
|
||||
@ -15,7 +15,6 @@ namespace c10::cuda {
|
||||
namespace {
|
||||
|
||||
// Global stream state and constants
|
||||
c10::once_flag init_flag;
|
||||
DeviceIndex num_gpus = -1;
|
||||
constexpr int kStreamsPerPoolBits = 5;
|
||||
constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
|
||||
@ -226,7 +225,10 @@ void initDeviceStreamState(DeviceIndex device_index) {
|
||||
// Init front-end to ensure initialization only occurs once
|
||||
void initCUDAStreamsOnce() {
|
||||
// Inits default streams (once, globally)
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
auto static init_flag [[maybe_unused]] = [] {
|
||||
initGlobalStreamState();
|
||||
return true;
|
||||
}();
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
|
||||
@ -624,6 +624,64 @@ inline T spherical_bessel_j0(T x) {
|
||||
return static_cast<T>(::metal::sin(x) / x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<is_scalar_floating_point_v<T>, T> logaddexp(
|
||||
T a,
|
||||
T b) {
|
||||
float a0 = static_cast<float>(a);
|
||||
float b0 = static_cast<float>(b);
|
||||
if (::metal::isinf(a0) && a0 == b0) {
|
||||
return static_cast<T>(a0);
|
||||
} else {
|
||||
float m0 = ::metal::max(a0, b0);
|
||||
return static_cast<T>(
|
||||
m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0))));
|
||||
}
|
||||
}
|
||||
|
||||
// The function is ported from mlx
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<is_complex_v<T>, T> logaddexp(T a, T b) {
|
||||
if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) ||
|
||||
::metal::isnan(b.y)) {
|
||||
return T(NAN, NAN);
|
||||
}
|
||||
|
||||
T maxval = a.x > b.x ? a : b;
|
||||
T minval = a.x < b.x ? a : b;
|
||||
constexpr auto inf = ::metal::numeric_limits<T>::infinity().x;
|
||||
|
||||
if (minval.x == -inf || maxval.x == inf) {
|
||||
return maxval;
|
||||
}
|
||||
|
||||
float2 maxval_ = static_cast<float2>(maxval);
|
||||
float2 minval_ = static_cast<float2>(minval);
|
||||
float m = ::metal::exp(minval_.x - maxval_.x);
|
||||
float2 dexp{
|
||||
m * ::metal::cos(minval_.y - maxval_.y),
|
||||
m * ::metal::sin(minval_.y - maxval_.y),
|
||||
};
|
||||
return static_cast<T>(maxval_ + ::c10::metal::log1p(dexp));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T logaddexp2(T a, T b) {
|
||||
constexpr auto log_2 = float(0.693147180559945309417232121458176);
|
||||
constexpr auto inv_log_2 = float(1) / log_2;
|
||||
float a0 = static_cast<float>(a);
|
||||
float b0 = static_cast<float>(b);
|
||||
if (::metal::isinf(a0) && a0 == b0) {
|
||||
return static_cast<T>(a0);
|
||||
} else {
|
||||
float m0 = ::metal::max(a0, b0);
|
||||
return static_cast<T>(
|
||||
m0 +
|
||||
::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) *
|
||||
inv_log_2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline float xlog1py(T x, T y) {
|
||||
if (::metal::isnan(y)) {
|
||||
|
||||
@ -322,6 +322,24 @@ inline float log1p(float x) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// The function is ported from mlx
|
||||
inline float2 log1p(float2 in) {
|
||||
float x = in.x;
|
||||
float y = in.y;
|
||||
float zabs = ::metal::precise::sqrt(x * x + y * y);
|
||||
float theta = ::metal::atan2(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1p(r), theta};
|
||||
} else {
|
||||
auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {::metal::log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2 = T1>
|
||||
struct pair {
|
||||
T1 first;
|
||||
|
||||
@ -239,7 +239,7 @@ struct Class2 {
|
||||
|
||||
struct mapper_call_func {
|
||||
template <class T>
|
||||
decltype(auto) operator()(T) {
|
||||
auto operator()(T) {
|
||||
return T::type::func();
|
||||
}
|
||||
};
|
||||
@ -254,7 +254,7 @@ TEST(TypeListTest, MapTypesToValues_members) {
|
||||
|
||||
struct mapper_call_nonexistent_function {
|
||||
template <class T>
|
||||
decltype(auto) operator()(T) {
|
||||
auto operator()(T) {
|
||||
return T::type::this_doesnt_exist();
|
||||
}
|
||||
};
|
||||
|
||||
@ -53,7 +53,7 @@ namespace guts {
|
||||
// member functions.
|
||||
namespace detail {
|
||||
template <class F, class Tuple, std::size_t... INDEX>
|
||||
C10_HOST_DEVICE constexpr decltype(auto) apply_impl(
|
||||
C10_HOST_DEVICE constexpr auto apply_impl(
|
||||
F&& f,
|
||||
Tuple&& t,
|
||||
std::index_sequence<INDEX...>) {
|
||||
@ -62,7 +62,7 @@ C10_HOST_DEVICE constexpr decltype(auto) apply_impl(
|
||||
} // namespace detail
|
||||
|
||||
template <class F, class Tuple>
|
||||
C10_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) {
|
||||
C10_HOST_DEVICE constexpr auto apply(F&& f, Tuple&& t) {
|
||||
return detail::apply_impl(
|
||||
std::forward<F>(f),
|
||||
std::forward<Tuple>(t),
|
||||
|
||||
@ -469,7 +469,7 @@ C10_API std::string GetExceptionString(const std::exception& e);
|
||||
|
||||
namespace c10::detail {
|
||||
template <typename... Args>
|
||||
decltype(auto) torchCheckMsgImpl(const char* /*msg*/, const Args&... args) {
|
||||
auto torchCheckMsgImpl(const char* /*msg*/, const Args&... args) {
|
||||
return ::c10::str(args...);
|
||||
}
|
||||
inline C10_API const char* torchCheckMsgImpl(const char* msg) {
|
||||
|
||||
@ -135,7 +135,7 @@ struct _str_wrapper<> final {
|
||||
|
||||
// Convert a list of string-like arguments into a single string.
|
||||
template <typename... Args>
|
||||
inline decltype(auto) str(const Args&... args) {
|
||||
inline auto str(const Args&... args) {
|
||||
return detail::_str_wrapper<
|
||||
typename detail::CanonicalizeStrTypes<Args>::type...>::call(args...);
|
||||
}
|
||||
|
||||
@ -507,7 +507,7 @@ struct map_types_to_values<typelist<Types...>> final {
|
||||
} // namespace detail
|
||||
|
||||
template <class TypeList, class Func>
|
||||
decltype(auto) map_types_to_values(Func&& func) {
|
||||
auto map_types_to_values(Func&& func) {
|
||||
return detail::map_types_to_values<TypeList>::call(std::forward<Func>(func));
|
||||
}
|
||||
|
||||
|
||||
@ -554,6 +554,17 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
TORCH_CHECK_VALUE(
|
||||
@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
return allocator.setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr);
|
||||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
|
||||
C10_XPU_API double getMemoryFraction(DeviceIndex device);
|
||||
|
||||
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
|
||||
|
||||
} // namespace c10::xpu::XPUCachingAllocator
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
|
||||
@ -33,7 +32,6 @@ namespace {
|
||||
* one iGPU and enumerate all iGPUs on that platform.
|
||||
* 3. If neither dGPUs nor iGPUs are found, conclude that no GPUs are available.
|
||||
*/
|
||||
c10::once_flag init_flag;
|
||||
thread_local DeviceIndex curDeviceIndex = 0;
|
||||
|
||||
struct DevicePool {
|
||||
@ -149,7 +147,10 @@ inline void initGlobalDevicePoolState() {
|
||||
}
|
||||
|
||||
inline void initDevicePoolCallOnce() {
|
||||
c10::call_once(init_flag, initGlobalDevicePoolState);
|
||||
auto static init_flag [[maybe_unused]] = [] {
|
||||
initGlobalDevicePoolState();
|
||||
return true;
|
||||
}();
|
||||
}
|
||||
|
||||
void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {
|
||||
|
||||
@ -12,7 +12,6 @@ namespace c10::xpu {
|
||||
namespace {
|
||||
|
||||
// Global stream state and constants
|
||||
c10::once_flag init_flag;
|
||||
DeviceIndex num_gpus = -1;
|
||||
constexpr int kStreamsPerPoolBits = 5;
|
||||
constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
|
||||
@ -163,7 +162,10 @@ void initDeviceStreamState(DeviceIndex device) {
|
||||
}
|
||||
|
||||
void initXPUStreamsOnce() {
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
auto static init_flag [[maybe_unused]] = [] {
|
||||
initGlobalStreamState();
|
||||
return true;
|
||||
}();
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
|
||||
@ -423,8 +423,10 @@ Also see {ref}`saved-tensors-hooks-doc`.
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.autograd.graph.get_gradient_edge
|
||||
```
|
||||
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch
|
||||
```
|
||||
|
||||
% This module needs to be documented. Adding here in the meantime
|
||||
|
||||
@ -394,10 +394,6 @@ an opaque group handle that can be given as a `group` argument to all collective
|
||||
.. autofunction:: new_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_group_rank
|
||||
```
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
|
||||
## Overview
|
||||
|
||||
The LibTorch Stable ABI (Application Binary Interface) provides an interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases.
|
||||
The LibTorch Stable ABI (Application Binary Interface) provides a limited interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. This limited set of APIs is not intended to replace existing LibTorch, but rather to provide a stable foundation for a majority of custom extension use cases. If there is any API you would like to see added to the stable ABI, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
|
||||
|
||||
The stable ABI consists of three main components:
|
||||
The limited stable ABI consists of three main components:
|
||||
|
||||
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
|
||||
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
|
||||
@ -14,8 +14,8 @@ We discuss each of these in detail
|
||||
|
||||
### `torch/headeronly`
|
||||
|
||||
This is a set of inlined C++ headers are completely decoupled from libtorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
|
||||
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`.
|
||||
The inlined C++ headers living in [`torch/headeronly`](https://github.com/pytorch/pytorch/tree/main/torch/headeronly) are completely decoupled from LibTorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
|
||||
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`, as well as a libtorch-independent version of `TORCH_CHECK` that is `STD_TORCH_CHECK`. You can trust all APIs in the `torch::headeronly` namespace to not depend on `libtorch.so`. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
|
||||
|
||||
### `torch/csrc/stable`
|
||||
|
||||
@ -34,8 +34,14 @@ We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please fi
|
||||
|
||||
### Stable C headers
|
||||
|
||||
The stable C headers used by AOTInductor form the foundation of the stable ABI. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs.
|
||||
Further, the stack-based APIs discussed below which allow the user to call the PyTorch dispatcher don't provide strong guarantees on forward and backward compatibility.
|
||||
The stable C headers started by AOTInductor form the foundation of the stable ABI. Presently, the available C headers include:
|
||||
|
||||
- [torch/csrc/inductor/aoti_torch/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/c/shim.h): Includes C-style shim APIs for commonly used regarding Tensors, dtypes, CUDA, and the like.
|
||||
- [torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h): Includes C-style shim APIs for ATen ops from `native_functions.yaml` (e.g. `aoti_torch_aten_new_empty`).
|
||||
- [torch/csrc/inductor/aoti_torch/generated/c_shim_*.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated): Includes C-style shim APIs for specific backend kernels dispatched from `native_functions.yaml` (e.g. `aoti_torch_cuda_pad`). These APIs should only be used for the specific backend they are named after (e.g. `aoti_torch_cuda_pad` should only be used within CUDA kernels), as they opt out of the dispatcher.
|
||||
- [torch/csrc/stable/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/stable/c/shim.h): We are building out more ABIs to logically live in `torch/csrc/stable/c` instead of continuing the AOTI naming that no longer makes sense for our general use case.
|
||||
|
||||
These headers are promised to be ABI stable across releases and adhere to a stronger backwards compatibility policy than LibTorch. Specifically, we promise not to modify them for at least 2 years after they are released. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. Further, the stack-based APIs discussed below which allow the user to call into the PyTorch dispatcher do not provide strong guarantees on forward and backward compatibility of the underlying op that is called.
|
||||
|
||||
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
|
||||
which will handle all the rough edges of the C API for the user.
|
||||
|
||||
@ -76,6 +76,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_per_process_memory_fraction
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
mem_get_info
|
||||
|
||||
@ -266,7 +266,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
|
||||
model(inp).sum().backward()
|
||||
param_names = {param_name for param_name, _ in model.named_parameters()}
|
||||
self.assertEqual(param_names, set(param_name_to_hook_count.keys()))
|
||||
for param_name, count in param_name_to_hook_count.items():
|
||||
for count in param_name_to_hook_count.values():
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
|
||||
|
||||
@ -827,7 +827,7 @@ class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
for iter_idx in range(5):
|
||||
for _ in range(5):
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
|
||||
@ -800,6 +800,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
stderr_redirects={0: stderr_redir},
|
||||
ret_vals={0: queue},
|
||||
queue_finished_reading_event=worker_finished_event_mock,
|
||||
numa_options=None,
|
||||
)
|
||||
self.assertEqual("hello_0", queue.get())
|
||||
if stdout_redir:
|
||||
|
||||
@ -514,18 +514,17 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||
def test_fsdp_cpu_training(self):
|
||||
"""Tests FSDP training on CPU."""
|
||||
gloo_pg = dist.new_group(backend="gloo")
|
||||
for ss in [ # noqa: F841
|
||||
for ss in [
|
||||
ShardingStrategy.NO_SHARD,
|
||||
ShardingStrategy.FULL_SHARD,
|
||||
ShardingStrategy.SHARD_GRAD_OP,
|
||||
ShardingStrategy.HYBRID_SHARD,
|
||||
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
||||
]:
|
||||
torch.manual_seed(42)
|
||||
model = MyModel()
|
||||
ref_model = DDP(deepcopy(model), process_group=gloo_pg)
|
||||
model = FSDP(
|
||||
model,
|
||||
sharding_strategy=ss,
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
process_group=gloo_pg,
|
||||
device_id=torch.device("cpu"),
|
||||
|
||||
@ -1,43 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
_start_time = time.time()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ts():
|
||||
return time.time() - _start_time
|
||||
|
||||
|
||||
def configure(level=logging.INFO, force=False):
|
||||
try:
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
force=force,
|
||||
)
|
||||
except TypeError:
|
||||
logging.basicConfig(
|
||||
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def log_test_info(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_success(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_validation(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_warning(rank, message):
|
||||
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_error(rank, message):
|
||||
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
|
||||
@ -2,7 +2,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -45,53 +44,19 @@ class TestInstantiator(TestCase):
|
||||
self.assertEqual(return_type_str, "Tuple[Tensor, int, str]")
|
||||
|
||||
def test_instantiate_scripted_remote_module_template(self):
|
||||
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
|
||||
# Cleanup.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
for file_path in file_paths:
|
||||
file_path.unlink()
|
||||
|
||||
# Check before run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_before = len(list(file_paths))
|
||||
self.assertEqual(num_files_before, 0)
|
||||
|
||||
generated_module = instantiator.instantiate_scriptable_remote_module_template(
|
||||
MyModuleInterface
|
||||
)
|
||||
self.assertTrue(hasattr(generated_module, "_remote_forward"))
|
||||
self.assertTrue(hasattr(generated_module, "_generated_methods"))
|
||||
|
||||
# Check after run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_after = len(list(file_paths))
|
||||
self.assertEqual(num_files_after, 1)
|
||||
|
||||
def test_instantiate_non_scripted_remote_module_template(self):
|
||||
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
|
||||
# Cleanup.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
for file_path in file_paths:
|
||||
file_path.unlink()
|
||||
|
||||
# Check before run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_before = len(list(file_paths))
|
||||
self.assertEqual(num_files_before, 0)
|
||||
|
||||
generated_module = (
|
||||
instantiator.instantiate_non_scriptable_remote_module_template()
|
||||
)
|
||||
self.assertTrue(hasattr(generated_module, "_remote_forward"))
|
||||
self.assertTrue(hasattr(generated_module, "_generated_methods"))
|
||||
|
||||
# Check after run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_after = len(list(file_paths))
|
||||
self.assertEqual(num_files_after, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -64,6 +64,38 @@ class TestDTensorDebugMode(TestCase):
|
||||
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
|
||||
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
|
||||
|
||||
# check stringification
|
||||
self.assertTrue(hasattr(debug_mode.operators[0], "args_str"))
|
||||
self.assertFalse(hasattr(debug_mode.operators[0], "args"))
|
||||
|
||||
# check recording hook
|
||||
def mm(x, y):
|
||||
return (x @ y).sum()
|
||||
|
||||
eager_out = mm(x_dtensor, y_dtensor)
|
||||
|
||||
# check recording hook for compiled variant
|
||||
with (
|
||||
DebugMode() as debug_mode,
|
||||
DebugMode.record_outputs(),
|
||||
DebugMode.log_tensor_hashes(),
|
||||
):
|
||||
compiled_out = torch.compile(mm, backend="aot_eager")(x_dtensor, y_dtensor)
|
||||
|
||||
# check numerical equivalence
|
||||
self.assertTrue(torch.equal(eager_out, compiled_out))
|
||||
sum_op = next(
|
||||
iter(
|
||||
op
|
||||
for op in debug_mode.operators
|
||||
if isinstance(op, _OpCall) and str(op.op) == "aten.sum.default"
|
||||
)
|
||||
)
|
||||
self.assertTrue(torch.equal(sum_op.record["output"], eager_out.to_local()))
|
||||
self.assertTrue(
|
||||
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
|
||||
)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
@ -86,7 +118,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
y_dtensor = DTensor.from_local(y, mesh, [Shard(1)], run_check=False)
|
||||
|
||||
with DebugMode(record_torchfunction=True) as debug_mode:
|
||||
with DebugMode(
|
||||
record_torchfunction=True, record_stack_trace=True
|
||||
) as debug_mode:
|
||||
z = x_dtensor + y_dtensor
|
||||
z.sum().backward()
|
||||
|
||||
@ -119,6 +153,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::detach(t: f32[1, 8])""",
|
||||
)
|
||||
|
||||
# check stack trace
|
||||
self.assertTrue("z.sum().backward()" in debug_mode.operators[-1].stack_trace)
|
||||
|
||||
def test_debug_mode_densor_redistribution_trace(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).view(4, 2))
|
||||
|
||||
@ -267,6 +304,7 @@ class TestDTensorDebugMode(TestCase):
|
||||
record_torchfunction=True,
|
||||
record_faketensor=True,
|
||||
record_tensor_attributes=["a1", "a2"],
|
||||
store_original_args=True,
|
||||
) as debug_mode:
|
||||
torch.matmul(y, x)
|
||||
|
||||
@ -279,6 +317,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(debug_mode.operators[0], "args"))
|
||||
self.assertEqual(id(debug_mode.operators[0].args[0]), id(y))
|
||||
|
||||
@parametrize("has_inner_mode", [True, False])
|
||||
@parametrize("has_outer_mode", [True, False])
|
||||
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
|
||||
|
||||
@ -20,18 +20,18 @@ from torch.distributed.tensor.experimental._attention import (
|
||||
_cp_options,
|
||||
_disable_context_parallel_dispatcher,
|
||||
_enable_context_parallel_dispatcher,
|
||||
_HeadTailLoadBalancer,
|
||||
_is_causal_behavior,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
_RotateMethod,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
)
|
||||
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
|
||||
from torch.distributed.tensor.experimental._load_balancer import (
|
||||
_HeadTailLoadBalancer,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
|
||||
flex_cp_allgather,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
@ -204,32 +204,28 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||
out = model_copy(arg)
|
||||
return (out_dt.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
def test_conv1d(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||
model_gt = copy.deepcopy(model)
|
||||
x = torch.randn(1, 64, 8)
|
||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
||||
model_dt = distribute_module(
|
||||
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
||||
)
|
||||
out_dt = model_dt(x_dt)
|
||||
out = model_gt(x)
|
||||
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
@with_comms
|
||||
def test_conv3d(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||
model_gt = copy.deepcopy(model).to(device=self.device_type)
|
||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
||||
model_dt = distribute_module(
|
||||
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
||||
)
|
||||
out_dt = model_dt(x_dt)
|
||||
out = model_gt(x)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
|
||||
|
||||
@ -520,8 +520,6 @@ class DTensorExportTest(TestCase):
|
||||
2,
|
||||
)
|
||||
|
||||
# "Explanation: SourcelessBuilder.create does not know how to wrap <class 'types.UnionType'>"
|
||||
@unittest.expectedFailure
|
||||
def test_union_typed_annotation(self):
|
||||
def fn(leaf: torch.Tensor | DTensor):
|
||||
def nest_fn(leaf: torch.Tensor | DTensor):
|
||||
@ -535,7 +533,7 @@ class DTensorExportTest(TestCase):
|
||||
z = torch.randn(16, 16)
|
||||
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
|
||||
|
||||
print(gm)
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
@ -981,6 +981,41 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
correct = func(a, b, c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_basic_all_reduce_bucketing(self):
|
||||
"""Test that independent all_reduce operations get bucketed together."""
|
||||
|
||||
def func(a, b, c):
|
||||
# Three independent all_reduces that should be bucketed
|
||||
ar1 = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
ar2 = _functional_collectives.all_reduce(b, "sum", "0")
|
||||
ar3 = _functional_collectives.all_reduce(c, "sum", "0")
|
||||
|
||||
return ar1.sum() + ar2.sum() + ar3.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
|
||||
|
||||
compiled = torch.compile(func)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
|
||||
|
||||
# Should see a single bucketed all_reduce
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@ -22,7 +21,6 @@ from unittest import mock, SkipTest
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
|
||||
|
||||
|
||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||
@ -49,15 +47,12 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
get_required_world_size,
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_shrink,
|
||||
requires_nccl_version,
|
||||
requires_world_size,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
@ -92,17 +87,6 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None or torch.version.hip is not None
|
||||
)
|
||||
|
||||
from logging_utils import (
|
||||
configure as _log_configure,
|
||||
log_test_info,
|
||||
log_test_success,
|
||||
log_test_validation,
|
||||
log_test_warning,
|
||||
)
|
||||
|
||||
|
||||
_log_configure(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
@retry_on_connect_failures
|
||||
@ -333,7 +317,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return get_required_world_size(self, 2)
|
||||
return 2
|
||||
|
||||
@property
|
||||
def rank_to_GPU(self):
|
||||
@ -1271,628 +1255,6 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_basic(self):
|
||||
"""Test basic shrink_group functionality."""
|
||||
self._perform_shrink_test([1], "Basic shrink test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_validation(self):
|
||||
"""Test input validation in shrink_group."""
|
||||
device, pg = self._setup_shrink_test("validation")
|
||||
|
||||
def _test_invalid_input(ranks, description, expected_exception):
|
||||
"""Helper to test invalid inputs."""
|
||||
try:
|
||||
c10d.shrink_group(ranks)
|
||||
self.fail(f"Expected {expected_exception.__name__} for {description}")
|
||||
except expected_exception:
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
except Exception:
|
||||
if expected_exception is Exception: # Accept any exception
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Test cases
|
||||
_test_invalid_input([], "Empty exclusion list", ValueError)
|
||||
if self.world_size > 1:
|
||||
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
|
||||
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
|
||||
|
||||
log_test_success(self.rank, "All validation tests passed")
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_backend_properties(self):
|
||||
"""Test that backend properties are preserved after shrinking."""
|
||||
|
||||
test_name = "Backend Properties Test"
|
||||
ranks_to_exclude = [0]
|
||||
|
||||
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
|
||||
device, pg = self._setup_shrink_test("backend_properties")
|
||||
|
||||
# Follow _perform_shrink_test pattern from here
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# Store original backend property values (not references) before shrinking
|
||||
original_timeout = None
|
||||
original_high_priority = None
|
||||
if not is_excluded:
|
||||
original_backend = pg._get_backend(device)
|
||||
original_timeout = original_backend.options._timeout
|
||||
original_high_priority = original_backend.options.is_high_priority_stream
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
|
||||
)
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
dist.destroy_process_group() # hang without it
|
||||
return
|
||||
|
||||
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
|
||||
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
|
||||
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
# Add custom backend properties validation
|
||||
new_backend = shrunk_pg._get_backend(device)
|
||||
log_test_info(self.rank, "Validating backend properties are preserved")
|
||||
|
||||
new_timeout = new_backend.options._timeout
|
||||
new_high_priority = new_backend.options.is_high_priority_stream
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_high_priority,
|
||||
new_high_priority,
|
||||
f"{test_name}: high_priority_stream not preserved",
|
||||
)
|
||||
|
||||
log_test_validation(
|
||||
self.rank, f"{test_name}: Backend properties preserved successfully"
|
||||
)
|
||||
log_test_success(
|
||||
self.rank, f"{test_name} successful (shrink + backend validation)"
|
||||
)
|
||||
|
||||
# Cleanup (same as _perform_shrink_test)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_multiple_comms(self):
|
||||
"""Test shrink_group with multiple communicators and subgroup invalidation."""
|
||||
|
||||
device, pg = self._setup_shrink_test("multiple_comms")
|
||||
|
||||
# Create subgroup [0, 1] and test shrinking it
|
||||
subgroup = c10d.new_group([0, 1])
|
||||
if self.rank <= 1:
|
||||
# Shrink subgroup: exclude rank 1
|
||||
if self.rank == 0: # Only rank 0 remains
|
||||
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
|
||||
self.assertEqual(shrunk_subgroup.size(), 1)
|
||||
# Test communication on shrunk subgroup
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_subgroup)
|
||||
self.assertEqual(tensor.item(), 0) # Only rank 0
|
||||
log_test_success(self.rank, "Subgroup shrinking successful")
|
||||
|
||||
dist.barrier() # Sync before default group test
|
||||
|
||||
# Shrink default group: exclude last rank
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
if self.rank not in ranks_to_exclude:
|
||||
shrunk_default = c10d.shrink_group(ranks_to_exclude)
|
||||
expected_size = self.world_size - 1
|
||||
self.assertEqual(shrunk_default.size(), expected_size)
|
||||
|
||||
# Test collective on shrunk default group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_default)
|
||||
expected_sum = sum(
|
||||
range(self.world_size - 1)
|
||||
) # 0 + 1 + ... + (world_size-2)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_success(self.rank, "Default group shrinking successful")
|
||||
|
||||
# Note: After shrinking default group, the old subgroup is invalid
|
||||
# due to global rank reassignment
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
|
||||
"""Helper method to test shrink_group with a specific flag."""
|
||||
if self.world_size < 2:
|
||||
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
|
||||
return
|
||||
ranks_to_exclude = [rank_to_exclude]
|
||||
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
|
||||
if flag_name == "NCCL_SHRINK_ABORT":
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"ABORT flag will terminate ongoing operations before shrinking",
|
||||
)
|
||||
|
||||
self._perform_shrink_test(
|
||||
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
|
||||
)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_flags(self):
|
||||
"""Test shrink_group with different shrink flags."""
|
||||
# Test ABORT flags
|
||||
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
|
||||
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_nccl_config(self):
|
||||
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
|
||||
device, pg = self._setup_shrink_test("config")
|
||||
if self.rank == self.world_size - 1:
|
||||
# excluded rank should not call shrink_group
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Prepare pg_options with NCCL config overrides
|
||||
# Capture parent's current backend options to ensure we can prove override vs inherit
|
||||
parent_backend = pg._get_backend(torch.device("cuda"))
|
||||
parent_hp = parent_backend.options.is_high_priority_stream
|
||||
parent_blocking = parent_backend.options.config.blocking
|
||||
|
||||
# Choose overrides that differ from the parent (flip where possible)
|
||||
override_hp = not parent_hp
|
||||
if parent_blocking in (0, 1):
|
||||
override_blocking = 1 - parent_blocking
|
||||
else:
|
||||
# If undefined or unexpected, set to 1 which is a concrete value
|
||||
override_blocking = 1
|
||||
|
||||
opts = c10d.ProcessGroupNCCL.Options()
|
||||
opts.is_high_priority_stream = override_hp
|
||||
opts.config.blocking = override_blocking
|
||||
|
||||
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
|
||||
|
||||
# Validate backend options propagated
|
||||
backend = shrunk_pg._get_backend(torch.device("cuda"))
|
||||
# is_high_priority_stream should exactly match our override and differ from parent
|
||||
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
|
||||
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
|
||||
# config is a struct; check representative field and difference from parent when meaningful
|
||||
self.assertEqual(backend.options.config.blocking, override_blocking)
|
||||
if parent_blocking in (0, 1):
|
||||
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_performance(self):
|
||||
"""Test shrink_group performance and regression detection."""
|
||||
import time
|
||||
|
||||
ranks_to_exclude = self._get_default_ranks_to_exclude()
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
|
||||
if not ranks_to_exclude:
|
||||
log_test_info(self.rank, "Skipping performance test (world_size=1)")
|
||||
return
|
||||
|
||||
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
|
||||
device, pg = self._setup_shrink_test("performance")
|
||||
|
||||
if not is_excluded:
|
||||
log_test_info(self.rank, "Measuring shrink_group performance")
|
||||
start_time = time.time()
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
|
||||
|
||||
# Regression check: should complete within reasonable time
|
||||
self.assertLess(
|
||||
elapsed_time,
|
||||
30.0,
|
||||
f"shrink_group took {elapsed_time:.3f}s, possible regression",
|
||||
)
|
||||
|
||||
# Test collective performance
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
|
||||
|
||||
collective_start = time.time()
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, "performance"
|
||||
)
|
||||
collective_time = time.time() - collective_start
|
||||
|
||||
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
|
||||
log_test_success(self.rank, "Performance test passed")
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded rank - waiting")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(4)
|
||||
def test_shrink_group_multiple_exclusions(self):
|
||||
"""Test shrink_group with multiple ranks excluded at once."""
|
||||
# Scale exclusions with world size
|
||||
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
|
||||
|
||||
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_multiple_iterations(self):
|
||||
"""Test multiple shrink operations in sequence."""
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
|
||||
)
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
# Track current effective world size throughout shrinking operations
|
||||
current_world_size = self.world_size
|
||||
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
|
||||
|
||||
# First shrinking: exclude the last rank(s)
|
||||
first_exclusion = [self.world_size - 1]
|
||||
if self.world_size >= 6:
|
||||
first_exclusion.append(
|
||||
self.world_size - 2
|
||||
) # Exclude last two ranks for larger sizes
|
||||
|
||||
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
|
||||
|
||||
if self.rank not in first_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group
|
||||
first_pg = c10d.shrink_group(first_exclusion)
|
||||
self.assertIsNotNone(first_pg)
|
||||
# IMPORTANT: Update world size after first shrinking
|
||||
current_world_size = first_pg.size()
|
||||
expected_first_size = self.world_size - len(first_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
|
||||
)
|
||||
self.assertEqual(first_pg.size(), expected_first_size)
|
||||
|
||||
# Second shrinking: exclude another rank from the remaining group
|
||||
# Choose a rank that's in the middle range
|
||||
if current_world_size >= 3:
|
||||
second_exclusion = [
|
||||
current_world_size - 1
|
||||
] # Exclude the new "last" rank
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
|
||||
)
|
||||
|
||||
if self.rank not in second_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group for second iteration
|
||||
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
|
||||
self.assertIsNotNone(second_pg)
|
||||
# IMPORTANT: Update world size after second shrinking
|
||||
final_world_size = second_pg.size()
|
||||
expected_final_size = current_world_size - len(second_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
|
||||
)
|
||||
self.assertEqual(second_pg.size(), expected_final_size)
|
||||
|
||||
# Test collective on final group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
|
||||
)
|
||||
c10d.all_reduce(tensor, group=second_pg)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Final all_reduce completed, result: {tensor.item()}",
|
||||
)
|
||||
|
||||
# Calculate expected sum of remaining ranks
|
||||
all_excluded = set(first_exclusion + second_exclusion)
|
||||
remaining_ranks = [
|
||||
r for r in range(self.world_size) if r not in all_excluded
|
||||
]
|
||||
expected_sum = sum(remaining_ranks)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
|
||||
)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_info(self.rank, "Final verification passed")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in second shrinking, not calling shrink_group",
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, "Skipping second shrinking (remaining group too small)"
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in first shrinking, not calling shrink_group",
|
||||
)
|
||||
|
||||
log_test_info(self.rank, "Destroying process group")
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
|
||||
|
||||
# Helper methods for optimized shrink group tests
|
||||
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
|
||||
"""Common setup for shrink group tests."""
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
world_size = world_size or self.world_size
|
||||
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
pg_options=self.opts(),
|
||||
device_id=device,
|
||||
)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if warmup:
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
|
||||
|
||||
return device, pg
|
||||
|
||||
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
|
||||
"""Validate properties of a shrunk process group."""
|
||||
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
|
||||
actual_size = shrunk_pg.size()
|
||||
self.assertEqual(
|
||||
actual_size, expected_size, f"{test_name}: group size mismatch"
|
||||
)
|
||||
|
||||
new_rank = shrunk_pg.rank()
|
||||
self.assertTrue(
|
||||
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
|
||||
)
|
||||
return new_rank
|
||||
|
||||
def _test_collective_on_shrunk_group(
|
||||
self, shrunk_pg, device, ranks_to_exclude, test_name=""
|
||||
):
|
||||
"""Test collective communication on shrunk group and verify correctness."""
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
c10d.all_reduce(test_tensor, group=shrunk_pg)
|
||||
|
||||
result = test_tensor.item()
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result, expected_sum, f"{test_name}: collective result mismatch"
|
||||
)
|
||||
log_test_info(
|
||||
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
|
||||
)
|
||||
return result
|
||||
|
||||
def _perform_shrink_test(
|
||||
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
|
||||
):
|
||||
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
|
||||
|
||||
Consistent API: All ranks perform setup to initialize distributed environment.
|
||||
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
|
||||
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
|
||||
"""
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# All ranks (including excluded ones) perform setup to initialize distributed environment
|
||||
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
|
||||
is_default_group = pg == c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
if shrink_flags & NCCL_SHRINK_ABORT:
|
||||
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
|
||||
pg._get_backend(torch.device(device)).abort()
|
||||
log_test_info(
|
||||
self.rank, f"cleanup resources for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, f"Using regular destroy for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
return None
|
||||
|
||||
# Only non-excluded ranks proceed with shrink
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
|
||||
)
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
|
||||
)
|
||||
|
||||
# Non-excluded ranks: validate and test the new group
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
if with_collective:
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, test_name
|
||||
)
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
|
||||
else:
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink only)")
|
||||
|
||||
dist.destroy_process_group()
|
||||
return shrunk_pg
|
||||
|
||||
def _get_default_ranks_to_exclude(self):
|
||||
"""Get default ranks to exclude based on world size."""
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
return [self.world_size - 1] # Exclude last rank by default
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_vs_abort_reinit_performance(self):
|
||||
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
|
||||
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
|
||||
|
||||
device, pg1 = self._setup_shrink_test("_perf_reinit")
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Test 1: Traditional abort + reinit
|
||||
start_time = time.perf_counter()
|
||||
dist.destroy_process_group()
|
||||
|
||||
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
|
||||
reinit_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective with original rank values for fair comparison (non-blocking mode)
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(r for r in range(self.world_size))
|
||||
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
|
||||
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
dist.destroy_process_group(new_pg)
|
||||
|
||||
# Test 2: shrink_group with NCCL_SHRINK_ABORT
|
||||
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
|
||||
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
|
||||
|
||||
shrink_time = 0
|
||||
if not is_excluded:
|
||||
torch.cuda.synchronize(device) # Ensure accurate timing
|
||||
start_time = time.perf_counter()
|
||||
shrunk_pg = c10d.shrink_group(
|
||||
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
|
||||
)
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
|
||||
shrink_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective communication on shrunk group (non-blocking mode)
|
||||
test_tensor = torch.full(
|
||||
(1,), self.rank, device=device, dtype=torch.float32
|
||||
)
|
||||
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
self.assertEqual(
|
||||
test_tensor.item(),
|
||||
expected_sum,
|
||||
"shrink_test: collective result mismatch",
|
||||
)
|
||||
|
||||
torch.cuda.synchronize(device) # Ensure operations complete
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
dist.destroy_process_group()
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Performance analysis (only for participating ranks)
|
||||
if shrink_time > 0 and reinit_time > 0:
|
||||
speedup = reinit_time / shrink_time
|
||||
time_saved = reinit_time - shrink_time
|
||||
|
||||
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
|
||||
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
|
||||
|
||||
if speedup > 1.1:
|
||||
log_test_success(self.rank, "shrink_group significantly faster")
|
||||
elif speedup > 0.9:
|
||||
log_test_info(self.rank, "≈ comparable performance")
|
||||
else:
|
||||
log_test_warning(self.rank, "abort+reinit faster")
|
||||
|
||||
log_test_info(self.rank, "Performance test completed")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_deterministic_mode_no_break(self):
|
||||
|
||||
@ -195,7 +195,7 @@ if not TEST_WITH_DEV_DBG_ASAN:
|
||||
for i, t in enumerate(tensors):
|
||||
self.assertEqual(t, torch.ones(5, 5, device=device) + i)
|
||||
elif self.rank == 0:
|
||||
for i, t in enumerate(tensors):
|
||||
for t in tensors:
|
||||
zeros = torch.zeros(5, 5, device=device)
|
||||
self.assertEqual(t, zeros)
|
||||
y = torch.sum(torch.stack(tensors), axis=0)
|
||||
|
||||
@ -408,6 +408,79 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
||||
"%all_gather_into_tensor_out", 1, exactly=False
|
||||
).run(graph_str)
|
||||
|
||||
def test_can_bucket_all_reduce(self):
|
||||
"""
|
||||
Test that all_reduce operations CAN bucket together.
|
||||
|
||||
Graph structure:
|
||||
ar1_start -> ar2_start -> mm1 (hides ar1) -> mm2 (hides ar2) -> ar1_wait -> ar2_wait
|
||||
"""
|
||||
|
||||
def func(a, b):
|
||||
group_name = "0"
|
||||
|
||||
# Start both all_reduce operations
|
||||
ar1 = torch.ops._c10d_functional.all_reduce(a, "sum", group_name)
|
||||
ar2 = torch.ops._c10d_functional.all_reduce(b, "sum", group_name)
|
||||
|
||||
# Independent compute that can hide both
|
||||
mm1 = torch.mm(a, a)
|
||||
mm2 = torch.mm(b, b)
|
||||
|
||||
# Wait for both
|
||||
ar1_out = torch.ops._c10d_functional.wait_tensor(ar1)
|
||||
ar2_out = torch.ops._c10d_functional.wait_tensor(ar2)
|
||||
|
||||
return ar1_out.sum() + ar2_out.sum() + mm1.sum() + mm2.sum()
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device)
|
||||
b = torch.ones(4, 4, device=self.device) * 2
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b)
|
||||
|
||||
# Find nodes
|
||||
ar1, ar2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_reduce.default,
|
||||
)
|
||||
mm1, mm2 = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
|
||||
# For all_reduce, start_node == wait_node (no separate wait)
|
||||
hiding_annotations = {
|
||||
ar1: mm1,
|
||||
ar2: mm2,
|
||||
}
|
||||
|
||||
# Build collective info
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
# Verify: should have 1 bucketed all_reduce
|
||||
# After bucketing, there should be only one all_reduce node (the bucketed one)
|
||||
graph_str = str(traced.graph)
|
||||
FileCheck().check_count("%all_reduce", 1, exactly=True).check_count(
|
||||
"%mm", 2
|
||||
).run(graph_str)
|
||||
|
||||
def test_can_bucket_multidtype_collectives(self):
|
||||
"""
|
||||
Test that all_gathers with different dtypes CAN bucket together.
|
||||
|
||||
@ -1718,6 +1718,39 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
self.assertEqual(eager_no_sq, comp_ind_no_sq)
|
||||
self.assertEqual(eager_no_sq.stride(), comp_ind_no_sq.stride())
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_unbacked_activation_specialized_in_inductor(self):
|
||||
"""Test compilation with unbacked operations like nonzero."""
|
||||
torch._dynamo.reset()
|
||||
|
||||
def fuzzed_program(arg_0, sentinel):
|
||||
var_node_1 = arg_0
|
||||
var_node_5 = torch.full((1, 2), -66, dtype=torch.int32)
|
||||
var_node_6 = torch.full((1, 2), 77, dtype=torch.int64)
|
||||
var_node_4 = torch.ops.aten.add(var_node_5, var_node_6)
|
||||
var_node_7 = torch.full((1, 2), -64, dtype=torch.int32)
|
||||
var_node_3 = torch.ops.aten.mul(var_node_4, var_node_7)
|
||||
var_node_9 = torch.full((3, 4), False, dtype=torch.bool)
|
||||
var_node_8 = torch.nonzero(var_node_9)
|
||||
var_node_2 = torch.ops.aten.add(var_node_3, var_node_8)
|
||||
var_node_0 = torch.ops.aten.div(var_node_1, var_node_2)
|
||||
result = var_node_0 * sentinel
|
||||
if result.is_complex():
|
||||
result = result.real
|
||||
return result
|
||||
|
||||
sentinel = torch.tensor(1.0, requires_grad=True)
|
||||
arg_0 = torch.randint(0, 3, (1, 2), dtype=torch.int64)
|
||||
args = (arg_0,) + (sentinel,)
|
||||
|
||||
result_original = fuzzed_program(*args)
|
||||
|
||||
compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
|
||||
result_compiled = compiled_program(*args)
|
||||
|
||||
self.assertTrue(torch.allclose(result_original, result_compiled))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -41,6 +41,20 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
def aot_eager_regional_inductor():
|
||||
"""
|
||||
Regional inductor backend for AOT autograd.
|
||||
Uses regional_inductor as both forward and backward compiler.
|
||||
"""
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch.fx.passes.regional_inductor import regional_inductor
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor,
|
||||
bw_compiler=regional_inductor,
|
||||
)
|
||||
|
||||
|
||||
def saved_tensors_hooks_to_gm(
|
||||
pack_fn,
|
||||
unpack_fn,
|
||||
@ -1898,6 +1912,171 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
# no recompiles
|
||||
self.assertFalse(counters)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
def test_regional_inductor_basic(self):
|
||||
"""
|
||||
Basic test for regional inductor with bundled autograd cache.
|
||||
Tests that regional inductor compilation results can be cached and hit.
|
||||
"""
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
# Mark this region to be compiled with inductor
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10, device="cpu")
|
||||
y = torch.randn(10, device="cpu")
|
||||
|
||||
# Compile with regional inductor backend
|
||||
compiled_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
|
||||
# First call should miss in cache
|
||||
result1 = compiled_fn(x, y)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# Second call should hit (after clearing dynamo)
|
||||
self._clear_dynamo_and_codecache()
|
||||
result2 = compiled_fn(x, y)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# Results should be the same
|
||||
self.assertEqual(result1, result2)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
def test_regional_inductor_with_backward(self):
|
||||
"""
|
||||
Test regional inductor with backward pass and bundled autograd cache.
|
||||
Note: Regional inductor triggers multiple AOT autograd compilations:
|
||||
- One for the outer graph (with regional inductor backend)
|
||||
- One for each marked region (via standalone_compile)
|
||||
"""
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
# Mark this region to be compiled with inductor
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
x2 = x.detach().clone().requires_grad_(True)
|
||||
y2 = y.detach().clone().requires_grad_(True)
|
||||
|
||||
# Compile with regional inductor backend
|
||||
compiled_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
|
||||
# First call: AOT autograd compiles the outer graph (1 miss)
|
||||
# Regional inductor then compiles the marked region (1 more miss)
|
||||
result1 = compiled_fn(x, y)
|
||||
result1.sum().backward()
|
||||
|
||||
# We expect 2 cache misses: outer graph + marked region
|
||||
initial_misses = counters["aot_autograd"]["autograd_cache_miss"]
|
||||
initial_saves = counters["aot_autograd"]["autograd_cache_saved"]
|
||||
self.assertGreater(initial_misses, 0)
|
||||
self.assertGreater(initial_saves, 0)
|
||||
|
||||
# Second call should hit (after clearing dynamo)
|
||||
self._clear_dynamo_and_codecache()
|
||||
result2 = compiled_fn(x2, y2)
|
||||
result2.sum().backward()
|
||||
|
||||
# Should have cache hits now
|
||||
final_hits = counters["aot_autograd"]["autograd_cache_hit"]
|
||||
self.assertGreater(final_hits, 0)
|
||||
|
||||
# Cache misses and saves should not increase
|
||||
self.assertEqual(
|
||||
counters["aot_autograd"]["autograd_cache_miss"], initial_misses
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["aot_autograd"]["autograd_cache_saved"], initial_saves
|
||||
)
|
||||
|
||||
# Results and gradients should be the same
|
||||
self.assertEqual(result1, result2)
|
||||
self.assertEqual(x.grad, x2.grad)
|
||||
self.assertEqual(y.grad, y2.grad)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
def test_regional_inductor_cache_miss_on_change(self):
|
||||
"""
|
||||
Test that changing the function causes a cache miss with regional inductor.
|
||||
Regional inductor creates multiple AOT compilations, so we track
|
||||
the change in cache misses rather than absolute counts.
|
||||
"""
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
def fn1(x, y):
|
||||
sin = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
def fn2(x, y):
|
||||
sin = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 2 # Changed from +1 to +2
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10)
|
||||
y = torch.randn(10)
|
||||
|
||||
# Compile first function
|
||||
compiled_fn1 = torch.compile(
|
||||
fn1, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
result1 = compiled_fn1(x, y)
|
||||
first_misses = counters["aot_autograd"]["autograd_cache_miss"]
|
||||
first_saves = counters["aot_autograd"]["autograd_cache_saved"]
|
||||
self.assertGreater(first_misses, 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertGreater(first_saves, 0)
|
||||
|
||||
# Compile second function (different graph)
|
||||
self._clear_dynamo_and_codecache()
|
||||
compiled_fn2 = torch.compile(
|
||||
fn2, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
result2 = compiled_fn2(x, y)
|
||||
# Should miss because graph is different (more misses than before)
|
||||
self.assertGreater(
|
||||
counters["aot_autograd"]["autograd_cache_miss"], first_misses
|
||||
)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertGreater(
|
||||
counters["aot_autograd"]["autograd_cache_saved"], first_saves
|
||||
)
|
||||
|
||||
# Results should be different
|
||||
self.assertNotEqual(result1, result2)
|
||||
|
||||
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
class AOTAutogradCacheBundledTests(AOTAutogradCacheTests):
|
||||
|
||||
@ -582,6 +582,23 @@ from user code:
|
||||
actual = compiled_fn(fn, *inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_with_default_args(self):
|
||||
def fn(x, y=1):
|
||||
return x + x
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
|
||||
((torch.randn(3, 4),), {})
|
||||
)
|
||||
inputs = (torch.randn(3, 4),)
|
||||
expected = fn(*inputs)
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user